Compare commits

...

10 commits

Author SHA1 Message Date
82accabc73 fix: properly log settings as structlog kwargs instead of dropping them 2026-01-16 00:40:53 +05:30
6f9b545c15 fix: remove invalid positional arg from structlog call 2026-01-16 00:38:41 +05:30
5bcecf4649 fix: use structlog instead of logging in mem0_manager for kwargs support 2026-01-16 00:34:51 +05:30
638a591dc5 add setup script with volume reset option, fix default embedding dims 2026-01-16 00:29:29 +05:30
a190527076 fix: use npm_network for NPM proxy, expose instead of ports 2026-01-16 00:01:41 +05:30
9e86c30548 fix: pass OLLAMA_BASE_URL, EMBEDDING_MODEL, EMBEDDING_DIMS to container 2026-01-15 23:55:44 +05:30
2c1d73a1ec add OpenAI-compatible endpoint and improved login UI
- Add /v1/chat/completions and /chat/completions endpoints (OpenAI SDK compatible)
- Add streaming support with SSE for chat completions
- Add get_current_user_openai auth supporting Bearer token and X-API-Key
- Add OpenAI-compatible request/response models (OpenAIChatCompletionRequest, etc.)
- Cherry-pick improved login UI from cloud branch (styled login screen, logout button)
2026-01-15 23:29:08 +05:30
a228780146 production improvements: configurable embeddings, v1.1, O(1) ownership, retries
- Make Ollama URL configurable via OLLAMA_BASE_URL env var
- Add version: v1.1 to Mem0 config (required for latest features)
- Make embedding model and dimensions configurable
- Fix ownership check: O(1) lookup instead of fetching 10k records
- Add tenacity retry logic for database operations
2026-01-15 23:01:18 +05:30
50edce2d3c security hardening: add auth, rate limiting, fix info disclosure
- Add auth to /models and /users endpoints
- Add rate limiting to all endpoints (10-120/min based on operation type)
- Fix 11 info disclosure issues (detail=str(e) -> generic message)
- Fix 2 silent except blocks with proper logging
- Fix 7 raise e -> raise for proper exception chaining
- Fix health check to not expose exception details
- Update tests with X-API-Key headers and security tests
2026-01-15 22:41:24 +05:30
35c1bbec4e added MCP HTTP endpoint with auth
Exposes memory operations as MCP tools over /mcp endpoint:
- add_memory, search_memory, remove_memory, chat
- API key auth via x-api-key or Authorization header
- User isolation enforced via contextvars
2026-01-11 14:00:16 +05:30
11 changed files with 2088 additions and 522 deletions

View file

@ -1,7 +1,7 @@
"""Simple API key authentication for Mem0 Interface.""" """Simple API key authentication for Mem0 Interface."""
from typing import Optional from typing import Optional
from fastapi import HTTPException, Security, status from fastapi import HTTPException, Security, status, Header
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
import structlog import structlog
@ -19,7 +19,9 @@ class AuthService:
def __init__(self): def __init__(self):
"""Initialize auth service with API key to user mapping.""" """Initialize auth service with API key to user mapping."""
self.api_key_to_user = settings.api_key_mapping self.api_key_to_user = settings.api_key_mapping
logger.info(f"Auth service initialized with {len(self.api_key_to_user)} API keys") logger.info(
f"Auth service initialized with {len(self.api_key_to_user)} API keys"
)
def verify_api_key(self, api_key: str) -> str: def verify_api_key(self, api_key: str) -> str:
""" """
@ -37,8 +39,7 @@ class AuthService:
if api_key not in self.api_key_to_user: if api_key not in self.api_key_to_user:
logger.warning(f"Invalid API key attempted: {api_key[:10]}...") logger.warning(f"Invalid API key attempted: {api_key[:10]}...")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
detail="Invalid API key"
) )
user_id = self.api_key_to_user[api_key] user_id = self.api_key_to_user[api_key]
@ -68,7 +69,7 @@ class AuthService:
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied: You can only access your own memories" detail=f"Access denied: You can only access your own memories",
) )
return authenticated_user_id return authenticated_user_id
@ -91,9 +92,46 @@ async def get_current_user(api_key: str = Security(api_key_header)) -> str:
return auth_service.verify_api_key(api_key) return auth_service.verify_api_key(api_key)
async def get_current_user_openai(
authorization: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
) -> str:
"""
FastAPI dependency for OpenAI-compatible authentication.
Supports both Authorization: Bearer and X-API-Key headers.
Args:
authorization: Authorization header (Bearer token)
x_api_key: X-API-Key header
Returns:
str: Authenticated user_id
Raises:
HTTPException: If no valid API key is provided
"""
api_key = None
# Try Bearer token first (OpenAI standard)
if authorization and authorization.startswith("Bearer "):
api_key = authorization[7:] # Remove "Bearer " prefix
logger.debug("Extracted API key from Authorization Bearer token")
# Fall back to X-API-Key header
elif x_api_key:
api_key = x_api_key
logger.debug("Extracted API key from X-API-Key header")
else:
logger.warning("No API key provided in Authorization or X-API-Key headers")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API key. Provide either 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header",
)
return auth_service.verify_api_key(api_key)
async def verify_user_access( async def verify_user_access(
api_key: str = Security(api_key_header), api_key: str = Security(api_key_header), user_id: Optional[str] = None
user_id: Optional[str] = None
) -> str: ) -> str:
""" """
FastAPI dependency to verify user can access the requested user_id. FastAPI dependency to verify user can access the requested user_id.
@ -114,7 +152,7 @@ async def verify_user_access(
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: You can only access your own memories" detail="Access denied: You can only access your own memories",
) )
return authenticated_user_id return authenticated_user_id

View file

@ -11,39 +11,87 @@ class Settings(BaseSettings):
"""Application settings loaded from environment variables.""" """Application settings loaded from environment variables."""
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file=".env", case_sensitive=False, extra="ignore"
case_sensitive=False,
extra='ignore'
) )
# API Configuration # API Configuration
# Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env) # Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env)
openai_api_key: str = Field(validation_alias=AliasChoices('OPENAI_API_KEY', 'OPENAI_COMPAT_API_KEY', 'openai_api_key')) openai_api_key: str = Field(
openai_base_url: str = Field(validation_alias=AliasChoices('OPENAI_BASE_URL', 'OPENAI_COMPAT_BASE_URL', 'openai_base_url')) validation_alias=AliasChoices(
cohere_api_key: str = Field(validation_alias=AliasChoices('COHERE_API_KEY', 'cohere_api_key')) "OPENAI_API_KEY", "OPENAI_COMPAT_API_KEY", "openai_api_key"
)
)
openai_base_url: str = Field(
validation_alias=AliasChoices(
"OPENAI_BASE_URL", "OPENAI_COMPAT_BASE_URL", "openai_base_url"
)
)
cohere_api_key: str = Field(
validation_alias=AliasChoices("COHERE_API_KEY", "cohere_api_key")
)
# Database Configuration # Database Configuration
qdrant_host: str = Field(default="localhost", validation_alias=AliasChoices('QDRANT_HOST', 'qdrant_host')) qdrant_host: str = Field(
qdrant_port: int = Field(default=6333, validation_alias=AliasChoices('QDRANT_PORT', 'qdrant_port')) default="localhost", validation_alias=AliasChoices("QDRANT_HOST", "qdrant_host")
qdrant_collection_name: str = Field(default="mem0", validation_alias=AliasChoices('QDRANT_COLLECTION_NAME', 'qdrant_collection_name')) )
qdrant_port: int = Field(
default=6333, validation_alias=AliasChoices("QDRANT_PORT", "qdrant_port")
)
qdrant_collection_name: str = Field(
default="mem0",
validation_alias=AliasChoices(
"QDRANT_COLLECTION_NAME", "qdrant_collection_name"
),
)
# Neo4j Configuration # Neo4j Configuration
neo4j_uri: str = Field(default="bolt://localhost:7687", validation_alias=AliasChoices('NEO4J_URI', 'neo4j_uri')) neo4j_uri: str = Field(
neo4j_username: str = Field(default="neo4j", validation_alias=AliasChoices('NEO4J_USERNAME', 'neo4j_username')) default="bolt://localhost:7687",
neo4j_password: str = Field(default="mem0_neo4j_password", validation_alias=AliasChoices('NEO4J_PASSWORD', 'neo4j_password')) validation_alias=AliasChoices("NEO4J_URI", "neo4j_uri"),
)
neo4j_username: str = Field(
default="neo4j",
validation_alias=AliasChoices("NEO4J_USERNAME", "neo4j_username"),
)
neo4j_password: str = Field(
default="mem0_neo4j_password",
validation_alias=AliasChoices("NEO4J_PASSWORD", "neo4j_password"),
)
# Application Configuration # Application Configuration
log_level: str = Field(default="INFO", validation_alias=AliasChoices('LOG_LEVEL', 'log_level')) log_level: str = Field(
cors_origins: str = Field(default="http://localhost:3000", validation_alias=AliasChoices('CORS_ORIGINS', 'cors_origins')) default="INFO", validation_alias=AliasChoices("LOG_LEVEL", "log_level")
)
cors_origins: str = Field(
default="http://localhost:3000",
validation_alias=AliasChoices("CORS_ORIGINS", "cors_origins"),
)
# Model Configuration - Ultra-minimal (single model) # Model Configuration - Ultra-minimal (single model)
default_model: str = Field(default="claude-sonnet-4", validation_alias=AliasChoices('DEFAULT_MODEL', 'default_model')) default_model: str = Field(
default="claude-sonnet-4",
validation_alias=AliasChoices("DEFAULT_MODEL", "default_model"),
)
# Embedder Configuration
ollama_base_url: str = Field(
default="http://host.docker.internal:11434",
validation_alias=AliasChoices("OLLAMA_BASE_URL", "ollama_base_url"),
)
embedding_model: str = Field(
default="qwen3-embedding:4b-q8_0",
validation_alias=AliasChoices("EMBEDDING_MODEL", "embedding_model"),
)
embedding_dims: int = Field(
default=2560, validation_alias=AliasChoices("EMBEDDING_DIMS", "embedding_dims")
)
# Authentication Configuration # Authentication Configuration
# Format: JSON string mapping API keys to user IDs # Format: JSON string mapping API keys to user IDs
# Example: {"api_key_123": "alice", "api_key_456": "bob"} # Example: {"api_key_123": "alice", "api_key_456": "bob"}
api_keys: str = Field(default="{}", validation_alias=AliasChoices('API_KEYS', 'api_keys')) api_keys: str = Field(
default="{}", validation_alias=AliasChoices("API_KEYS", "api_keys")
)
@property @property
def cors_origins_list(self) -> List[str]: def cors_origins_list(self) -> List[str]:

File diff suppressed because it is too large Load diff

240
backend/mcp_server.py Normal file
View file

@ -0,0 +1,240 @@
"""MCP Server for Mem0 Memory Service.
Exposes memory operations as MCP tools over HTTP with API key authentication.
"""
import contextlib
import logging
from contextvars import ContextVar
from typing import Optional
from pydantic import Field
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse
from starlette.routing import Mount
from mcp.server.fastmcp import FastMCP
from config import settings
logger = logging.getLogger(__name__)
# Context variable for authenticated user_id
# Set by middleware, read by tools
current_user_id: ContextVar[str] = ContextVar("current_user_id", default="")
def get_authenticated_user() -> str:
"""Get the authenticated user_id from context.
Raises:
ValueError: If no authenticated user in context.
"""
user_id = current_user_id.get()
if not user_id:
raise ValueError("No authenticated user in context")
return user_id
class MCPAuthMiddleware(BaseHTTPMiddleware):
"""Middleware to authenticate MCP requests using API key."""
async def dispatch(self, request, call_next):
# Extract API key from headers
api_key = (
request.headers.get("x-api-key") or
request.headers.get("X-API-Key") or
request.headers.get("Authorization", "").replace("Bearer ", "")
)
if not api_key:
return JSONResponse(
{"error": "Missing API key. Provide x-api-key or Authorization header."},
status_code=401
)
# Map API key to user_id
user_id = settings.api_key_mapping.get(api_key)
if not user_id:
return JSONResponse(
{"error": "Invalid API key"},
status_code=401
)
# Store user_id in context for tools to access
token = current_user_id.set(user_id)
try:
response = await call_next(request)
return response
finally:
current_user_id.reset(token)
# Create FastMCP server
# streamable_http_path="/" since we mount at /mcp in main.py
mcp = FastMCP(
"Mem0 Memory Service",
stateless_http=True,
json_response=True,
streamable_http_path="/"
)
@mcp.tool()
async def add_memory(
content: str = Field(description="Content to add to memory"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"),
) -> dict:
"""Add content to the user's memory.
The user_id is automatically determined from the API key authentication.
Use agent_id and run_id for multi-agent or session-based memory organization.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}")
result = await mem0_manager.add_memories(
messages=[{"role": "user", "content": content}],
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=metadata
)
return result
@mcp.tool()
async def search_memory(
query: str = Field(description="Search query to find relevant memories"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"),
limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"),
) -> dict:
"""Search the user's memories.
Returns memories most relevant to the query. The user_id is automatically
determined from API key authentication.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}")
result = await mem0_manager.search_memories(
query=query,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit
)
return result
@mcp.tool()
async def remove_memory(
memory_id: str = Field(description="The ID of the memory to remove"),
) -> dict:
"""Remove a specific memory by its ID.
Only memories belonging to the authenticated user can be deleted.
Verifies ownership before deletion.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}")
# Verify ownership: get user's memories and check if memory_id exists
user_memories = await mem0_manager.get_user_memories(
user_id=user_id,
limit=10000 # Get all to check ownership
)
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
if memory_id not in memory_ids:
raise ValueError(f"Memory '{memory_id}' not found or access denied")
result = await mem0_manager.delete_memory(memory_id=memory_id)
return result
@mcp.tool()
async def chat(
message: str = Field(description="The user's message to chat with"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
) -> str:
"""Chat with memory context.
Retrieves relevant memories based on the message, generates a response
using the configured LLM, and stores the conversation in memory.
The user_id is automatically determined from API key authentication.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...")
result = await mem0_manager.chat_with_memory(
message=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id
)
return result.get("response", "")
@contextlib.asynccontextmanager
async def mcp_lifespan():
"""Context manager for MCP session lifecycle.
Must be used in the main FastAPI lifespan since mounted app
lifespans don't run automatically.
"""
async with mcp.session_manager.run():
logger.info("MCP session manager started")
yield
logger.info("MCP session manager stopped")
def create_mcp_app() -> Starlette:
"""Create and configure the MCP Starlette application.
Returns a Starlette app with MCP endpoints, authentication middleware,
and CORS support.
Note: The MCP session manager must be started via mcp_lifespan() in
the main FastAPI lifespan, not here.
"""
# Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp
streamable_app = mcp.streamable_http_app()
# Create Starlette app with MCP routes
app = Starlette(
routes=[Mount("/", app=streamable_app)],
)
# Add authentication middleware
app.add_middleware(MCPAuthMiddleware)
# Add CORS middleware for browser clients
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
allow_headers=["*"],
expose_headers=["Mcp-Session-Id"],
)
return app

View file

@ -5,24 +5,48 @@ from typing import Dict, List, Optional, Any
from datetime import datetime from datetime import datetime
from mem0 import Memory from mem0 import Memory
from openai import OpenAI from openai import OpenAI
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
)
import structlog
from config import settings from config import settings
from monitoring import timed from monitoring import timed
logger = logging.getLogger(__name__) logger = structlog.get_logger(__name__)
# Retry decorator for database operations (Qdrant, Neo4j)
db_retry = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility # Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
from mem0.llms.openai import OpenAILLM from mem0.llms.openai import OpenAILLM
_original_generate_response = OpenAILLM.generate_response _original_generate_response = OpenAILLM.generate_response
def patched_generate_response(self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs):
def patched_generate_response(
self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs
):
# Remove 'store' parameter as LiteLLM doesn't support it # Remove 'store' parameter as LiteLLM doesn't support it
if hasattr(self.config, 'store'): if hasattr(self.config, "store"):
self.config.store = None self.config.store = None
# Remove 'top_p' to avoid conflict with temperature for Claude models # Remove 'top_p' to avoid conflict with temperature for Claude models
if hasattr(self.config, 'top_p'): if hasattr(self.config, "top_p"):
self.config.top_p = None self.config.top_p = None
return _original_generate_response(self, messages, response_format, tools, tool_choice, **kwargs) return _original_generate_response(
self, messages, response_format, tools, tool_choice, **kwargs
)
OpenAILLM.generate_response = patched_generate_response OpenAILLM.generate_response = patched_generate_response
logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter") logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter")
@ -33,11 +57,19 @@ class Mem0Manager:
Ultra-minimal manager that bridges custom OpenAI endpoint with pure Mem0. Ultra-minimal manager that bridges custom OpenAI endpoint with pure Mem0.
No custom logic - let Mem0 handle all memory intelligence. No custom logic - let Mem0 handle all memory intelligence.
""" """
def __init__(self): def __init__(self):
# Custom endpoint configuration with graph memory enabled # Custom endpoint configuration with graph memory enabled
logger.info("Initializing ultra-minimal Mem0Manager with custom endpoint with settings:", settings) logger.info(
"Initializing Mem0Manager with custom endpoint",
model=settings.default_model,
embedding_model=settings.embedding_model,
embedding_dims=settings.embedding_dims,
qdrant_host=settings.qdrant_host,
neo4j_uri=settings.neo4j_uri,
)
config = { config = {
"version": "v1.1",
"enable_graph": True, "enable_graph": True,
"llm": { "llm": {
"provider": "openai", "provider": "openai",
@ -46,17 +78,16 @@ class Mem0Manager:
"api_key": settings.openai_api_key, "api_key": settings.openai_api_key,
"openai_base_url": settings.openai_base_url, "openai_base_url": settings.openai_base_url,
"temperature": 0.1, "temperature": 0.1,
"top_p": None # Don't use top_p with Claude models "top_p": None,
} },
}, },
"embedder": { "embedder": {
"provider": "ollama", "provider": "ollama",
"config": { "config": {
"model": "qwen3-embedding:4b-q8_0", "model": settings.embedding_model,
# "api_key": settings.embedder_api_key, "ollama_base_url": settings.ollama_base_url,
"ollama_base_url": "http://172.17.0.1:11434", "embedding_dims": settings.embedding_dims,
"embedding_dims": 2560 },
}
}, },
"vector_store": { "vector_store": {
"provider": "qdrant", "provider": "qdrant",
@ -64,38 +95,39 @@ class Mem0Manager:
"collection_name": settings.qdrant_collection_name, "collection_name": settings.qdrant_collection_name,
"host": settings.qdrant_host, "host": settings.qdrant_host,
"port": settings.qdrant_port, "port": settings.qdrant_port,
"embedding_model_dims": 2560, "embedding_model_dims": settings.embedding_dims,
"on_disk": True "on_disk": True,
} },
}, },
"graph_store": { "graph_store": {
"provider": "neo4j", "provider": "neo4j",
"config": { "config": {
"url": settings.neo4j_uri, "url": settings.neo4j_uri,
"username": settings.neo4j_username, "username": settings.neo4j_username,
"password": settings.neo4j_password "password": settings.neo4j_password,
} },
}, },
"reranker": { "reranker": {
"provider": "cohere", "provider": "cohere",
"config": { "config": {
"api_key": settings.cohere_api_key, "api_key": settings.cohere_api_key,
"model": "rerank-english-v3.0", "model": "rerank-english-v3.0",
"top_n": 10 "top_n": 10,
} },
} },
} }
self.memory = Memory.from_config(config) self.memory = Memory.from_config(config)
self.openai_client = OpenAI( self.openai_client = OpenAI(
api_key=settings.openai_api_key, api_key=settings.openai_api_key,
base_url=settings.openai_base_url base_url=settings.openai_base_url,
timeout=60.0, # 60 second timeout for LLM calls
max_retries=2, # Retry failed requests up to 2 times
) )
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint") logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
# Pure passthrough methods - no custom logic # Pure passthrough methods - no custom logic
@db_retry
@timed("add_memories") @timed("add_memories")
async def add_memories( async def add_memories(
self, self,
@ -103,18 +135,18 @@ class Mem0Manager:
user_id: Optional[str] = "default", user_id: Optional[str] = "default",
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Add memories - simplified native Mem0 pattern (10 lines vs 45).""" """Add memories - simplified native Mem0 pattern (10 lines vs 45)."""
try: try:
# Convert ChatMessage objects to dict if needed # Convert ChatMessage objects to dict if needed
formatted_messages = [] formatted_messages = []
for msg in messages: for msg in messages:
if hasattr(msg, 'dict'): if hasattr(msg, "dict"):
formatted_messages.append(msg.dict()) formatted_messages.append(msg.dict())
else: else:
formatted_messages.append(msg) formatted_messages.append(msg)
# Auto-enhance metadata for better memory quality # Auto-enhance metadata for better memory quality
combined_metadata = metadata or {} combined_metadata = metadata or {}
@ -123,26 +155,35 @@ class Mem0Manager:
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"source": "chat_conversation", "source": "chat_conversation",
"message_count": len(formatted_messages), "message_count": len(formatted_messages),
"auto_generated": True "auto_generated": True,
} }
# Merge user metadata with auto metadata (user metadata takes precedence) # Merge user metadata with auto metadata (user metadata takes precedence)
enhanced_metadata = {**auto_metadata, **combined_metadata} enhanced_metadata = {**auto_metadata, **combined_metadata}
# Direct Mem0 add with enhanced metadata # Direct Mem0 add with enhanced metadata
result = self.memory.add(formatted_messages, user_id=user_id, result = self.memory.add(
agent_id=agent_id, run_id=run_id, formatted_messages,
metadata=enhanced_metadata) user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=enhanced_metadata,
)
return { return {
"added_memories": result if isinstance(result, list) else [result], "added_memories": result if isinstance(result, list) else [result],
"message": "Memories added successfully", "message": "Memories added successfully",
"hierarchy": {"user_id": user_id, "agent_id": agent_id, "run_id": run_id} "hierarchy": {
"user_id": user_id,
"agent_id": agent_id,
"run_id": run_id,
},
} }
except Exception as e: except Exception as e:
logger.error(f"Error adding memories: {e}") logger.error(f"Error adding memories: {e}")
raise e raise
@db_retry
@timed("search_memories") @timed("search_memories")
async def search_memories( async def search_memories(
self, self,
@ -155,37 +196,79 @@ class Mem0Manager:
# rerank: bool = False, # rerank: bool = False,
# filter_memories: bool = False, # filter_memories: bool = False,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None run_id: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Search memories - native Mem0 pattern""" """Search memories - native Mem0 pattern"""
try: try:
# Minimal empty query protection for API compatibility # Minimal empty query protection for API compatibility
if not query or query.strip() == "": if not query or query.strip() == "":
return {"memories": [], "total_count": 0, "query": query, "note": "Empty query provided, no results returned. Use a specific query to search memories."} return {
"memories": [],
"total_count": 0,
"query": query,
"note": "Empty query provided, no results returned. Use a specific query to search memories.",
}
# Direct Mem0 search - trust native handling # Direct Mem0 search - trust native handling
result = self.memory.search(query=query, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit, threshold=threshold, filters=filters) result = self.memory.search(
return {"memories": result.get("results", []), "total_count": len(result.get("results", [])), "query": query} query=query,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit,
threshold=threshold,
filters=filters,
)
return {
"memories": result.get("results", []),
"total_count": len(result.get("results", [])),
"query": query,
}
except Exception as e: except Exception as e:
logger.error(f"Error searching memories: {e}") logger.error(f"Error searching memories: {e}")
raise e raise
@db_retry
async def get_user_memories( async def get_user_memories(
self, self,
user_id: str, user_id: str,
limit: int = 10, limit: int = 10,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None filters: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Get all memories for a user - native Mem0 pattern.""" """Get all memories for a user - native Mem0 pattern."""
try: try:
# Direct Mem0 get_all call - trust native parameter handling # Direct Mem0 get_all call - trust native parameter handling
result = self.memory.get_all(user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id, filters=filters) result = self.memory.get_all(
user_id=user_id,
limit=limit,
agent_id=agent_id,
run_id=run_id,
filters=filters,
)
return result.get("results", []) return result.get("results", [])
except Exception as e: except Exception as e:
logger.error(f"Error getting user memories: {e}") logger.error(f"Error getting user memories: {e}")
raise e raise
@db_retry
async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
"""Get a single memory by ID. Returns None if not found."""
try:
result = self.memory.get(memory_id=memory_id)
return result
except Exception as e:
logger.debug(f"Memory {memory_id} not found or error: {e}")
return None
async def verify_memory_ownership(self, memory_id: str, user_id: str) -> bool:
"""Check if a memory belongs to a user. O(1) instead of O(n)."""
memory = await self.get_memory(memory_id)
if memory is None:
return False
return memory.get("user_id") == user_id
@db_retry
@timed("update_memory") @timed("update_memory")
async def update_memory( async def update_memory(
self, self,
@ -194,15 +277,13 @@ class Mem0Manager:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Update memory - pure Mem0 passthrough.""" """Update memory - pure Mem0 passthrough."""
try: try:
result = self.memory.update( result = self.memory.update(memory_id=memory_id, data=content)
memory_id=memory_id,
data=content
)
return {"message": "Memory updated successfully", "result": result} return {"message": "Memory updated successfully", "result": result}
except Exception as e: except Exception as e:
logger.error(f"Error updating memory: {e}") logger.error(f"Error updating memory: {e}")
raise e raise
@db_retry
@timed("delete_memory") @timed("delete_memory")
async def delete_memory(self, memory_id: str) -> Dict[str, Any]: async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
"""Delete memory - pure Mem0 passthrough.""" """Delete memory - pure Mem0 passthrough."""
@ -211,7 +292,7 @@ class Mem0Manager:
return {"message": "Memory deleted successfully"} return {"message": "Memory deleted successfully"}
except Exception as e: except Exception as e:
logger.error(f"Error deleting memory: {e}") logger.error(f"Error deleting memory: {e}")
raise e raise
async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]: async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]:
"""Delete all user memories - pure Mem0 passthrough.""" """Delete all user memories - pure Mem0 passthrough."""
@ -220,8 +301,8 @@ class Mem0Manager:
return {"message": "All user memories deleted successfully"} return {"message": "All user memories deleted successfully"}
except Exception as e: except Exception as e:
logger.error(f"Error deleting user memories: {e}") logger.error(f"Error deleting user memories: {e}")
raise e raise
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]: async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
"""Get memory change history - pure Mem0 passthrough.""" """Get memory change history - pure Mem0 passthrough."""
try: try:
@ -229,42 +310,44 @@ class Mem0Manager:
return { return {
"memory_id": memory_id, "memory_id": memory_id,
"history": history, "history": history,
"message": "Memory history retrieved successfully" "message": "Memory history retrieved successfully",
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting memory history: {e}") logger.error(f"Error getting memory history: {e}")
raise e raise
async def get_graph_relationships(
async def get_graph_relationships(self, user_id: Optional[str], agent_id: Optional[str], run_id: Optional[str], limit: int = 50) -> Dict[str, Any]: self,
user_id: Optional[str],
agent_id: Optional[str],
run_id: Optional[str],
limit: int = 50,
) -> Dict[str, Any]:
"""Get graph relationships - using correct Mem0 get_all() method.""" """Get graph relationships - using correct Mem0 get_all() method."""
try: try:
# Use get_all() to retrieve memories with graph relationships # Use get_all() to retrieve memories with graph relationships
result = self.memory.get_all( result = self.memory.get_all(
user_id=user_id, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit
agent_id=agent_id,
run_id=run_id,
limit=limit
) )
# Extract relationships from Mem0's response structure # Extract relationships from Mem0's response structure
relationships = result.get("relations", []) relationships = result.get("relations", [])
# For entities, we can derive them from memory results or relations # For entities, we can derive them from memory results or relations
entities = [] entities = []
if "results" in result: if "results" in result:
# Extract unique entities from memories and relationships # Extract unique entities from memories and relationships
entity_set = set() entity_set = set()
# Add entities from relationships # Add entities from relationships
for rel in relationships: for rel in relationships:
if "source" in rel: if "source" in rel:
entity_set.add(rel["source"]) entity_set.add(rel["source"])
if "target" in rel: if "target" in rel:
entity_set.add(rel["target"]) entity_set.add(rel["target"])
entities = [{"name": entity} for entity in entity_set] entities = [{"name": entity} for entity in entity_set]
return { return {
"relationships": relationships, "relationships": relationships,
"entities": entities, "entities": entities,
@ -272,9 +355,9 @@ class Mem0Manager:
"agent_id": agent_id, "agent_id": agent_id,
"run_id": run_id, "run_id": run_id,
"total_memories": len(result.get("results", [])), "total_memories": len(result.get("results", [])),
"total_relationships": len(relationships) "total_relationships": len(relationships),
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting graph relationships: {e}") logger.error(f"Error getting graph relationships: {e}")
# Return empty but structured response on error # Return empty but structured response on error
@ -286,9 +369,9 @@ class Mem0Manager:
"run_id": run_id, "run_id": run_id,
"total_memories": 0, "total_memories": 0,
"total_relationships": 0, "total_relationships": 0,
"error": str(e) "error": str(e),
} }
@timed("chat_with_memory") @timed("chat_with_memory")
async def chat_with_memory( async def chat_with_memory(
self, self,
@ -301,53 +384,74 @@ class Mem0Manager:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Chat with memory - native Mem0 pattern with detailed timing.""" """Chat with memory - native Mem0 pattern with detailed timing."""
import time import time
try: try:
total_start_time = time.time() total_start_time = time.time()
print(f"\n🚀 Starting chat request for user: {user_id}") logger.info("Starting chat request", user_id=user_id)
# Stage 1: Memory Search
search_start_time = time.time() search_start_time = time.time()
search_result = self.memory.search(query=message, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=10, threshold=0.3) search_result = self.memory.search(
query=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=10,
threshold=0.3,
)
relevant_memories = search_result.get("results", []) relevant_memories = search_result.get("results", [])
memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories) memories_str = "\n".join(
f"- {entry['memory']}" for entry in relevant_memories
)
search_time = time.time() - search_start_time search_time = time.time() - search_start_time
print(f"🔍 Memory search took: {search_time:.2f}s (found {len(relevant_memories)} memories)") logger.debug(
"Memory search completed",
# Stage 2: Prepare LLM messages search_time_s=round(search_time, 2),
memories_found=len(relevant_memories),
)
prep_start_time = time.time() prep_start_time = time.time()
system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}" system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
messages = [{"role": "system", "content": system_prompt}] messages = [{"role": "system", "content": system_prompt}]
# Add conversation context if provided (last 50 messages)
if context: if context:
messages.extend(context) messages.extend(context)
print(f"📝 Added {len(context)} context messages") logger.debug("Added context messages", context_count=len(context))
# Add current user message
messages.append({"role": "user", "content": message}) messages.append({"role": "user", "content": message})
prep_time = time.time() - prep_start_time prep_time = time.time() - prep_start_time
print(f"📋 Message preparation took: {prep_time:.3f}s")
# Stage 3: LLM Call
llm_start_time = time.time() llm_start_time = time.time()
response = self.openai_client.chat.completions.create(model=settings.default_model, messages=messages) response = self.openai_client.chat.completions.create(
model=settings.default_model, messages=messages
)
assistant_response = response.choices[0].message.content assistant_response = response.choices[0].message.content
llm_time = time.time() - llm_start_time llm_time = time.time() - llm_start_time
print(f"🤖 LLM call took: {llm_time:.2f}s (model: {settings.default_model})") logger.debug(
"LLM call completed",
# Stage 4: Memory Add llm_time_s=round(llm_time, 2),
model=settings.default_model,
)
add_start_time = time.time() add_start_time = time.time()
memory_messages = [{"role": "user", "content": message}, {"role": "assistant", "content": assistant_response}] memory_messages = [
{"role": "user", "content": message},
{"role": "assistant", "content": assistant_response},
]
self.memory.add(memory_messages, user_id=user_id) self.memory.add(memory_messages, user_id=user_id)
add_time = time.time() - add_start_time add_time = time.time() - add_start_time
print(f"💾 Memory add took: {add_time:.2f}s")
# Total timing summary
total_time = time.time() - total_start_time total_time = time.time() - total_start_time
print(f"⏱️ TOTAL: {total_time:.2f}s | Search: {search_time:.2f}s | LLM: {llm_time:.2f}s | Add: {add_time:.2f}s | Prep: {prep_time:.3f}s") logger.info(
print(f"📊 Breakdown: Search {(search_time/total_time)*100:.1f}% | LLM {(llm_time/total_time)*100:.1f}% | Add {(add_time/total_time)*100:.1f}%\n") "Chat request completed",
user_id=user_id,
total_time_s=round(total_time, 2),
search_time_s=round(search_time, 2),
llm_time_s=round(llm_time, 2),
add_time_s=round(add_time, 2),
memories_used=len(relevant_memories),
model=settings.default_model,
)
return { return {
"response": assistant_response, "response": assistant_response,
"memories_used": len(relevant_memories), "memories_used": len(relevant_memories),
@ -356,37 +460,42 @@ class Mem0Manager:
"total": round(total_time, 2), "total": round(total_time, 2),
"search": round(search_time, 2), "search": round(search_time, 2),
"llm": round(llm_time, 2), "llm": round(llm_time, 2),
"add": round(add_time, 2) "add": round(add_time, 2),
} },
} }
except Exception as e: except Exception as e:
logger.error(f"Error in chat_with_memory: {e}") logger.error(
"Error in chat_with_memory",
error=str(e),
user_id=user_id,
exc_info=True,
)
return { return {
"error": str(e), "error": str(e),
"response": "I apologize, but I encountered an error processing your request.", "response": "I apologize, but I encountered an error processing your request.",
"memories_used": 0, "memories_used": 0,
"model_used": None "model_used": None,
} }
async def health_check(self) -> Dict[str, str]: async def health_check(self) -> Dict[str, str]:
"""Basic health check - just connectivity.""" """Basic health check - just connectivity."""
status = {} status = {}
# Check custom OpenAI endpoint # Check custom OpenAI endpoint
try: try:
models = self.openai_client.models.list() models = self.openai_client.models.list()
status["openai_endpoint"] = "healthy" status["openai_endpoint"] = "healthy"
except Exception as e: except Exception as e:
status["openai_endpoint"] = f"unhealthy: {str(e)}" status["openai_endpoint"] = f"unhealthy: {str(e)}"
# Check Mem0 memory # Check Mem0 memory
try: try:
self.memory.search(query="test", user_id="health_check", limit=1) self.memory.search(query="test", user_id="health_check", limit=1)
status["mem0_memory"] = "healthy" status["mem0_memory"] = "healthy"
except Exception as e: except Exception as e:
status["mem0_memory"] = f"unhealthy: {str(e)}" status["mem0_memory"] = f"unhealthy: {str(e)}"
return status return status

View file

@ -2,53 +2,115 @@
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import re
# Constants for input validation
MAX_MESSAGE_LENGTH = 50000 # ~12k tokens max per message
MAX_QUERY_LENGTH = 10000 # ~2.5k tokens max per query
MAX_USER_ID_LENGTH = 100 # Reasonable user ID length
MAX_MEMORY_ID_LENGTH = 100 # Memory IDs are typically UUIDs
MAX_CONTEXT_MESSAGES = 100 # Max conversation context messages
USER_ID_PATTERN = r"^[a-zA-Z0-9_\-\.@]+$" # Alphanumeric with common separators
# Request Models # Request Models
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
"""Chat message structure.""" """Chat message structure."""
role: str = Field(..., description="Message role (user, assistant, system)")
content: str = Field(..., description="Message content") role: str = Field(
..., max_length=20, description="Message role (user, assistant, system)"
)
content: str = Field(
..., max_length=MAX_MESSAGE_LENGTH, description="Message content"
)
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
"""Ultra-minimal chat request.""" """Ultra-minimal chat request."""
message: str = Field(..., description="User message")
user_id: Optional[str] = Field("default", description="User identifier") message: str = Field(..., max_length=MAX_MESSAGE_LENGTH, description="User message")
agent_id: Optional[str] = Field(None, description="Agent identifier") user_id: Optional[str] = Field(
run_id: Optional[str] = Field(None, description="Run identifier") "default",
context: Optional[List[ChatMessage]] = Field(None, description="Previous conversation context") max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier (alphanumeric, _, -, ., @)",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
context: Optional[List[ChatMessage]] = Field(
None,
max_length=MAX_CONTEXT_MESSAGES,
description="Previous conversation context (max 100 messages)",
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class MemoryAddRequest(BaseModel): class MemoryAddRequest(BaseModel):
"""Request to add memories with hierarchy support - open-source compatible.""" """Request to add memories with hierarchy support - open-source compatible."""
messages: List[ChatMessage] = Field(..., description="Messages to process")
user_id: Optional[str] = Field("default", description="User identifier") messages: List[ChatMessage] = Field(
agent_id: Optional[str] = Field(None, description="Agent identifier") ...,
run_id: Optional[str] = Field(None, description="Run identifier") max_length=MAX_CONTEXT_MESSAGES,
description="Messages to process (max 100 messages)",
)
user_id: Optional[str] = Field(
"default",
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class MemorySearchRequest(BaseModel): class MemorySearchRequest(BaseModel):
"""Request to search memories with hierarchy filtering.""" """Request to search memories with hierarchy filtering."""
query: str = Field(..., description="Search query")
user_id: Optional[str] = Field("default", description="User identifier") query: str = Field(..., max_length=MAX_QUERY_LENGTH, description="Search query")
agent_id: Optional[str] = Field(None, description="Agent identifier") user_id: Optional[str] = Field(
run_id: Optional[str] = Field(None, description="Run identifier") "default",
limit: int = Field(5, description="Maximum number of results") max_length=MAX_USER_ID_LENGTH,
threshold: Optional[float] = Field(None, description="Minimum relevance score") pattern=USER_ID_PATTERN,
description="User identifier",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
limit: int = Field(5, ge=1, le=100, description="Maximum number of results (1-100)")
threshold: Optional[float] = Field(
None, ge=0.0, le=1.0, description="Minimum relevance score (0-1)"
)
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters") filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters")
# Hierarchy filters (open-source compatible)
agent_id: Optional[str] = Field(None, description="Filter by agent identifier")
run_id: Optional[str] = Field(None, description="Filter by run identifier")
class MemoryUpdateRequest(BaseModel): class MemoryUpdateRequest(BaseModel):
"""Request to update a memory.""" """Request to update a memory."""
memory_id: str = Field(..., description="Memory ID to update")
content: str = Field(..., description="New memory content") memory_id: str = Field(
..., max_length=MAX_MEMORY_ID_LENGTH, description="Memory ID to update"
)
user_id: str = Field(
...,
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier for ownership verification",
)
content: str = Field(
..., max_length=MAX_MESSAGE_LENGTH, description="New memory content"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata")
@ -57,19 +119,23 @@ class MemoryUpdateRequest(BaseModel):
class MemoryItem(BaseModel): class MemoryItem(BaseModel):
"""Individual memory item.""" """Individual memory item."""
id: str = Field(..., description="Memory unique identifier") id: str = Field(..., description="Memory unique identifier")
memory: str = Field(..., description="Memory content") memory: str = Field(..., description="Memory content")
user_id: Optional[str] = Field(None, description="Associated user ID") user_id: Optional[str] = Field(None, description="Associated user ID")
agent_id: Optional[str] = Field(None, description="Associated agent ID") agent_id: Optional[str] = Field(None, description="Associated agent ID")
run_id: Optional[str] = Field(None, description="Associated run ID") run_id: Optional[str] = Field(None, description="Associated run ID")
metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata")
score: Optional[float] = Field(None, description="Relevance score (for search results)") score: Optional[float] = Field(
None, description="Relevance score (for search results)"
)
created_at: Optional[str] = Field(None, description="Creation timestamp") created_at: Optional[str] = Field(None, description="Creation timestamp")
updated_at: Optional[str] = Field(None, description="Last update timestamp") updated_at: Optional[str] = Field(None, description="Last update timestamp")
class MemorySearchResponse(BaseModel): class MemorySearchResponse(BaseModel):
"""Memory search results - pure Mem0 structure.""" """Memory search results - pure Mem0 structure."""
memories: List[MemoryItem] = Field(..., description="Found memories") memories: List[MemoryItem] = Field(..., description="Found memories")
total_count: int = Field(..., description="Total number of memories found") total_count: int = Field(..., description="Total number of memories found")
query: str = Field(..., description="Original search query") query: str = Field(..., description="Original search query")
@ -77,27 +143,37 @@ class MemorySearchResponse(BaseModel):
class MemoryAddResponse(BaseModel): class MemoryAddResponse(BaseModel):
"""Response from adding memories - pure Mem0 structure.""" """Response from adding memories - pure Mem0 structure."""
added_memories: List[Dict[str, Any]] = Field(..., description="Memories that were added")
added_memories: List[Dict[str, Any]] = Field(
..., description="Memories that were added"
)
message: str = Field(..., description="Success message") message: str = Field(..., description="Success message")
class GraphRelationship(BaseModel): class GraphRelationship(BaseModel):
"""Graph relationship structure.""" """Graph relationship structure."""
source: str = Field(..., description="Source entity") source: str = Field(..., description="Source entity")
relationship: str = Field(..., description="Relationship type") relationship: str = Field(..., description="Relationship type")
target: str = Field(..., description="Target entity") target: str = Field(..., description="Target entity")
properties: Optional[Dict[str, Any]] = Field(None, description="Relationship properties") properties: Optional[Dict[str, Any]] = Field(
None, description="Relationship properties"
)
class GraphResponse(BaseModel): class GraphResponse(BaseModel):
"""Graph relationships - pure Mem0 structure.""" """Graph relationships - pure Mem0 structure."""
relationships: List[GraphRelationship] = Field(..., description="Found relationships")
relationships: List[GraphRelationship] = Field(
..., description="Found relationships"
)
entities: List[str] = Field(..., description="Unique entities") entities: List[str] = Field(..., description="Unique entities")
user_id: str = Field(..., description="User identifier") user_id: str = Field(..., description="User identifier")
class HealthResponse(BaseModel): class HealthResponse(BaseModel):
"""Health check response.""" """Health check response."""
status: str = Field(..., description="Service status") status: str = Field(..., description="Service status")
services: Dict[str, str] = Field(..., description="Individual service statuses") services: Dict[str, str] = Field(..., description="Individual service statuses")
timestamp: str = Field(..., description="Health check timestamp") timestamp: str = Field(..., description="Health check timestamp")
@ -105,6 +181,7 @@ class HealthResponse(BaseModel):
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
"""Error response structure.""" """Error response structure."""
error: str = Field(..., description="Error message") error: str = Field(..., description="Error message")
detail: Optional[str] = Field(None, description="Detailed error information") detail: Optional[str] = Field(None, description="Detailed error information")
status_code: int = Field(..., description="HTTP status code") status_code: int = Field(..., description="HTTP status code")
@ -112,8 +189,10 @@ class ErrorResponse(BaseModel):
# Statistics and Monitoring Models # Statistics and Monitoring Models
class MemoryOperationStats(BaseModel): class MemoryOperationStats(BaseModel):
"""Memory operation statistics.""" """Memory operation statistics."""
add: int = Field(..., description="Number of add operations") add: int = Field(..., description="Number of add operations")
search: int = Field(..., description="Number of search operations") search: int = Field(..., description="Number of search operations")
update: int = Field(..., description="Number of update operations") update: int = Field(..., description="Number of update operations")
@ -122,19 +201,111 @@ class MemoryOperationStats(BaseModel):
class GlobalStatsResponse(BaseModel): class GlobalStatsResponse(BaseModel):
"""Global application statistics.""" """Global application statistics."""
total_memories: int = Field(..., description="Total memories across all users") total_memories: int = Field(..., description="Total memories across all users")
total_users: int = Field(..., description="Total number of users") total_users: int = Field(..., description="Total number of users")
api_calls_today: int = Field(..., description="Total API calls today") api_calls_today: int = Field(..., description="Total API calls today")
avg_response_time_ms: float = Field(..., description="Average response time in milliseconds") avg_response_time_ms: float = Field(
memory_operations: MemoryOperationStats = Field(..., description="Memory operation breakdown") ..., description="Average response time in milliseconds"
)
memory_operations: MemoryOperationStats = Field(
..., description="Memory operation breakdown"
)
uptime_seconds: float = Field(..., description="Application uptime in seconds") uptime_seconds: float = Field(..., description="Application uptime in seconds")
class UserStatsResponse(BaseModel): class UserStatsResponse(BaseModel):
"""User-specific statistics.""" """User-specific statistics."""
user_id: str = Field(..., description="User identifier") user_id: str = Field(..., description="User identifier")
memory_count: int = Field(..., description="Number of memories for this user") memory_count: int = Field(..., description="Number of memories for this user")
relationship_count: int = Field(..., description="Number of graph relationships for this user") relationship_count: int = Field(
..., description="Number of graph relationships for this user"
)
last_activity: Optional[str] = Field(None, description="Last activity timestamp") last_activity: Optional[str] = Field(None, description="Last activity timestamp")
api_calls_today: int = Field(..., description="API calls made by this user today") api_calls_today: int = Field(..., description="API calls made by this user today")
avg_response_time_ms: float = Field(..., description="Average response time for this user's requests") avg_response_time_ms: float = Field(
..., description="Average response time for this user's requests"
)
# OpenAI-Compatible API Models
class OpenAIMessage(BaseModel):
"""OpenAI message format."""
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
class OpenAIChatCompletionRequest(BaseModel):
"""OpenAI chat completion request format."""
model: str = Field(..., description="Model to use (will use configured default)")
messages: List[Dict[str, str]] = Field(..., description="List of messages")
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
stream: Optional[bool] = Field(False, description="Whether to stream responses")
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
n: Optional[int] = Field(1, description="Number of completions to generate")
stop: Optional[List[str]] = Field(None, description="Stop sequences")
presence_penalty: Optional[float] = Field(0, description="Presence penalty")
frequency_penalty: Optional[float] = Field(0, description="Frequency penalty")
user: Optional[str] = Field(
None, description="User identifier (ignored, uses API key)"
)
class OpenAIUsage(BaseModel):
"""Token usage information."""
prompt_tokens: int = Field(..., description="Tokens in the prompt")
completion_tokens: int = Field(..., description="Tokens in the completion")
total_tokens: int = Field(..., description="Total tokens used")
class OpenAIChoiceMessage(BaseModel):
"""Message in a choice."""
role: str = Field(..., description="Role of the message")
content: str = Field(..., description="Content of the message")
class OpenAIChoice(BaseModel):
"""Individual completion choice."""
index: int = Field(..., description="Choice index")
message: OpenAIChoiceMessage = Field(..., description="Message content")
finish_reason: str = Field(..., description="Reason for completion finish")
class OpenAIChatCompletionResponse(BaseModel):
"""OpenAI chat completion response format."""
id: str = Field(..., description="Unique completion ID")
object: str = Field(default="chat.completion", description="Object type")
created: int = Field(..., description="Unix timestamp of creation")
model: str = Field(..., description="Model used for completion")
choices: List[OpenAIChoice] = Field(..., description="List of completion choices")
usage: Optional[OpenAIUsage] = Field(None, description="Token usage information")
# Streaming-specific models
class OpenAIStreamDelta(BaseModel):
"""Delta content in a streaming chunk."""
role: Optional[str] = Field(None, description="Role (only in first chunk)")
content: Optional[str] = Field(None, description="Incremental content")
class OpenAIStreamChoice(BaseModel):
"""Individual streaming choice."""
index: int = Field(..., description="Choice index")
delta: OpenAIStreamDelta = Field(..., description="Delta content")
finish_reason: Optional[str] = Field(
None, description="Reason for completion finish"
)

View file

@ -19,6 +19,7 @@ ollama
# Utilities # Utilities
pydantic pydantic
pydantic-settings pydantic-settings
tenacity
python-dotenv python-dotenv
httpx httpx
aiofiles aiofiles
@ -31,3 +32,9 @@ python-json-logger
# CORS and Security # CORS and Security
python-jose[cryptography] python-jose[cryptography]
passlib[bcrypt] passlib[bcrypt]
# Rate Limiting
slowapi
# MCP Server
mcp[server]>=1.0.0

View file

@ -5,6 +5,8 @@ services:
container_name: mem0-qdrant container_name: mem0-qdrant
expose: expose:
- "6333" - "6333"
networks:
- mem0_network
volumes: volumes:
- qdrant_data:/qdrant/storage - qdrant_data:/qdrant/storage
command: > command: >
@ -32,6 +34,8 @@ services:
expose: expose:
- "7474" # HTTP - Internal only - "7474" # HTTP - Internal only
- "7687" # Bolt - Internal only - "7687" # Bolt - Internal only
networks:
- mem0_network
volumes: volumes:
- neo4j_data:/data - neo4j_data:/data
- neo4j_logs:/logs - neo4j_logs:/logs
@ -65,8 +69,14 @@ services:
CORS_ORIGINS: ${CORS_ORIGINS:-http://localhost:3000} CORS_ORIGINS: ${CORS_ORIGINS:-http://localhost:3000}
DEFAULT_MODEL: ${DEFAULT_MODEL:-claude-sonnet-4} DEFAULT_MODEL: ${DEFAULT_MODEL:-claude-sonnet-4}
API_KEYS: ${API_KEYS:-{}} API_KEYS: ${API_KEYS:-{}}
ports: OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://host.docker.internal:11434}
- "${BACKEND_PORT:-8000}:8000" EMBEDDING_MODEL: ${EMBEDDING_MODEL:-nomic-embed-text}
EMBEDDING_DIMS: ${EMBEDDING_DIMS:-2560}
expose:
- "8000"
networks:
- npm_network
- mem0_network
depends_on: depends_on:
qdrant: qdrant:
condition: service_healthy condition: service_healthy
@ -75,7 +85,8 @@ services:
restart: unless-stopped restart: unless-stopped
volumes: volumes:
- ./backend:/app - ./backend:/app
command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] - ./frontend:/app/frontend
command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
volumes: volumes:
qdrant_data: qdrant_data:
@ -85,5 +96,6 @@ volumes:
neo4j_plugins: neo4j_plugins:
networks: networks:
default: mem0_network:
name: mem0-network npm_network:
external: true

View file

@ -18,12 +18,106 @@
display: flex; display: flex;
} }
/* Login Screen */
.login-screen {
display: flex;
align-items: center;
justify-content: center;
width: 100%;
height: 100vh;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
.login-screen.hidden {
display: none;
}
.login-box {
background: white;
padding: 40px;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.2);
width: 100%;
max-width: 400px;
}
.login-box h1 {
margin-bottom: 10px;
color: #333;
font-size: 28px;
text-align: center;
}
.login-box p {
color: #666;
font-size: 14px;
text-align: center;
margin-bottom: 30px;
}
.login-box input {
width: 100%;
padding: 14px;
border: 2px solid #e0e0e0;
border-radius: 8px;
font-size: 14px;
margin-bottom: 20px;
outline: none;
transition: border-color 0.3s;
}
.login-box input:focus {
border-color: #667eea;
}
.login-box button {
width: 100%;
padding: 14px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s, opacity 0.3s;
}
.login-box button:hover {
transform: translateY(-2px);
}
.login-box button:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none;
}
.login-error {
background: #ffe6e6;
border: 1px solid #ffcccc;
color: #cc0000;
padding: 12px;
border-radius: 8px;
margin-bottom: 20px;
font-size: 14px;
display: none;
}
.login-error.show {
display: block;
}
.container { .container {
display: flex; display: flex;
width: 100%; width: 100%;
height: 100vh; height: 100vh;
} }
.container.hidden {
display: none;
}
/* Chat Section */ /* Chat Section */
.chat-section { .chat-section {
flex: 1; flex: 1;
@ -58,7 +152,12 @@
font-size: 14px; font-size: 14px;
} }
.clear-chat-btn { .header-buttons {
display: flex;
gap: 10px;
}
.clear-chat-btn, .logout-btn {
background: #f8f9fa; background: #f8f9fa;
color: #666; color: #666;
border: 1px solid #e0e0e0; border: 1px solid #e0e0e0;
@ -73,16 +172,27 @@
transition: all 0.2s ease; transition: all 0.2s ease;
} }
.clear-chat-btn:hover { .clear-chat-btn:hover, .logout-btn:hover {
background: #e9ecef; background: #e9ecef;
border-color: #ced4da; border-color: #ced4da;
color: #495057; color: #495057;
} }
.clear-chat-btn:active { .clear-chat-btn:active, .logout-btn:active {
background: #dee2e6; background: #dee2e6;
} }
.logout-btn {
background: #fff3cd;
border-color: #ffc107;
color: #856404;
}
.logout-btn:hover {
background: #ffe69c;
border-color: #ffb300;
}
.chat-messages { .chat-messages {
flex: 1; flex: 1;
overflow-y: auto; overflow-y: auto;
@ -281,17 +391,42 @@
</style> </style>
</head> </head>
<body> <body>
<div class="container"> <!-- Login Screen -->
<div class="login-screen" id="loginScreen">
<div class="login-box">
<h1>🧠 Mem0 Chat</h1>
<p>Enter your API key to access your memory-powered assistant</p>
<div class="login-error" id="loginError"></div>
<input
type="password"
id="apiKeyInput"
placeholder="Enter your API key (e.g., sk-xxxxx)"
autocomplete="off"
/>
<button id="loginButton">Connect</button>
</div>
</div>
<!-- Main Chat Interface (hidden initially) -->
<div class="container hidden" id="mainContainer">
<!-- Chat Section --> <!-- Chat Section -->
<div class="chat-section"> <div class="chat-section">
<div class="chat-header"> <div class="chat-header">
<div class="chat-header-content"> <div class="chat-header-content">
<h1>What can I help you with?</h1> <h1>What can I help you with?</h1>
<p>Chat with your memories - User: pratik</p> <p>Chat with your memories - User: <span id="currentUser">...</span></p>
</div>
<div class="header-buttons">
<button class="clear-chat-btn" id="clearChatBtn" title="Clear chat history">
🗑️ Clear Chat
</button>
<button class="logout-btn" id="logoutBtn" title="Logout">
🚪 Logout
</button>
</div> </div>
<button class="clear-chat-btn" id="clearChatBtn" title="Clear chat history">
🗑️ Clear Chat
</button>
</div> </div>
<div class="chat-messages" id="chatMessages"> <div class="chat-messages" id="chatMessages">
@ -319,10 +454,20 @@
<script> <script>
// Configuration // Configuration
const API_BASE = 'http://localhost:8000'; const API_BASE = window.location.origin;
const USER_ID = 'pratik';
// State
let API_KEY = null;
let USER_ID = null;
// DOM Elements // DOM Elements
const loginScreen = document.getElementById('loginScreen');
const mainContainer = document.getElementById('mainContainer');
const apiKeyInput = document.getElementById('apiKeyInput');
const loginButton = document.getElementById('loginButton');
const loginError = document.getElementById('loginError');
const logoutBtn = document.getElementById('logoutBtn');
const currentUser = document.getElementById('currentUser');
const chatMessages = document.getElementById('chatMessages'); const chatMessages = document.getElementById('chatMessages');
const messageInput = document.getElementById('messageInput'); const messageInput = document.getElementById('messageInput');
const sendButton = document.getElementById('sendButton'); const sendButton = document.getElementById('sendButton');
@ -336,19 +481,143 @@
// Initialize // Initialize
document.addEventListener('DOMContentLoaded', function() { document.addEventListener('DOMContentLoaded', function() {
loadChatHistory(); // Check if already logged in
loadMemories(); const savedApiKey = localStorage.getItem('apiKey');
const savedUserId = localStorage.getItem('userId');
if (savedApiKey && savedUserId) {
// Auto-login with saved credentials
API_KEY = savedApiKey;
USER_ID = savedUserId;
showMainInterface();
}
// Event listeners // Event listeners
loginButton.addEventListener('click', handleLogin);
apiKeyInput.addEventListener('keydown', (e) => {
if (e.key === 'Enter') handleLogin();
});
logoutBtn.addEventListener('click', handleLogout);
sendButton.addEventListener('click', sendMessage); sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keydown', handleKeyDown); messageInput.addEventListener('keydown', handleKeyDown);
messageInput.addEventListener('input', autoResizeTextarea); messageInput.addEventListener('input', autoResizeTextarea);
refreshButton.addEventListener('click', loadMemories); refreshButton.addEventListener('click', loadMemories);
clearChatBtn.addEventListener('click', clearChatWithConfirmation); clearChatBtn.addEventListener('click', clearChatWithConfirmation);
});
// Handle login
async function handleLogin() {
const apiKey = apiKeyInput.value.trim();
if (!apiKey) {
showLoginError('Please enter an API key');
return;
}
loginButton.disabled = true;
loginButton.textContent = 'Verifying...';
hideLoginError();
try {
// Verify API key by calling /health with auth
const response = await fetch(`${API_BASE}/health`, {
headers: {
'X-API-Key': apiKey
}
});
if (!response.ok) {
throw new Error('Invalid API key');
}
// Get user_id by trying to call a test endpoint
// We'll use /models since it doesn't require auth parameters
const userResponse = await fetch(`${API_BASE}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-API-Key': apiKey
},
body: JSON.stringify({
model: 'gpt-4',
messages: [{ role: 'user', content: 'test' }],
stream: false
})
});
if (!userResponse.ok) {
throw new Error('Failed to verify user');
}
// API key is valid, extract user_id from auth_service mapping
// We'll store both and use a simple username extraction
API_KEY = apiKey;
// Try to extract username from API key (e.g., sk-alice -> alice)
if (apiKey.startsWith('sk-')) {
const parts = apiKey.substring(3).split('-');
USER_ID = parts[0]; // Get first part after sk-
} else {
USER_ID = 'user';
}
// Save to localStorage
localStorage.setItem('apiKey', API_KEY);
localStorage.setItem('userId', USER_ID);
// Show main interface
showMainInterface();
} catch (error) {
console.error('Login error:', error);
showLoginError('Invalid API key. Please check and try again.');
loginButton.disabled = false;
loginButton.textContent = 'Connect';
}
}
// Show main interface
function showMainInterface() {
loginScreen.classList.add('hidden');
mainContainer.classList.remove('hidden');
currentUser.textContent = USER_ID;
// Load data
loadChatHistory();
loadMemories();
// Initialize textarea height // Initialize textarea height
autoResizeTextarea(); autoResizeTextarea();
}); messageInput.focus();
}
// Handle logout
function handleLogout() {
if (confirm('Are you sure you want to logout?')) {
// Clear credentials
localStorage.removeItem('apiKey');
localStorage.removeItem('userId');
API_KEY = null;
USER_ID = null;
// Show login screen
mainContainer.classList.add('hidden');
loginScreen.classList.remove('hidden');
apiKeyInput.value = '';
hideLoginError();
}
}
// Show login error
function showLoginError(message) {
loginError.textContent = message;
loginError.classList.add('show');
}
// Hide login error
function hideLoginError() {
loginError.classList.remove('show');
}
// Load chat history from localStorage // Load chat history from localStorage
function loadChatHistory() { function loadChatHistory() {
@ -419,6 +688,7 @@
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'X-API-Key': API_KEY
}, },
body: JSON.stringify({ body: JSON.stringify({
message: message, message: message,
@ -459,8 +729,12 @@
// Load memories from backend // Load memories from backend
async function loadMemories() { async function loadMemories() {
try { try{
const response = await fetch(`${API_BASE}/memories/${USER_ID}?limit=50`); const response = await fetch(`${API_BASE}/memories/${USER_ID}?limit=50`, {
headers: {
'X-API-Key': API_KEY
}
});
if (!response.ok) { if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`); throw new Error(`HTTP error! status: ${response.status}`);
@ -512,8 +786,11 @@
} }
try { try {
const response = await fetch(`${API_BASE}/memories/${memoryId}`, { const response = await fetch(`${API_BASE}/memories/${memoryId}?user_id=${USER_ID}`, {
method: 'DELETE' method: 'DELETE',
headers: {
'X-API-Key': API_KEY
}
}); });
if (!response.ok) { if (!response.ok) {

122
setup.sh Executable file
View file

@ -0,0 +1,122 @@
#!/bin/bash
set -e
COMPOSE_PROJECT_NAME="${COMPOSE_PROJECT_NAME:-mem0}"
VOLUMES=("qdrant_data" "neo4j_data" "neo4j_logs" "neo4j_import" "neo4j_plugins")
echo "=========================================="
echo " Mem0 Setup Script"
echo "=========================================="
echo ""
check_volumes_exist() {
local exists=false
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
exists=true
break
fi
done
echo "$exists"
}
list_existing_volumes() {
echo "Existing volumes:"
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
size=$(docker system df -v 2>/dev/null | grep "$full_name" | awk '{print $4}' || echo "unknown")
echo " - $full_name ($size)"
fi
done
}
remove_volumes() {
echo "Stopping containers..."
docker compose down 2>/dev/null || true
echo "Removing volumes..."
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
echo " Removing $full_name..."
docker volume rm "$full_name" 2>/dev/null || true
fi
done
echo "Volumes removed."
}
build_and_start() {
echo ""
echo "Building containers..."
docker compose build
echo ""
echo "Starting services..."
docker compose up -d
echo ""
echo "Waiting for services to be healthy..."
sleep 5
echo ""
echo "Checking health..."
curl -s http://localhost:8000/health 2>/dev/null | jq . || echo "Health check not available yet. Services may still be starting."
echo ""
echo "=========================================="
echo " Setup Complete!"
echo "=========================================="
echo ""
echo "Services:"
echo " - Backend API: http://localhost:8000"
echo " - API Docs: http://localhost:8000/docs"
echo " - Health Check: http://localhost:8000/health"
echo ""
echo "Logs: docker compose logs -f backend"
}
if [ "$(check_volumes_exist)" = "true" ]; then
echo "Existing data volumes detected!"
echo ""
list_existing_volumes
echo ""
echo "Options:"
echo " 1) Keep existing data and start services"
echo " 2) Reset everything (DELETE ALL DATA) and start fresh"
echo " 3) Exit"
echo ""
read -p "Choose an option [1/2/3]: " choice
case $choice in
1)
echo ""
echo "Keeping existing data..."
build_and_start
;;
2)
echo ""
read -p "Are you sure you want to DELETE ALL DATA? Type 'yes' to confirm: " confirm
if [ "$confirm" = "yes" ]; then
remove_volumes
build_and_start
else
echo "Aborted."
exit 1
fi
;;
3)
echo "Exiting."
exit 0
;;
*)
echo "Invalid option. Exiting."
exit 1
;;
esac
else
echo "No existing volumes found. Starting fresh setup..."
build_and_start
fi

View file

@ -19,27 +19,41 @@ import time
BASE_URL = "http://localhost:8000" BASE_URL = "http://localhost:8000"
TEST_USER = f"test_user_{int(datetime.now().timestamp())}" TEST_USER = f"test_user_{int(datetime.now().timestamp())}"
# API Key for authentication - set via environment or use default test key
import os
API_KEY = os.environ.get("MEM0_API_KEY", "test-api-key")
AUTH_HEADERS = {"X-API-Key": API_KEY}
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Mem0 Integration Tests - Real API Testing (Zero Mocking)", description="Mem0 Integration Tests - Real API Testing (Zero Mocking)",
formatter_class=argparse.RawDescriptionHelpFormatter formatter_class=argparse.RawDescriptionHelpFormatter,
) )
parser.add_argument("--verbose", "-v", action="store_true", parser.add_argument(
help="Show detailed output and API responses") "--verbose",
"-v",
action="store_true",
help="Show detailed output and API responses",
)
args = parser.parse_args() args = parser.parse_args()
verbose = args.verbose verbose = args.verbose
print("🧪 Mem0 Integration Tests - Real API Testing") print("🧪 Mem0 Integration Tests - Real API Testing")
print(f"🎯 Target: {BASE_URL}") print(f"🎯 Target: {BASE_URL}")
print(f"👤 Test User: {TEST_USER}") print(f"👤 Test User: {TEST_USER}")
print(f"⏰ Started: {datetime.now().strftime('%H:%M:%S')}") print(f"⏰ Started: {datetime.now().strftime('%H:%M:%S')}")
print("=" * 50) print("=" * 50)
# Test sequence - order matters for data dependencies # Test sequence - order matters for data dependencies
tests = [ tests = [
test_health_check, test_health_check,
test_empty_search_protection, test_auth_required_endpoints,
test_ownership_verification,
test_request_size_limit,
test_empty_search_protection,
test_add_memories_with_hierarchy, test_add_memories_with_hierarchy,
test_search_memories_basic, test_search_memories_basic,
test_search_memories_hierarchy_filters, test_search_memories_hierarchy_filters,
@ -51,30 +65,30 @@ def main():
test_graph_relationships, test_graph_relationships,
test_delete_specific_memory, test_delete_specific_memory,
test_delete_all_user_memories, test_delete_all_user_memories,
test_cleanup_verification test_cleanup_verification,
] ]
results = [] results = []
start_time = time.time() start_time = time.time()
for test in tests: for test in tests:
result = run_test(test.__name__, test, verbose) result = run_test(test.__name__, test, verbose)
results.append(result) results.append(result)
# Small delay between tests for API stability # Small delay between tests for API stability
time.sleep(0.5) time.sleep(0.5)
# Summary # Summary
end_time = time.time() end_time = time.time()
duration = end_time - start_time duration = end_time - start_time
passed = sum(1 for r in results if r) passed = sum(1 for r in results if r)
total = len(results) total = len(results)
print("=" * 50) print("=" * 50)
print(f"📊 Test Results: {passed}/{total} tests passed") print(f"📊 Test Results: {passed}/{total} tests passed")
print(f"⏱️ Duration: {duration:.2f} seconds") print(f"⏱️ Duration: {duration:.2f} seconds")
if passed == total: if passed == total:
print("✅ All tests passed! System is working correctly.") print("✅ All tests passed! System is working correctly.")
sys.exit(0) sys.exit(0)
@ -82,16 +96,17 @@ def main():
print("❌ Some tests failed! Check the output above.") print("❌ Some tests failed! Check the output above.")
sys.exit(1) sys.exit(1)
def run_test(name, test_func, verbose): def run_test(name, test_func, verbose):
"""Run a single test with error handling""" """Run a single test with error handling"""
try: try:
if verbose: if verbose:
print(f"\n🔍 Running {name}...") print(f"\n🔍 Running {name}...")
test_func(verbose) test_func(verbose)
print(f"{name}") print(f"{name}")
return True return True
except AssertionError as e: except AssertionError as e:
print(f"{name}: Assertion failed - {e}") print(f"{name}: Assertion failed - {e}")
return False return False
@ -102,6 +117,7 @@ def run_test(name, test_func, verbose):
print(f"{name}: {e}") print(f"{name}: {e}")
return False return False
def log_response(response, verbose, context=""): def log_response(response, verbose, context=""):
"""Log API response details if verbose""" """Log API response details if verbose"""
if verbose: if verbose:
@ -111,78 +127,101 @@ def log_response(response, verbose, context=""):
if isinstance(data, dict) and len(data) < 5: if isinstance(data, dict) and len(data) < 5:
print(f" {context} Response: {data}") print(f" {context} Response: {data}")
else: else:
print(f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}") print(
f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}"
)
except: except:
print(f" {context} Response: {response.text[:100]}...") print(f" {context} Response: {response.text[:100]}...")
# ================== TEST FUNCTIONS ================== # ================== TEST FUNCTIONS ==================
def test_health_check(verbose): def test_health_check(verbose):
"""Test service health endpoint""" """Test service health endpoint"""
response = requests.get(f"{BASE_URL}/health", timeout=10) response = requests.get(
f"{BASE_URL}/health", timeout=10
) # Health doesn't require auth
log_response(response, verbose, "Health") log_response(response, verbose, "Health")
assert response.status_code == 200, f"Expected 200, got {response.status_code}" assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json() data = response.json()
assert "status" in data, "Health response missing 'status' field" assert "status" in data, "Health response missing 'status' field"
assert data["status"] in ["healthy", "degraded"], f"Invalid status: {data['status']}" assert data["status"] in ["healthy", "degraded"], (
f"Invalid status: {data['status']}"
)
# Check individual services # Check individual services
assert "services" in data, "Health response missing 'services' field" assert "services" in data, "Health response missing 'services' field"
if verbose: if verbose:
print(f" Overall status: {data['status']}") print(f" Overall status: {data['status']}")
for service, status in data["services"].items(): for service, status in data["services"].items():
print(f" {service}: {status}") print(f" {service}: {status}")
def test_empty_search_protection(verbose): def test_empty_search_protection(verbose):
"""Test empty query protection (should not return 500 error)""" """Test empty query protection (should not return 500 error)"""
payload = { payload = {"query": "", "user_id": TEST_USER, "limit": 5}
"query": "",
"user_id": TEST_USER, response = requests.post(
"limit": 5 f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=10
} )
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=10)
log_response(response, verbose, "Empty Search") log_response(response, verbose, "Empty Search")
assert response.status_code == 200, f"Empty query failed with {response.status_code}" assert response.status_code == 200, (
f"Empty query failed with {response.status_code}"
)
data = response.json() data = response.json()
assert data["memories"] == [], "Empty query should return empty memories list" assert data["memories"] == [], "Empty query should return empty memories list"
assert "note" in data, "Empty query response should include explanatory note" assert "note" in data, "Empty query response should include explanatory note"
assert data["query"] == "", "Query should be echoed back" assert data["query"] == "", "Query should be echoed back"
if verbose: if verbose:
print(f" Empty search note: {data['note']}") print(f" Empty search note: {data['note']}")
print(f" Total count: {data.get('total_count', 0)}") print(f" Total count: {data.get('total_count', 0)}")
def test_add_memories_with_hierarchy(verbose): def test_add_memories_with_hierarchy(verbose):
"""Test adding memories with multi-level hierarchy support""" """Test adding memories with multi-level hierarchy support"""
payload = { payload = {
"messages": [ "messages": [
{"role": "user", "content": "I work at TechCorp as a Senior Software Engineer"}, {
{"role": "user", "content": "My colleague Sarah from Marketing team helped with Q3 presentation"}, "role": "user",
{"role": "user", "content": "Meeting with John the Product Manager tomorrow about new feature development"} "content": "I work at TechCorp as a Senior Software Engineer",
},
{
"role": "user",
"content": "My colleague Sarah from Marketing team helped with Q3 presentation",
},
{
"role": "user",
"content": "Meeting with John the Product Manager tomorrow about new feature development",
},
], ],
"user_id": TEST_USER, "user_id": TEST_USER,
"agent_id": "test_agent", "agent_id": "test_agent",
"run_id": "test_run_001", "run_id": "test_run_001",
"session_id": "test_session_001", "session_id": "test_session_001",
"metadata": {"test": "integration", "scenario": "work_context"} "metadata": {"test": "integration", "scenario": "work_context"},
} }
response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60) response = requests.post(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
)
log_response(response, verbose, "Add Memories") log_response(response, verbose, "Add Memories")
assert response.status_code == 200, f"Add memories failed with {response.status_code}" assert response.status_code == 200, (
f"Add memories failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "added_memories" in data, "Response missing 'added_memories'" assert "added_memories" in data, "Response missing 'added_memories'"
assert "message" in data, "Response missing success message" assert "message" in data, "Response missing success message"
assert len(data["added_memories"]) > 0, "No memories were added" assert len(data["added_memories"]) > 0, "No memories were added"
# Verify graph extraction (if available) # Verify graph extraction (if available)
memories = data["added_memories"] memories = data["added_memories"]
if isinstance(memories, list) and len(memories) > 0: if isinstance(memories, list) and len(memories) > 0:
@ -191,47 +230,51 @@ def test_add_memories_with_hierarchy(verbose):
relations = first_memory["relations"] relations = first_memory["relations"]
if "added_entities" in relations and relations["added_entities"]: if "added_entities" in relations and relations["added_entities"]:
if verbose: if verbose:
print(f" Graph extracted: {len(relations['added_entities'])} relationships") print(
f" Graph extracted: {len(relations['added_entities'])} relationships"
)
print(f" Sample relations: {relations['added_entities'][:3]}") print(f" Sample relations: {relations['added_entities'][:3]}")
if verbose: if verbose:
print(f" Added {len(memories)} memory blocks") print(f" Added {len(memories)} memory blocks")
print(f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001") print(
f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001"
)
def test_search_memories_basic(verbose): def test_search_memories_basic(verbose):
"""Test basic memory search functionality""" """Test basic memory search functionality"""
# Test meaningful search # Test meaningful search
payload = { payload = {"query": "TechCorp", "user_id": TEST_USER, "limit": 10}
"query": "TechCorp",
"user_id": TEST_USER, response = requests.post(
"limit": 10 f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
} )
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15)
log_response(response, verbose, "Search") log_response(response, verbose, "Search")
assert response.status_code == 200, f"Search failed with {response.status_code}" assert response.status_code == 200, f"Search failed with {response.status_code}"
data = response.json() data = response.json()
assert "memories" in data, "Search response missing 'memories'" assert "memories" in data, "Search response missing 'memories'"
assert "total_count" in data, "Search response missing 'total_count'" assert "total_count" in data, "Search response missing 'total_count'"
assert "query" in data, "Search response missing 'query'" assert "query" in data, "Search response missing 'query'"
assert data["query"] == "TechCorp", "Query not echoed correctly" assert data["query"] == "TechCorp", "Query not echoed correctly"
# Should find memories since we just added some # Should find memories since we just added some
assert data["total_count"] > 0, "Search should find previously added memories" assert data["total_count"] > 0, "Search should find previously added memories"
assert len(data["memories"]) > 0, "Search should return memory results" assert len(data["memories"]) > 0, "Search should return memory results"
# Verify memory structure # Verify memory structure
memory = data["memories"][0] memory = data["memories"][0]
assert "id" in memory, "Memory missing 'id'" assert "id" in memory, "Memory missing 'id'"
assert "memory" in memory, "Memory missing 'memory' content" assert "memory" in memory, "Memory missing 'memory' content"
assert "user_id" in memory, "Memory missing 'user_id'" assert "user_id" in memory, "Memory missing 'user_id'"
if verbose: if verbose:
print(f" Found {data['total_count']} memories") print(f" Found {data['total_count']} memories")
print(f" First memory: {memory['memory'][:50]}...") print(f" First memory: {memory['memory'][:50]}...")
def test_search_memories_hierarchy_filters(verbose): def test_search_memories_hierarchy_filters(verbose):
"""Test multi-level hierarchy filtering in search""" """Test multi-level hierarchy filtering in search"""
# Test with hierarchy filters # Test with hierarchy filters
@ -241,23 +284,30 @@ def test_search_memories_hierarchy_filters(verbose):
"agent_id": "test_agent", "agent_id": "test_agent",
"run_id": "test_run_001", "run_id": "test_run_001",
"session_id": "test_session_001", "session_id": "test_session_001",
"limit": 10 "limit": 10,
} }
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15) response = requests.post(
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Hierarchy Search") log_response(response, verbose, "Hierarchy Search")
assert response.status_code == 200, f"Hierarchy search failed with {response.status_code}" assert response.status_code == 200, (
f"Hierarchy search failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "memories" in data, "Hierarchy search response missing 'memories'" assert "memories" in data, "Hierarchy search response missing 'memories'"
# Should find memories since we added with these exact hierarchy values # Should find memories since we added with these exact hierarchy values
assert len(data["memories"]) > 0, "Should find memories with matching hierarchy" assert len(data["memories"]) > 0, "Should find memories with matching hierarchy"
if verbose: if verbose:
print(f" Found {len(data['memories'])} memories with hierarchy filters") print(f" Found {len(data['memories'])} memories with hierarchy filters")
print(f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001") print(
f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001"
)
def test_get_user_memories_with_hierarchy(verbose): def test_get_user_memories_with_hierarchy(verbose):
"""Test retrieving user memories with hierarchy filtering""" """Test retrieving user memories with hierarchy filtering"""
@ -266,23 +316,30 @@ def test_get_user_memories_with_hierarchy(verbose):
"limit": 20, "limit": 20,
"agent_id": "test_agent", "agent_id": "test_agent",
"run_id": "test_run_001", "run_id": "test_run_001",
"session_id": "test_session_001" "session_id": "test_session_001",
} }
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}", params=params, timeout=15) response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}",
params=params,
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Get User Memories with Hierarchy") log_response(response, verbose, "Get User Memories with Hierarchy")
assert response.status_code == 200, f"Get user memories with hierarchy failed with {response.status_code}" assert response.status_code == 200, (
f"Get user memories with hierarchy failed with {response.status_code}"
)
memories = response.json() memories = response.json()
assert isinstance(memories, list), "User memories should return a list" assert isinstance(memories, list), "User memories should return a list"
if len(memories) > 0: if len(memories) > 0:
memory = memories[0] memory = memories[0]
assert "id" in memory, "Memory missing 'id'" assert "id" in memory, "Memory missing 'id'"
assert "memory" in memory, "Memory missing 'memory' content" assert "memory" in memory, "Memory missing 'memory' content"
assert memory["user_id"] == TEST_USER, f"Wrong user_id: {memory['user_id']}" assert memory["user_id"] == TEST_USER, f"Wrong user_id: {memory['user_id']}"
if verbose: if verbose:
print(f" Retrieved {len(memories)} memories with hierarchy filters") print(f" Retrieved {len(memories)} memories with hierarchy filters")
print(f" First memory: {memory['memory'][:40]}...") print(f" First memory: {memory['memory'][:40]}...")
@ -290,248 +347,320 @@ def test_get_user_memories_with_hierarchy(verbose):
if verbose: if verbose:
print(" No memories found with hierarchy filters (may be expected)") print(" No memories found with hierarchy filters (may be expected)")
def test_memory_history(verbose): def test_memory_history(verbose):
"""Test memory history endpoint""" """Test memory history endpoint"""
# First get a memory to check history for # First get a memory to check history for
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10) response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for history test" assert response.status_code == 200, "Failed to get memory for history test"
memories = response.json() memories = response.json()
if len(memories) == 0: if len(memories) == 0:
if verbose: if verbose:
print(" No memories available for history test (skipping)") print(" No memories available for history test (skipping)")
return return
memory_id = memories[0]["id"] memory_id = memories[0]["id"]
# Test memory history endpoint # Test memory history endpoint
response = requests.get(f"{BASE_URL}/memories/{memory_id}/history", timeout=15) response = requests.get(
f"{BASE_URL}/memories/{memory_id}/history?user_id={TEST_USER}",
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Memory History") log_response(response, verbose, "Memory History")
assert response.status_code == 200, f"Memory history failed with {response.status_code}" assert response.status_code == 200, (
f"Memory history failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "memory_id" in data, "History response missing 'memory_id'" assert "memory_id" in data, "History response missing 'memory_id'"
assert "history" in data, "History response missing 'history'" assert "history" in data, "History response missing 'history'"
assert "message" in data, "History response missing success message" assert "message" in data, "History response missing success message"
assert data["memory_id"] == memory_id, f"Wrong memory_id in response: {data['memory_id']}" assert data["memory_id"] == memory_id, (
f"Wrong memory_id in response: {data['memory_id']}"
)
if verbose: if verbose:
print(f" Retrieved history for memory {memory_id}") print(f" Retrieved history for memory {memory_id}")
print(f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}") print(
f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}"
)
def test_update_memory(verbose): def test_update_memory(verbose):
"""Test updating a specific memory""" """Test updating a specific memory"""
# First get a memory to update # First get a memory to update
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10) response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for update test" assert response.status_code == 200, "Failed to get memory for update test"
memories = response.json() memories = response.json()
assert len(memories) > 0, "No memories available to update" assert len(memories) > 0, "No memories available to update"
memory_id = memories[0]["id"] memory_id = memories[0]["id"]
original_content = memories[0]["memory"] original_content = memories[0]["memory"]
# Update the memory # Update the memory
payload = { payload = {
"memory_id": memory_id, "memory_id": memory_id,
"content": f"UPDATED: {original_content}" "user_id": TEST_USER,
"content": f"UPDATED: {original_content}",
} }
response = requests.put(f"{BASE_URL}/memories", json=payload, timeout=10) response = requests.put(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Update") log_response(response, verbose, "Update")
assert response.status_code == 200, f"Update failed with {response.status_code}" assert response.status_code == 200, f"Update failed with {response.status_code}"
data = response.json() data = response.json()
assert "message" in data, "Update response missing success message" assert "message" in data, "Update response missing success message"
if verbose: if verbose:
print(f" Updated memory {memory_id}") print(f" Updated memory {memory_id}")
print(f" Original: {original_content[:30]}...") print(f" Original: {original_content[:30]}...")
def test_chat_with_memory(verbose): def test_chat_with_memory(verbose):
"""Test memory-enhanced chat functionality""" """Test memory-enhanced chat functionality"""
payload = { payload = {"message": "What company do I work for?", "user_id": TEST_USER}
"message": "What company do I work for?",
"user_id": TEST_USER
}
try: try:
response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=90) response = requests.post(
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=90
)
log_response(response, verbose, "Chat") log_response(response, verbose, "Chat")
assert response.status_code == 200, f"Chat failed with {response.status_code}" assert response.status_code == 200, f"Chat failed with {response.status_code}"
data = response.json() data = response.json()
assert "response" in data, "Chat response missing 'response'" assert "response" in data, "Chat response missing 'response'"
assert "memories_used" in data, "Chat response missing 'memories_used'" assert "memories_used" in data, "Chat response missing 'memories_used'"
assert "model_used" in data, "Chat response missing 'model_used'" assert "model_used" in data, "Chat response missing 'model_used'"
# Should use some memories for context # Should use some memories for context
assert data["memories_used"] >= 0, "Memories used should be non-negative" assert data["memories_used"] >= 0, "Memories used should be non-negative"
if verbose: if verbose:
print(f" Chat response: {data['response'][:60]}...") print(f" Chat response: {data['response'][:60]}...")
print(f" Memories used: {data['memories_used']}") print(f" Memories used: {data['memories_used']}")
print(f" Model: {data['model_used']}") print(f" Model: {data['model_used']}")
except requests.exceptions.ReadTimeout: except requests.exceptions.ReadTimeout:
if verbose: if verbose:
print(" Chat endpoint timed out (LLM API may be slow)") print(" Chat endpoint timed out (LLM API may be slow)")
# Still test that the endpoint exists and accepts requests # Still test that the endpoint exists and accepts requests
try: try:
response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=5) response = requests.post(
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=5
)
except requests.exceptions.ReadTimeout: except requests.exceptions.ReadTimeout:
# This is expected - endpoint exists but processing is slow # This is expected - endpoint exists but processing is slow
if verbose: if verbose:
print(" Chat endpoint confirmed active (processing timeout expected)") print(" Chat endpoint confirmed active (processing timeout expected)")
def test_graph_relationships_creation(verbose): def test_graph_relationships_creation(verbose):
"""Test graph relationships creation with entity-rich memories""" """Test graph relationships creation with entity-rich memories"""
# Create a separate test user for graph relationship testing # Create a separate test user for graph relationship testing
graph_test_user = f"graph_test_user_{int(datetime.now().timestamp())}" graph_test_user = f"graph_test_user_{int(datetime.now().timestamp())}"
# Add memories with clear entity relationships # Add memories with clear entity relationships
payload = { payload = {
"messages": [ "messages": [
{"role": "user", "content": "John Smith works at Microsoft as a Senior Software Engineer"}, {
{"role": "user", "content": "John Smith is friends with Sarah Johnson who works at Google"}, "role": "user",
{"role": "user", "content": "Sarah Johnson lives in Seattle and loves hiking"}, "content": "John Smith works at Microsoft as a Senior Software Engineer",
},
{
"role": "user",
"content": "John Smith is friends with Sarah Johnson who works at Google",
},
{
"role": "user",
"content": "Sarah Johnson lives in Seattle and loves hiking",
},
{"role": "user", "content": "Microsoft is located in Redmond, Washington"}, {"role": "user", "content": "Microsoft is located in Redmond, Washington"},
{"role": "user", "content": "John Smith and Sarah Johnson both graduated from Stanford University"} {
"role": "user",
"content": "John Smith and Sarah Johnson both graduated from Stanford University",
},
], ],
"user_id": graph_test_user, "user_id": graph_test_user,
"metadata": {"test": "graph_relationships", "scenario": "entity_creation"} "metadata": {"test": "graph_relationships", "scenario": "entity_creation"},
} }
response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60) response = requests.post(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
)
log_response(response, verbose, "Add Graph Memories") log_response(response, verbose, "Add Graph Memories")
assert response.status_code == 200, f"Add graph memories failed with {response.status_code}" assert response.status_code == 200, (
f"Add graph memories failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "added_memories" in data, "Response missing 'added_memories'" assert "added_memories" in data, "Response missing 'added_memories'"
if verbose: if verbose:
print(f" Added {len(data['added_memories'])} memories for graph relationship testing") print(
f" Added {len(data['added_memories'])} memories for graph relationship testing"
)
# Wait a moment for graph processing (Mem0 graph extraction can be async) # Wait a moment for graph processing (Mem0 graph extraction can be async)
time.sleep(2) time.sleep(2)
# Test graph relationships endpoint # Test graph relationships endpoint
response = requests.get(f"{BASE_URL}/graph/relationships/{graph_test_user}", timeout=15) response = requests.get(
f"{BASE_URL}/graph/relationships/{graph_test_user}",
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Graph Relationships") log_response(response, verbose, "Graph Relationships")
assert response.status_code == 200, f"Graph relationships failed with {response.status_code}" assert response.status_code == 200, (
f"Graph relationships failed with {response.status_code}"
)
graph_data = response.json() graph_data = response.json()
assert "relationships" in graph_data, "Graph response missing 'relationships'" assert "relationships" in graph_data, "Graph response missing 'relationships'"
assert "entities" in graph_data, "Graph response missing 'entities'" assert "entities" in graph_data, "Graph response missing 'entities'"
assert "user_id" in graph_data, "Graph response missing 'user_id'" assert "user_id" in graph_data, "Graph response missing 'user_id'"
assert graph_data["user_id"] == graph_test_user, f"Wrong user_id in graph: {graph_data['user_id']}" assert graph_data["user_id"] == graph_test_user, (
f"Wrong user_id in graph: {graph_data['user_id']}"
)
relationships = graph_data["relationships"] relationships = graph_data["relationships"]
entities = graph_data["entities"] entities = graph_data["entities"]
if verbose: if verbose:
print(f" Found {len(relationships)} relationships") print(f" Found {len(relationships)} relationships")
print(f" Found {len(entities)} entities") print(f" Found {len(entities)} entities")
# Print sample relationships if they exist # Print sample relationships if they exist
if relationships: if relationships:
print(f" Sample relationships:") print(f" Sample relationships:")
for i, rel in enumerate(relationships[:3]): # Show first 3 for i, rel in enumerate(relationships[:3]): # Show first 3
source = rel.get("source", "unknown") source = rel.get("source", "unknown")
target = rel.get("target", "unknown") target = rel.get("target", "unknown")
relationship = rel.get("relationship", "unknown") relationship = rel.get("relationship", "unknown")
print(f" {i+1}. {source} --{relationship}--> {target}") print(f" {i + 1}. {source} --{relationship}--> {target}")
# Print sample entities if they exist # Print sample entities if they exist
if entities: if entities:
print(f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}") print(
f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}"
)
# Verify relationship structure (if relationships exist) # Verify relationship structure (if relationships exist)
for rel in relationships: for rel in relationships:
assert "source" in rel or "from" in rel, f"Relationship missing source/from: {rel}" assert "source" in rel or "from" in rel, (
f"Relationship missing source/from: {rel}"
)
assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}" assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}"
assert "relationship" in rel or "type" in rel, f"Relationship missing type: {rel}" assert "relationship" in rel or "type" in rel, (
f"Relationship missing type: {rel}"
)
# Clean up graph test user memories # Clean up graph test user memories
cleanup_response = requests.delete(f"{BASE_URL}/memories/user/{graph_test_user}", timeout=15) cleanup_response = requests.delete(
f"{BASE_URL}/memories/user/{graph_test_user}", headers=AUTH_HEADERS, timeout=15
)
assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories" assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories"
if verbose: if verbose:
print(f" Cleaned up graph test user: {graph_test_user}") print(f" Cleaned up graph test user: {graph_test_user}")
# Note: We expect some relationships even if graph extraction is basic # Note: We expect some relationships even if graph extraction is basic
# The test passes if the endpoint works and returns proper structure # The test passes if the endpoint works and returns proper structure
def test_graph_relationships(verbose): def test_graph_relationships(verbose):
"""Test graph relationships endpoint""" """Test graph relationships endpoint"""
response = requests.get(f"{BASE_URL}/graph/relationships/{TEST_USER}", timeout=15) response = requests.get(
f"{BASE_URL}/graph/relationships/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Graph") log_response(response, verbose, "Graph")
assert response.status_code == 200, f"Graph endpoint failed with {response.status_code}" assert response.status_code == 200, (
f"Graph endpoint failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "relationships" in data, "Graph response missing 'relationships'" assert "relationships" in data, "Graph response missing 'relationships'"
assert "entities" in data, "Graph response missing 'entities'" assert "entities" in data, "Graph response missing 'entities'"
assert "user_id" in data, "Graph response missing 'user_id'" assert "user_id" in data, "Graph response missing 'user_id'"
assert data["user_id"] == TEST_USER, f"Wrong user_id in graph: {data['user_id']}" assert data["user_id"] == TEST_USER, f"Wrong user_id in graph: {data['user_id']}"
if verbose: if verbose:
print(f" Relationships: {len(data['relationships'])}") print(f" Relationships: {len(data['relationships'])}")
print(f" Entities: {len(data['entities'])}") print(f" Entities: {len(data['entities'])}")
def test_delete_specific_memory(verbose): def test_delete_specific_memory(verbose):
"""Test deleting a specific memory""" """Test deleting a specific memory"""
# Get a memory to delete # Get a memory to delete
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10) response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for deletion test" assert response.status_code == 200, "Failed to get memory for deletion test"
memories = response.json() memories = response.json()
assert len(memories) > 0, "No memories available to delete" assert len(memories) > 0, "No memories available to delete"
memory_id = memories[0]["id"] memory_id = memories[0]["id"]
# Delete the memory # Delete the memory
response = requests.delete(f"{BASE_URL}/memories/{memory_id}", timeout=10) response = requests.delete(
f"{BASE_URL}/memories/{memory_id}", headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Delete") log_response(response, verbose, "Delete")
assert response.status_code == 200, f"Delete failed with {response.status_code}" assert response.status_code == 200, f"Delete failed with {response.status_code}"
data = response.json() data = response.json()
assert "message" in data, "Delete response missing success message" assert "message" in data, "Delete response missing success message"
if verbose: if verbose:
print(f" Deleted memory {memory_id}") print(f" Deleted memory {memory_id}")
def test_delete_all_user_memories(verbose): def test_delete_all_user_memories(verbose):
"""Test deleting all memories for a user""" """Test deleting all memories for a user"""
response = requests.delete(f"{BASE_URL}/memories/user/{TEST_USER}", timeout=15) response = requests.delete(
f"{BASE_URL}/memories/user/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Delete All") log_response(response, verbose, "Delete All")
assert response.status_code == 200, f"Delete all failed with {response.status_code}" assert response.status_code == 200, f"Delete all failed with {response.status_code}"
data = response.json() data = response.json()
assert "message" in data, "Delete all response missing success message" assert "message" in data, "Delete all response missing success message"
if verbose: if verbose:
print(f"Deleted all memories for {TEST_USER}") print(f"Deleted all memories for {TEST_USER}")
def test_cleanup_verification(verbose): def test_cleanup_verification(verbose):
"""Verify cleanup was successful""" """Verify cleanup was successful"""
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=10", timeout=10) response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=10", headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Cleanup Check") log_response(response, verbose, "Cleanup Check")
assert response.status_code == 200, f"Cleanup verification failed with {response.status_code}" assert response.status_code == 200, (
f"Cleanup verification failed with {response.status_code}"
)
memories = response.json() memories = response.json()
assert isinstance(memories, list), "Should return list even if empty" assert isinstance(memories, list), "Should return list even if empty"
# Should be empty after deletion # Should be empty after deletion
if len(memories) > 0: if len(memories) > 0:
print(f" Warning: {len(memories)} memories still exist after cleanup") print(f" Warning: {len(memories)} memories still exist after cleanup")
@ -539,5 +668,79 @@ def test_cleanup_verification(verbose):
if verbose: if verbose:
print(" Cleanup successful - no memories remain") print(" Cleanup successful - no memories remain")
# ================== SECURITY TEST FUNCTIONS ==================
def test_auth_required_endpoints(verbose):
"""Test that protected endpoints require authentication"""
endpoints_requiring_auth = [
("GET", f"{BASE_URL}/memories/{TEST_USER}"),
("POST", f"{BASE_URL}/memories/search"),
("GET", f"{BASE_URL}/stats"),
("GET", f"{BASE_URL}/models"),
("GET", f"{BASE_URL}/users"),
]
for method, url in endpoints_requiring_auth:
if method == "GET":
response = requests.get(url, timeout=5)
else:
response = requests.post(
url, json={"query": "test", "user_id": TEST_USER}, timeout=5
)
assert response.status_code in [401, 403], (
f"{method} {url} should require auth, got {response.status_code}"
)
if verbose:
print(f" {method} {url}: {response.status_code} (auth required)")
def test_ownership_verification(verbose):
"""Test that users can only access their own data"""
other_user = "other_user_not_me"
response = requests.get(
f"{BASE_URL}/memories/{other_user}", headers=AUTH_HEADERS, timeout=5
)
assert response.status_code in [403, 404], (
f"Accessing other user's memories should be denied, got {response.status_code}"
)
if verbose:
print(f" Ownership check passed: {response.status_code}")
def test_request_size_limit(verbose):
"""Test request size limit enforcement (10MB max)"""
large_payload = {
"messages": [{"role": "user", "content": "x" * (11 * 1024 * 1024)}],
"user_id": TEST_USER,
}
try:
response = requests.post(
f"{BASE_URL}/memories",
json=large_payload,
headers={**AUTH_HEADERS, "Content-Length": str(11 * 1024 * 1024)},
timeout=5,
)
assert response.status_code == 413, (
f"Large request should return 413, got {response.status_code}"
)
if verbose:
print(f" Request size limit enforced: {response.status_code}")
except requests.exceptions.RequestException as e:
if verbose:
print(
f" Request size limit test: connection issue (expected for large payload)"
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()