Compare commits

..

10 commits

Author SHA1 Message Date
82accabc73 fix: properly log settings as structlog kwargs instead of dropping them 2026-01-16 00:40:53 +05:30
6f9b545c15 fix: remove invalid positional arg from structlog call 2026-01-16 00:38:41 +05:30
5bcecf4649 fix: use structlog instead of logging in mem0_manager for kwargs support 2026-01-16 00:34:51 +05:30
638a591dc5 add setup script with volume reset option, fix default embedding dims 2026-01-16 00:29:29 +05:30
a190527076 fix: use npm_network for NPM proxy, expose instead of ports 2026-01-16 00:01:41 +05:30
9e86c30548 fix: pass OLLAMA_BASE_URL, EMBEDDING_MODEL, EMBEDDING_DIMS to container 2026-01-15 23:55:44 +05:30
2c1d73a1ec add OpenAI-compatible endpoint and improved login UI
- Add /v1/chat/completions and /chat/completions endpoints (OpenAI SDK compatible)
- Add streaming support with SSE for chat completions
- Add get_current_user_openai auth supporting Bearer token and X-API-Key
- Add OpenAI-compatible request/response models (OpenAIChatCompletionRequest, etc.)
- Cherry-pick improved login UI from cloud branch (styled login screen, logout button)
2026-01-15 23:29:08 +05:30
a228780146 production improvements: configurable embeddings, v1.1, O(1) ownership, retries
- Make Ollama URL configurable via OLLAMA_BASE_URL env var
- Add version: v1.1 to Mem0 config (required for latest features)
- Make embedding model and dimensions configurable
- Fix ownership check: O(1) lookup instead of fetching 10k records
- Add tenacity retry logic for database operations
2026-01-15 23:01:18 +05:30
50edce2d3c security hardening: add auth, rate limiting, fix info disclosure
- Add auth to /models and /users endpoints
- Add rate limiting to all endpoints (10-120/min based on operation type)
- Fix 11 info disclosure issues (detail=str(e) -> generic message)
- Fix 2 silent except blocks with proper logging
- Fix 7 raise e -> raise for proper exception chaining
- Fix health check to not expose exception details
- Update tests with X-API-Key headers and security tests
2026-01-15 22:41:24 +05:30
35c1bbec4e added MCP HTTP endpoint with auth
Exposes memory operations as MCP tools over /mcp endpoint:
- add_memory, search_memory, remove_memory, chat
- API key auth via x-api-key or Authorization header
- User isolation enforced via contextvars
2026-01-11 14:00:16 +05:30
11 changed files with 1598 additions and 615 deletions

View file

@ -11,14 +11,11 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements and install Python dependencies
COPY backend/requirements.txt .
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy backend code
COPY backend/ .
# Copy frontend directory
COPY frontend/ /app/frontend/
# Copy application code
COPY . .
# Set Python path
ENV PYTHONPATH=/app

View file

@ -19,7 +19,9 @@ class AuthService:
def __init__(self):
"""Initialize auth service with API key to user mapping."""
self.api_key_to_user = settings.api_key_mapping
logger.info(f"Auth service initialized with {len(self.api_key_to_user)} API keys")
logger.info(
f"Auth service initialized with {len(self.api_key_to_user)} API keys"
)
def verify_api_key(self, api_key: str) -> str:
"""
@ -37,8 +39,7 @@ class AuthService:
if api_key not in self.api_key_to_user:
logger.warning(f"Invalid API key attempted: {api_key[:10]}...")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
)
user_id = self.api_key_to_user[api_key]
@ -68,7 +69,7 @@ class AuthService:
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied: You can only access your own memories"
detail=f"Access denied: You can only access your own memories",
)
return authenticated_user_id
@ -93,7 +94,7 @@ async def get_current_user(api_key: str = Security(api_key_header)) -> str:
async def get_current_user_openai(
authorization: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None, alias="X-API-Key")
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
) -> str:
"""
FastAPI dependency for OpenAI-compatible authentication.
@ -114,24 +115,23 @@ async def get_current_user_openai(
# Try Bearer token first (OpenAI standard)
if authorization and authorization.startswith("Bearer "):
api_key = authorization[7:] # Remove "Bearer " prefix
logger.debug(f"Extracted API key from Authorization Bearer token")
logger.debug("Extracted API key from Authorization Bearer token")
# Fall back to X-API-Key header
elif x_api_key:
api_key = x_api_key
logger.debug(f"Extracted API key from X-API-Key header")
logger.debug("Extracted API key from X-API-Key header")
else:
logger.warning("No API key provided in Authorization or X-API-Key headers")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API key. Provide either 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header"
detail="Missing API key. Provide either 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header",
)
return auth_service.verify_api_key(api_key)
async def verify_user_access(
api_key: str = Security(api_key_header),
user_id: Optional[str] = None
api_key: str = Security(api_key_header), user_id: Optional[str] = None
) -> str:
"""
FastAPI dependency to verify user can access the requested user_id.
@ -152,7 +152,7 @@ async def verify_user_access(
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: You can only access your own memories"
detail="Access denied: You can only access your own memories",
)
return authenticated_user_id

View file

@ -11,39 +11,87 @@ class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env",
case_sensitive=False,
extra='ignore'
env_file=".env", case_sensitive=False, extra="ignore"
)
# API Configuration
# Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env)
openai_api_key: str = Field(validation_alias=AliasChoices('OPENAI_API_KEY', 'OPENAI_COMPAT_API_KEY', 'openai_api_key'))
openai_base_url: str = Field(validation_alias=AliasChoices('OPENAI_BASE_URL', 'OPENAI_COMPAT_BASE_URL', 'openai_base_url'))
cohere_api_key: str = Field(validation_alias=AliasChoices('COHERE_API_KEY', 'cohere_api_key'))
openai_api_key: str = Field(
validation_alias=AliasChoices(
"OPENAI_API_KEY", "OPENAI_COMPAT_API_KEY", "openai_api_key"
)
)
openai_base_url: str = Field(
validation_alias=AliasChoices(
"OPENAI_BASE_URL", "OPENAI_COMPAT_BASE_URL", "openai_base_url"
)
)
cohere_api_key: str = Field(
validation_alias=AliasChoices("COHERE_API_KEY", "cohere_api_key")
)
# Database Configuration
qdrant_host: str = Field(default="localhost", validation_alias=AliasChoices('QDRANT_HOST', 'qdrant_host'))
qdrant_port: int = Field(default=6333, validation_alias=AliasChoices('QDRANT_PORT', 'qdrant_port'))
qdrant_collection_name: str = Field(default="mem0", validation_alias=AliasChoices('QDRANT_COLLECTION_NAME', 'qdrant_collection_name'))
qdrant_host: str = Field(
default="localhost", validation_alias=AliasChoices("QDRANT_HOST", "qdrant_host")
)
qdrant_port: int = Field(
default=6333, validation_alias=AliasChoices("QDRANT_PORT", "qdrant_port")
)
qdrant_collection_name: str = Field(
default="mem0",
validation_alias=AliasChoices(
"QDRANT_COLLECTION_NAME", "qdrant_collection_name"
),
)
# Neo4j Configuration
neo4j_uri: str = Field(default="bolt://localhost:7687", validation_alias=AliasChoices('NEO4J_URI', 'neo4j_uri'))
neo4j_username: str = Field(default="neo4j", validation_alias=AliasChoices('NEO4J_USERNAME', 'neo4j_username'))
neo4j_password: str = Field(default="mem0_neo4j_password", validation_alias=AliasChoices('NEO4J_PASSWORD', 'neo4j_password'))
neo4j_uri: str = Field(
default="bolt://localhost:7687",
validation_alias=AliasChoices("NEO4J_URI", "neo4j_uri"),
)
neo4j_username: str = Field(
default="neo4j",
validation_alias=AliasChoices("NEO4J_USERNAME", "neo4j_username"),
)
neo4j_password: str = Field(
default="mem0_neo4j_password",
validation_alias=AliasChoices("NEO4J_PASSWORD", "neo4j_password"),
)
# Application Configuration
log_level: str = Field(default="INFO", validation_alias=AliasChoices('LOG_LEVEL', 'log_level'))
cors_origins: str = Field(default="http://localhost:3000", validation_alias=AliasChoices('CORS_ORIGINS', 'cors_origins'))
log_level: str = Field(
default="INFO", validation_alias=AliasChoices("LOG_LEVEL", "log_level")
)
cors_origins: str = Field(
default="http://localhost:3000",
validation_alias=AliasChoices("CORS_ORIGINS", "cors_origins"),
)
# Model Configuration - Ultra-minimal (single model)
default_model: str = Field(default="claude-sonnet-4", validation_alias=AliasChoices('DEFAULT_MODEL', 'default_model'))
default_model: str = Field(
default="claude-sonnet-4",
validation_alias=AliasChoices("DEFAULT_MODEL", "default_model"),
)
# Embedder Configuration
ollama_base_url: str = Field(
default="http://host.docker.internal:11434",
validation_alias=AliasChoices("OLLAMA_BASE_URL", "ollama_base_url"),
)
embedding_model: str = Field(
default="qwen3-embedding:4b-q8_0",
validation_alias=AliasChoices("EMBEDDING_MODEL", "embedding_model"),
)
embedding_dims: int = Field(
default=2560, validation_alias=AliasChoices("EMBEDDING_DIMS", "embedding_dims")
)
# Authentication Configuration
# Format: JSON string mapping API keys to user IDs
# Example: {"api_key_123": "alice", "api_key_456": "bob"}
api_keys: str = Field(default="{}", validation_alias=AliasChoices('API_KEYS', 'api_keys'))
api_keys: str = Field(
default="{}", validation_alias=AliasChoices("API_KEYS", "api_keys")
)
@property
def cors_origins_list(self) -> List[str]:

File diff suppressed because it is too large Load diff

240
backend/mcp_server.py Normal file
View file

@ -0,0 +1,240 @@
"""MCP Server for Mem0 Memory Service.
Exposes memory operations as MCP tools over HTTP with API key authentication.
"""
import contextlib
import logging
from contextvars import ContextVar
from typing import Optional
from pydantic import Field
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse
from starlette.routing import Mount
from mcp.server.fastmcp import FastMCP
from config import settings
logger = logging.getLogger(__name__)
# Context variable for authenticated user_id
# Set by middleware, read by tools
current_user_id: ContextVar[str] = ContextVar("current_user_id", default="")
def get_authenticated_user() -> str:
"""Get the authenticated user_id from context.
Raises:
ValueError: If no authenticated user in context.
"""
user_id = current_user_id.get()
if not user_id:
raise ValueError("No authenticated user in context")
return user_id
class MCPAuthMiddleware(BaseHTTPMiddleware):
"""Middleware to authenticate MCP requests using API key."""
async def dispatch(self, request, call_next):
# Extract API key from headers
api_key = (
request.headers.get("x-api-key") or
request.headers.get("X-API-Key") or
request.headers.get("Authorization", "").replace("Bearer ", "")
)
if not api_key:
return JSONResponse(
{"error": "Missing API key. Provide x-api-key or Authorization header."},
status_code=401
)
# Map API key to user_id
user_id = settings.api_key_mapping.get(api_key)
if not user_id:
return JSONResponse(
{"error": "Invalid API key"},
status_code=401
)
# Store user_id in context for tools to access
token = current_user_id.set(user_id)
try:
response = await call_next(request)
return response
finally:
current_user_id.reset(token)
# Create FastMCP server
# streamable_http_path="/" since we mount at /mcp in main.py
mcp = FastMCP(
"Mem0 Memory Service",
stateless_http=True,
json_response=True,
streamable_http_path="/"
)
@mcp.tool()
async def add_memory(
content: str = Field(description="Content to add to memory"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"),
) -> dict:
"""Add content to the user's memory.
The user_id is automatically determined from the API key authentication.
Use agent_id and run_id for multi-agent or session-based memory organization.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}")
result = await mem0_manager.add_memories(
messages=[{"role": "user", "content": content}],
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=metadata
)
return result
@mcp.tool()
async def search_memory(
query: str = Field(description="Search query to find relevant memories"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"),
limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"),
) -> dict:
"""Search the user's memories.
Returns memories most relevant to the query. The user_id is automatically
determined from API key authentication.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}")
result = await mem0_manager.search_memories(
query=query,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit
)
return result
@mcp.tool()
async def remove_memory(
memory_id: str = Field(description="The ID of the memory to remove"),
) -> dict:
"""Remove a specific memory by its ID.
Only memories belonging to the authenticated user can be deleted.
Verifies ownership before deletion.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}")
# Verify ownership: get user's memories and check if memory_id exists
user_memories = await mem0_manager.get_user_memories(
user_id=user_id,
limit=10000 # Get all to check ownership
)
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
if memory_id not in memory_ids:
raise ValueError(f"Memory '{memory_id}' not found or access denied")
result = await mem0_manager.delete_memory(memory_id=memory_id)
return result
@mcp.tool()
async def chat(
message: str = Field(description="The user's message to chat with"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
) -> str:
"""Chat with memory context.
Retrieves relevant memories based on the message, generates a response
using the configured LLM, and stores the conversation in memory.
The user_id is automatically determined from API key authentication.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...")
result = await mem0_manager.chat_with_memory(
message=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id
)
return result.get("response", "")
@contextlib.asynccontextmanager
async def mcp_lifespan():
"""Context manager for MCP session lifecycle.
Must be used in the main FastAPI lifespan since mounted app
lifespans don't run automatically.
"""
async with mcp.session_manager.run():
logger.info("MCP session manager started")
yield
logger.info("MCP session manager stopped")
def create_mcp_app() -> Starlette:
"""Create and configure the MCP Starlette application.
Returns a Starlette app with MCP endpoints, authentication middleware,
and CORS support.
Note: The MCP session manager must be started via mcp_lifespan() in
the main FastAPI lifespan, not here.
"""
# Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp
streamable_app = mcp.streamable_http_app()
# Create Starlette app with MCP routes
app = Starlette(
routes=[Mount("/", app=streamable_app)],
)
# Add authentication middleware
app.add_middleware(MCPAuthMiddleware)
# Add CORS middleware for browser clients
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
allow_headers=["*"],
expose_headers=["Mcp-Session-Id"],
)
return app

View file

@ -5,24 +5,48 @@ from typing import Dict, List, Optional, Any
from datetime import datetime
from mem0 import Memory
from openai import OpenAI
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
)
import structlog
from config import settings
from monitoring import timed
logger = logging.getLogger(__name__)
logger = structlog.get_logger(__name__)
# Retry decorator for database operations (Qdrant, Neo4j)
db_retry = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
from mem0.llms.openai import OpenAILLM
_original_generate_response = OpenAILLM.generate_response
def patched_generate_response(self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs):
def patched_generate_response(
self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs
):
# Remove 'store' parameter as LiteLLM doesn't support it
if hasattr(self.config, 'store'):
if hasattr(self.config, "store"):
self.config.store = None
# Remove 'top_p' to avoid conflict with temperature for Claude models
if hasattr(self.config, 'top_p'):
if hasattr(self.config, "top_p"):
self.config.top_p = None
return _original_generate_response(self, messages, response_format, tools, tool_choice, **kwargs)
return _original_generate_response(
self, messages, response_format, tools, tool_choice, **kwargs
)
OpenAILLM.generate_response = patched_generate_response
logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter")
@ -36,8 +60,16 @@ class Mem0Manager:
def __init__(self):
# Custom endpoint configuration with graph memory enabled
logger.info("Initializing ultra-minimal Mem0Manager with custom endpoint with settings:", settings)
logger.info(
"Initializing Mem0Manager with custom endpoint",
model=settings.default_model,
embedding_model=settings.embedding_model,
embedding_dims=settings.embedding_dims,
qdrant_host=settings.qdrant_host,
neo4j_uri=settings.neo4j_uri,
)
config = {
"version": "v1.1",
"enable_graph": True,
"llm": {
"provider": "openai",
@ -46,17 +78,16 @@ class Mem0Manager:
"api_key": settings.openai_api_key,
"openai_base_url": settings.openai_base_url,
"temperature": 0.1,
"top_p": None # Don't use top_p with Claude models
}
"top_p": None,
},
},
"embedder": {
"provider": "ollama",
"config": {
"model": "hf.co/Qwen/Qwen3-Embedding-0.6B-GGUF:Q8_0",
# "api_key": settings.embedder_api_key,
"ollama_base_url": "https://models.breezehq.dev",
"embedding_dims": 1024
}
"model": settings.embedding_model,
"ollama_base_url": settings.ollama_base_url,
"embedding_dims": settings.embedding_dims,
},
},
"vector_store": {
"provider": "qdrant",
@ -64,38 +95,39 @@ class Mem0Manager:
"collection_name": settings.qdrant_collection_name,
"host": settings.qdrant_host,
"port": settings.qdrant_port,
"embedding_model_dims": 1024,
"on_disk": True
}
"embedding_model_dims": settings.embedding_dims,
"on_disk": True,
},
},
"graph_store": {
"provider": "neo4j",
"config": {
"url": settings.neo4j_uri,
"username": settings.neo4j_username,
"password": settings.neo4j_password
}
"password": settings.neo4j_password,
},
},
"reranker": {
"provider": "cohere",
"config": {
"api_key": settings.cohere_api_key,
"model": "rerank-english-v3.0",
"top_n": 10
}
}
"top_n": 10,
},
},
}
self.memory = Memory.from_config(config)
self.openai_client = OpenAI(
api_key=settings.openai_api_key,
base_url=settings.openai_base_url
base_url=settings.openai_base_url,
timeout=60.0, # 60 second timeout for LLM calls
max_retries=2, # Retry failed requests up to 2 times
)
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
# Pure passthrough methods - no custom logic
@db_retry
@timed("add_memories")
async def add_memories(
self,
@ -103,14 +135,14 @@ class Mem0Manager:
user_id: Optional[str] = "default",
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Add memories - simplified native Mem0 pattern (10 lines vs 45)."""
try:
# Convert ChatMessage objects to dict if needed
formatted_messages = []
for msg in messages:
if hasattr(msg, 'dict'):
if hasattr(msg, "dict"):
formatted_messages.append(msg.dict())
else:
formatted_messages.append(msg)
@ -123,26 +155,35 @@ class Mem0Manager:
"timestamp": datetime.now().isoformat(),
"source": "chat_conversation",
"message_count": len(formatted_messages),
"auto_generated": True
"auto_generated": True,
}
# Merge user metadata with auto metadata (user metadata takes precedence)
enhanced_metadata = {**auto_metadata, **combined_metadata}
# Direct Mem0 add with enhanced metadata
result = self.memory.add(formatted_messages, user_id=user_id,
agent_id=agent_id, run_id=run_id,
metadata=enhanced_metadata)
result = self.memory.add(
formatted_messages,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=enhanced_metadata,
)
return {
"added_memories": result if isinstance(result, list) else [result],
"message": "Memories added successfully",
"hierarchy": {"user_id": user_id, "agent_id": agent_id, "run_id": run_id}
"hierarchy": {
"user_id": user_id,
"agent_id": agent_id,
"run_id": run_id,
},
}
except Exception as e:
logger.error(f"Error adding memories: {e}")
raise e
raise
@db_retry
@timed("search_memories")
async def search_memories(
self,
@ -155,37 +196,79 @@ class Mem0Manager:
# rerank: bool = False,
# filter_memories: bool = False,
agent_id: Optional[str] = None,
run_id: Optional[str] = None
run_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Search memories - native Mem0 pattern"""
try:
# Minimal empty query protection for API compatibility
if not query or query.strip() == "":
return {"memories": [], "total_count": 0, "query": query, "note": "Empty query provided, no results returned. Use a specific query to search memories."}
return {
"memories": [],
"total_count": 0,
"query": query,
"note": "Empty query provided, no results returned. Use a specific query to search memories.",
}
# Direct Mem0 search - trust native handling
result = self.memory.search(query=query, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit, threshold=threshold, filters=filters)
return {"memories": result.get("results", []), "total_count": len(result.get("results", [])), "query": query}
result = self.memory.search(
query=query,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit,
threshold=threshold,
filters=filters,
)
return {
"memories": result.get("results", []),
"total_count": len(result.get("results", [])),
"query": query,
}
except Exception as e:
logger.error(f"Error searching memories: {e}")
raise e
raise
@db_retry
async def get_user_memories(
self,
user_id: str,
limit: int = 10,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None
filters: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""Get all memories for a user - native Mem0 pattern."""
try:
# Direct Mem0 get_all call - trust native parameter handling
result = self.memory.get_all(user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id, filters=filters)
result = self.memory.get_all(
user_id=user_id,
limit=limit,
agent_id=agent_id,
run_id=run_id,
filters=filters,
)
return result.get("results", [])
except Exception as e:
logger.error(f"Error getting user memories: {e}")
raise e
raise
@db_retry
async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
"""Get a single memory by ID. Returns None if not found."""
try:
result = self.memory.get(memory_id=memory_id)
return result
except Exception as e:
logger.debug(f"Memory {memory_id} not found or error: {e}")
return None
async def verify_memory_ownership(self, memory_id: str, user_id: str) -> bool:
"""Check if a memory belongs to a user. O(1) instead of O(n)."""
memory = await self.get_memory(memory_id)
if memory is None:
return False
return memory.get("user_id") == user_id
@db_retry
@timed("update_memory")
async def update_memory(
self,
@ -194,15 +277,13 @@ class Mem0Manager:
) -> Dict[str, Any]:
"""Update memory - pure Mem0 passthrough."""
try:
result = self.memory.update(
memory_id=memory_id,
data=content
)
result = self.memory.update(memory_id=memory_id, data=content)
return {"message": "Memory updated successfully", "result": result}
except Exception as e:
logger.error(f"Error updating memory: {e}")
raise e
raise
@db_retry
@timed("delete_memory")
async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
"""Delete memory - pure Mem0 passthrough."""
@ -211,7 +292,7 @@ class Mem0Manager:
return {"message": "Memory deleted successfully"}
except Exception as e:
logger.error(f"Error deleting memory: {e}")
raise e
raise
async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]:
"""Delete all user memories - pure Mem0 passthrough."""
@ -220,7 +301,7 @@ class Mem0Manager:
return {"message": "All user memories deleted successfully"}
except Exception as e:
logger.error(f"Error deleting user memories: {e}")
raise e
raise
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
"""Get memory change history - pure Mem0 passthrough."""
@ -229,22 +310,24 @@ class Mem0Manager:
return {
"memory_id": memory_id,
"history": history,
"message": "Memory history retrieved successfully"
"message": "Memory history retrieved successfully",
}
except Exception as e:
logger.error(f"Error getting memory history: {e}")
raise e
raise
async def get_graph_relationships(self, user_id: Optional[str], agent_id: Optional[str], run_id: Optional[str], limit: int = 50) -> Dict[str, Any]:
async def get_graph_relationships(
self,
user_id: Optional[str],
agent_id: Optional[str],
run_id: Optional[str],
limit: int = 50,
) -> Dict[str, Any]:
"""Get graph relationships - using correct Mem0 get_all() method."""
try:
# Use get_all() to retrieve memories with graph relationships
result = self.memory.get_all(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit
user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit
)
# Extract relationships from Mem0's response structure
@ -272,7 +355,7 @@ class Mem0Manager:
"agent_id": agent_id,
"run_id": run_id,
"total_memories": len(result.get("results", [])),
"total_relationships": len(relationships)
"total_relationships": len(relationships),
}
except Exception as e:
@ -286,7 +369,7 @@ class Mem0Manager:
"run_id": run_id,
"total_memories": 0,
"total_relationships": 0,
"error": str(e)
"error": str(e),
}
@timed("chat_with_memory")
@ -304,49 +387,70 @@ class Mem0Manager:
try:
total_start_time = time.time()
print(f"\n🚀 Starting chat request for user: {user_id}")
logger.info("Starting chat request", user_id=user_id)
# Stage 1: Memory Search
search_start_time = time.time()
search_result = self.memory.search(query=message, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=10, threshold=0.3)
search_result = self.memory.search(
query=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=10,
threshold=0.3,
)
relevant_memories = search_result.get("results", [])
memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories)
memories_str = "\n".join(
f"- {entry['memory']}" for entry in relevant_memories
)
search_time = time.time() - search_start_time
print(f"🔍 Memory search took: {search_time:.2f}s (found {len(relevant_memories)} memories)")
logger.debug(
"Memory search completed",
search_time_s=round(search_time, 2),
memories_found=len(relevant_memories),
)
# Stage 2: Prepare LLM messages
prep_start_time = time.time()
system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
messages = [{"role": "system", "content": system_prompt}]
# Add conversation context if provided (last 50 messages)
if context:
messages.extend(context)
print(f"📝 Added {len(context)} context messages")
logger.debug("Added context messages", context_count=len(context))
# Add current user message
messages.append({"role": "user", "content": message})
prep_time = time.time() - prep_start_time
print(f"📋 Message preparation took: {prep_time:.3f}s")
# Stage 3: LLM Call
llm_start_time = time.time()
response = self.openai_client.chat.completions.create(model=settings.default_model, messages=messages)
response = self.openai_client.chat.completions.create(
model=settings.default_model, messages=messages
)
assistant_response = response.choices[0].message.content
llm_time = time.time() - llm_start_time
print(f"🤖 LLM call took: {llm_time:.2f}s (model: {settings.default_model})")
logger.debug(
"LLM call completed",
llm_time_s=round(llm_time, 2),
model=settings.default_model,
)
# Stage 4: Memory Add
add_start_time = time.time()
memory_messages = [{"role": "user", "content": message}, {"role": "assistant", "content": assistant_response}]
memory_messages = [
{"role": "user", "content": message},
{"role": "assistant", "content": assistant_response},
]
self.memory.add(memory_messages, user_id=user_id)
add_time = time.time() - add_start_time
print(f"💾 Memory add took: {add_time:.2f}s")
# Total timing summary
total_time = time.time() - total_start_time
print(f"⏱️ TOTAL: {total_time:.2f}s | Search: {search_time:.2f}s | LLM: {llm_time:.2f}s | Add: {add_time:.2f}s | Prep: {prep_time:.3f}s")
print(f"📊 Breakdown: Search {(search_time/total_time)*100:.1f}% | LLM {(llm_time/total_time)*100:.1f}% | Add {(add_time/total_time)*100:.1f}%\n")
logger.info(
"Chat request completed",
user_id=user_id,
total_time_s=round(total_time, 2),
search_time_s=round(search_time, 2),
llm_time_s=round(llm_time, 2),
add_time_s=round(add_time, 2),
memories_used=len(relevant_memories),
model=settings.default_model,
)
return {
"response": assistant_response,
@ -356,17 +460,22 @@ class Mem0Manager:
"total": round(total_time, 2),
"search": round(search_time, 2),
"llm": round(llm_time, 2),
"add": round(add_time, 2)
}
"add": round(add_time, 2),
},
}
except Exception as e:
logger.error(f"Error in chat_with_memory: {e}")
logger.error(
"Error in chat_with_memory",
error=str(e),
user_id=user_id,
exc_info=True,
)
return {
"error": str(e),
"response": "I apologize, but I encountered an error processing your request.",
"memories_used": 0,
"model_used": None
"model_used": None,
}
async def health_check(self) -> Dict[str, str]:

View file

@ -2,53 +2,115 @@
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
import re
# Constants for input validation
MAX_MESSAGE_LENGTH = 50000 # ~12k tokens max per message
MAX_QUERY_LENGTH = 10000 # ~2.5k tokens max per query
MAX_USER_ID_LENGTH = 100 # Reasonable user ID length
MAX_MEMORY_ID_LENGTH = 100 # Memory IDs are typically UUIDs
MAX_CONTEXT_MESSAGES = 100 # Max conversation context messages
USER_ID_PATTERN = r"^[a-zA-Z0-9_\-\.@]+$" # Alphanumeric with common separators
# Request Models
class ChatMessage(BaseModel):
"""Chat message structure."""
role: str = Field(..., description="Message role (user, assistant, system)")
content: str = Field(..., description="Message content")
role: str = Field(
..., max_length=20, description="Message role (user, assistant, system)"
)
content: str = Field(
..., max_length=MAX_MESSAGE_LENGTH, description="Message content"
)
class ChatRequest(BaseModel):
"""Ultra-minimal chat request."""
message: str = Field(..., description="User message")
user_id: Optional[str] = Field("default", description="User identifier")
agent_id: Optional[str] = Field(None, description="Agent identifier")
run_id: Optional[str] = Field(None, description="Run identifier")
context: Optional[List[ChatMessage]] = Field(None, description="Previous conversation context")
message: str = Field(..., max_length=MAX_MESSAGE_LENGTH, description="User message")
user_id: Optional[str] = Field(
"default",
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier (alphanumeric, _, -, ., @)",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
context: Optional[List[ChatMessage]] = Field(
None,
max_length=MAX_CONTEXT_MESSAGES,
description="Previous conversation context (max 100 messages)",
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class MemoryAddRequest(BaseModel):
"""Request to add memories with hierarchy support - open-source compatible."""
messages: List[ChatMessage] = Field(..., description="Messages to process")
user_id: Optional[str] = Field("default", description="User identifier")
agent_id: Optional[str] = Field(None, description="Agent identifier")
run_id: Optional[str] = Field(None, description="Run identifier")
messages: List[ChatMessage] = Field(
...,
max_length=MAX_CONTEXT_MESSAGES,
description="Messages to process (max 100 messages)",
)
user_id: Optional[str] = Field(
"default",
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class MemorySearchRequest(BaseModel):
"""Request to search memories with hierarchy filtering."""
query: str = Field(..., description="Search query")
user_id: Optional[str] = Field("default", description="User identifier")
agent_id: Optional[str] = Field(None, description="Agent identifier")
run_id: Optional[str] = Field(None, description="Run identifier")
limit: int = Field(5, description="Maximum number of results")
threshold: Optional[float] = Field(None, description="Minimum relevance score")
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters")
# Hierarchy filters (open-source compatible)
agent_id: Optional[str] = Field(None, description="Filter by agent identifier")
run_id: Optional[str] = Field(None, description="Filter by run identifier")
query: str = Field(..., max_length=MAX_QUERY_LENGTH, description="Search query")
user_id: Optional[str] = Field(
"default",
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
limit: int = Field(5, ge=1, le=100, description="Maximum number of results (1-100)")
threshold: Optional[float] = Field(
None, ge=0.0, le=1.0, description="Minimum relevance score (0-1)"
)
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters")
class MemoryUpdateRequest(BaseModel):
"""Request to update a memory."""
memory_id: str = Field(..., description="Memory ID to update")
content: str = Field(..., description="New memory content")
memory_id: str = Field(
..., max_length=MAX_MEMORY_ID_LENGTH, description="Memory ID to update"
)
user_id: str = Field(
...,
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier for ownership verification",
)
content: str = Field(
..., max_length=MAX_MESSAGE_LENGTH, description="New memory content"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata")
@ -57,19 +119,23 @@ class MemoryUpdateRequest(BaseModel):
class MemoryItem(BaseModel):
"""Individual memory item."""
id: str = Field(..., description="Memory unique identifier")
memory: str = Field(..., description="Memory content")
user_id: Optional[str] = Field(None, description="Associated user ID")
agent_id: Optional[str] = Field(None, description="Associated agent ID")
run_id: Optional[str] = Field(None, description="Associated run ID")
metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata")
score: Optional[float] = Field(None, description="Relevance score (for search results)")
score: Optional[float] = Field(
None, description="Relevance score (for search results)"
)
created_at: Optional[str] = Field(None, description="Creation timestamp")
updated_at: Optional[str] = Field(None, description="Last update timestamp")
class MemorySearchResponse(BaseModel):
"""Memory search results - pure Mem0 structure."""
memories: List[MemoryItem] = Field(..., description="Found memories")
total_count: int = Field(..., description="Total number of memories found")
query: str = Field(..., description="Original search query")
@ -77,27 +143,37 @@ class MemorySearchResponse(BaseModel):
class MemoryAddResponse(BaseModel):
"""Response from adding memories - pure Mem0 structure."""
added_memories: List[Dict[str, Any]] = Field(..., description="Memories that were added")
added_memories: List[Dict[str, Any]] = Field(
..., description="Memories that were added"
)
message: str = Field(..., description="Success message")
class GraphRelationship(BaseModel):
"""Graph relationship structure."""
source: str = Field(..., description="Source entity")
relationship: str = Field(..., description="Relationship type")
target: str = Field(..., description="Target entity")
properties: Optional[Dict[str, Any]] = Field(None, description="Relationship properties")
properties: Optional[Dict[str, Any]] = Field(
None, description="Relationship properties"
)
class GraphResponse(BaseModel):
"""Graph relationships - pure Mem0 structure."""
relationships: List[GraphRelationship] = Field(..., description="Found relationships")
relationships: List[GraphRelationship] = Field(
..., description="Found relationships"
)
entities: List[str] = Field(..., description="Unique entities")
user_id: str = Field(..., description="User identifier")
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Service status")
services: Dict[str, str] = Field(..., description="Individual service statuses")
timestamp: str = Field(..., description="Health check timestamp")
@ -105,6 +181,7 @@ class HealthResponse(BaseModel):
class ErrorResponse(BaseModel):
"""Error response structure."""
error: str = Field(..., description="Error message")
detail: Optional[str] = Field(None, description="Detailed error information")
status_code: int = Field(..., description="HTTP status code")
@ -112,8 +189,10 @@ class ErrorResponse(BaseModel):
# Statistics and Monitoring Models
class MemoryOperationStats(BaseModel):
"""Memory operation statistics."""
add: int = Field(..., description="Number of add operations")
search: int = Field(..., description="Number of search operations")
update: int = Field(..., description="Number of update operations")
@ -122,34 +201,47 @@ class MemoryOperationStats(BaseModel):
class GlobalStatsResponse(BaseModel):
"""Global application statistics."""
total_memories: int = Field(..., description="Total memories across all users")
total_users: int = Field(..., description="Total number of users")
api_calls_today: int = Field(..., description="Total API calls today")
avg_response_time_ms: float = Field(..., description="Average response time in milliseconds")
memory_operations: MemoryOperationStats = Field(..., description="Memory operation breakdown")
avg_response_time_ms: float = Field(
..., description="Average response time in milliseconds"
)
memory_operations: MemoryOperationStats = Field(
..., description="Memory operation breakdown"
)
uptime_seconds: float = Field(..., description="Application uptime in seconds")
class UserStatsResponse(BaseModel):
"""User-specific statistics."""
user_id: str = Field(..., description="User identifier")
memory_count: int = Field(..., description="Number of memories for this user")
relationship_count: int = Field(..., description="Number of graph relationships for this user")
relationship_count: int = Field(
..., description="Number of graph relationships for this user"
)
last_activity: Optional[str] = Field(None, description="Last activity timestamp")
api_calls_today: int = Field(..., description="API calls made by this user today")
avg_response_time_ms: float = Field(..., description="Average response time for this user's requests")
avg_response_time_ms: float = Field(
..., description="Average response time for this user's requests"
)
# OpenAI-Compatible API Models
class OpenAIMessage(BaseModel):
"""OpenAI message format."""
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
class OpenAIChatCompletionRequest(BaseModel):
"""OpenAI chat completion request format."""
model: str = Field(..., description="Model to use (will use configured default)")
messages: List[Dict[str, str]] = Field(..., description="List of messages")
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
@ -160,11 +252,14 @@ class OpenAIChatCompletionRequest(BaseModel):
stop: Optional[List[str]] = Field(None, description="Stop sequences")
presence_penalty: Optional[float] = Field(0, description="Presence penalty")
frequency_penalty: Optional[float] = Field(0, description="Frequency penalty")
user: Optional[str] = Field(None, description="User identifier (ignored, uses API key)")
user: Optional[str] = Field(
None, description="User identifier (ignored, uses API key)"
)
class OpenAIUsage(BaseModel):
"""Token usage information."""
prompt_tokens: int = Field(..., description="Tokens in the prompt")
completion_tokens: int = Field(..., description="Tokens in the completion")
total_tokens: int = Field(..., description="Total tokens used")
@ -172,12 +267,14 @@ class OpenAIUsage(BaseModel):
class OpenAIChoiceMessage(BaseModel):
"""Message in a choice."""
role: str = Field(..., description="Role of the message")
content: str = Field(..., description="Content of the message")
class OpenAIChoice(BaseModel):
"""Individual completion choice."""
index: int = Field(..., description="Choice index")
message: OpenAIChoiceMessage = Field(..., description="Message content")
finish_reason: str = Field(..., description="Reason for completion finish")
@ -185,6 +282,7 @@ class OpenAIChoice(BaseModel):
class OpenAIChatCompletionResponse(BaseModel):
"""OpenAI chat completion response format."""
id: str = Field(..., description="Unique completion ID")
object: str = Field(default="chat.completion", description="Object type")
created: int = Field(..., description="Unix timestamp of creation")
@ -195,14 +293,19 @@ class OpenAIChatCompletionResponse(BaseModel):
# Streaming-specific models
class OpenAIStreamDelta(BaseModel):
"""Delta content in a streaming chunk."""
role: Optional[str] = Field(None, description="Role (only in first chunk)")
content: Optional[str] = Field(None, description="Incremental content")
class OpenAIStreamChoice(BaseModel):
"""Individual streaming choice."""
index: int = Field(..., description="Choice index")
delta: OpenAIStreamDelta = Field(..., description="Delta content")
finish_reason: Optional[str] = Field(None, description="Reason for completion finish")
finish_reason: Optional[str] = Field(
None, description="Reason for completion finish"
)

View file

@ -19,6 +19,7 @@ ollama
# Utilities
pydantic
pydantic-settings
tenacity
python-dotenv
httpx
aiofiles
@ -31,3 +32,9 @@ python-json-logger
# CORS and Security
python-jose[cryptography]
passlib[bcrypt]
# Rate Limiting
slowapi
# MCP Server
mcp[server]>=1.0.0

View file

@ -5,6 +5,8 @@ services:
container_name: mem0-qdrant
expose:
- "6333"
networks:
- mem0_network
volumes:
- qdrant_data:/qdrant/storage
command: >
@ -15,8 +17,6 @@ services:
timeout: 5s
retries: 5
restart: unless-stopped
networks:
- mem0_network
# Neo4j with APOC for graph relationships
neo4j:
@ -51,8 +51,8 @@ services:
# Backend API service
backend:
build:
context: .
dockerfile: ./backend/Dockerfile
context: ./backend
dockerfile: Dockerfile
container_name: mem0-backend
environment:
OPENAI_API_KEY: ${OPENAI_COMPAT_API_KEY}
@ -69,8 +69,11 @@ services:
CORS_ORIGINS: ${CORS_ORIGINS:-http://localhost:3000}
DEFAULT_MODEL: ${DEFAULT_MODEL:-claude-sonnet-4}
API_KEYS: ${API_KEYS:-{}}
OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://host.docker.internal:11434}
EMBEDDING_MODEL: ${EMBEDDING_MODEL:-nomic-embed-text}
EMBEDDING_DIMS: ${EMBEDDING_DIMS:-2560}
expose:
- 8000
- "8000"
networks:
- npm_network
- mem0_network

122
setup.sh Executable file
View file

@ -0,0 +1,122 @@
#!/bin/bash
set -e
COMPOSE_PROJECT_NAME="${COMPOSE_PROJECT_NAME:-mem0}"
VOLUMES=("qdrant_data" "neo4j_data" "neo4j_logs" "neo4j_import" "neo4j_plugins")
echo "=========================================="
echo " Mem0 Setup Script"
echo "=========================================="
echo ""
check_volumes_exist() {
local exists=false
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
exists=true
break
fi
done
echo "$exists"
}
list_existing_volumes() {
echo "Existing volumes:"
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
size=$(docker system df -v 2>/dev/null | grep "$full_name" | awk '{print $4}' || echo "unknown")
echo " - $full_name ($size)"
fi
done
}
remove_volumes() {
echo "Stopping containers..."
docker compose down 2>/dev/null || true
echo "Removing volumes..."
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
echo " Removing $full_name..."
docker volume rm "$full_name" 2>/dev/null || true
fi
done
echo "Volumes removed."
}
build_and_start() {
echo ""
echo "Building containers..."
docker compose build
echo ""
echo "Starting services..."
docker compose up -d
echo ""
echo "Waiting for services to be healthy..."
sleep 5
echo ""
echo "Checking health..."
curl -s http://localhost:8000/health 2>/dev/null | jq . || echo "Health check not available yet. Services may still be starting."
echo ""
echo "=========================================="
echo " Setup Complete!"
echo "=========================================="
echo ""
echo "Services:"
echo " - Backend API: http://localhost:8000"
echo " - API Docs: http://localhost:8000/docs"
echo " - Health Check: http://localhost:8000/health"
echo ""
echo "Logs: docker compose logs -f backend"
}
if [ "$(check_volumes_exist)" = "true" ]; then
echo "Existing data volumes detected!"
echo ""
list_existing_volumes
echo ""
echo "Options:"
echo " 1) Keep existing data and start services"
echo " 2) Reset everything (DELETE ALL DATA) and start fresh"
echo " 3) Exit"
echo ""
read -p "Choose an option [1/2/3]: " choice
case $choice in
1)
echo ""
echo "Keeping existing data..."
build_and_start
;;
2)
echo ""
read -p "Are you sure you want to DELETE ALL DATA? Type 'yes' to confirm: " confirm
if [ "$confirm" = "yes" ]; then
remove_volumes
build_and_start
else
echo "Aborted."
exit 1
fi
;;
3)
echo "Exiting."
exit 0
;;
*)
echo "Invalid option. Exiting."
exit 1
;;
esac
else
echo "No existing volumes found. Starting fresh setup..."
build_and_start
fi

View file

@ -19,13 +19,24 @@ import time
BASE_URL = "http://localhost:8000"
TEST_USER = f"test_user_{int(datetime.now().timestamp())}"
# API Key for authentication - set via environment or use default test key
import os
API_KEY = os.environ.get("MEM0_API_KEY", "test-api-key")
AUTH_HEADERS = {"X-API-Key": API_KEY}
def main():
parser = argparse.ArgumentParser(
description="Mem0 Integration Tests - Real API Testing (Zero Mocking)",
formatter_class=argparse.RawDescriptionHelpFormatter
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="Show detailed output and API responses",
)
parser.add_argument("--verbose", "-v", action="store_true",
help="Show detailed output and API responses")
args = parser.parse_args()
verbose = args.verbose
@ -39,6 +50,9 @@ def main():
# Test sequence - order matters for data dependencies
tests = [
test_health_check,
test_auth_required_endpoints,
test_ownership_verification,
test_request_size_limit,
test_empty_search_protection,
test_add_memories_with_hierarchy,
test_search_memories_basic,
@ -51,7 +65,7 @@ def main():
test_graph_relationships,
test_delete_specific_memory,
test_delete_all_user_memories,
test_cleanup_verification
test_cleanup_verification,
]
results = []
@ -82,6 +96,7 @@ def main():
print("❌ Some tests failed! Check the output above.")
sys.exit(1)
def run_test(name, test_func, verbose):
"""Run a single test with error handling"""
try:
@ -102,6 +117,7 @@ def run_test(name, test_func, verbose):
print(f"{name}: {e}")
return False
def log_response(response, verbose, context=""):
"""Log API response details if verbose"""
if verbose:
@ -111,22 +127,30 @@ def log_response(response, verbose, context=""):
if isinstance(data, dict) and len(data) < 5:
print(f" {context} Response: {data}")
else:
print(f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}")
print(
f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}"
)
except:
print(f" {context} Response: {response.text[:100]}...")
# ================== TEST FUNCTIONS ==================
def test_health_check(verbose):
"""Test service health endpoint"""
response = requests.get(f"{BASE_URL}/health", timeout=10)
response = requests.get(
f"{BASE_URL}/health", timeout=10
) # Health doesn't require auth
log_response(response, verbose, "Health")
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert "status" in data, "Health response missing 'status' field"
assert data["status"] in ["healthy", "degraded"], f"Invalid status: {data['status']}"
assert data["status"] in ["healthy", "degraded"], (
f"Invalid status: {data['status']}"
)
# Check individual services
assert "services" in data, "Health response missing 'services' field"
@ -136,18 +160,19 @@ def test_health_check(verbose):
for service, status in data["services"].items():
print(f" {service}: {status}")
def test_empty_search_protection(verbose):
"""Test empty query protection (should not return 500 error)"""
payload = {
"query": "",
"user_id": TEST_USER,
"limit": 5
}
payload = {"query": "", "user_id": TEST_USER, "limit": 5}
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=10)
response = requests.post(
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Empty Search")
assert response.status_code == 200, f"Empty query failed with {response.status_code}"
assert response.status_code == 200, (
f"Empty query failed with {response.status_code}"
)
data = response.json()
assert data["memories"] == [], "Empty query should return empty memories list"
@ -158,25 +183,39 @@ def test_empty_search_protection(verbose):
print(f" Empty search note: {data['note']}")
print(f" Total count: {data.get('total_count', 0)}")
def test_add_memories_with_hierarchy(verbose):
"""Test adding memories with multi-level hierarchy support"""
payload = {
"messages": [
{"role": "user", "content": "I work at TechCorp as a Senior Software Engineer"},
{"role": "user", "content": "My colleague Sarah from Marketing team helped with Q3 presentation"},
{"role": "user", "content": "Meeting with John the Product Manager tomorrow about new feature development"}
{
"role": "user",
"content": "I work at TechCorp as a Senior Software Engineer",
},
{
"role": "user",
"content": "My colleague Sarah from Marketing team helped with Q3 presentation",
},
{
"role": "user",
"content": "Meeting with John the Product Manager tomorrow about new feature development",
},
],
"user_id": TEST_USER,
"agent_id": "test_agent",
"run_id": "test_run_001",
"session_id": "test_session_001",
"metadata": {"test": "integration", "scenario": "work_context"}
"metadata": {"test": "integration", "scenario": "work_context"},
}
response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60)
response = requests.post(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
)
log_response(response, verbose, "Add Memories")
assert response.status_code == 200, f"Add memories failed with {response.status_code}"
assert response.status_code == 200, (
f"Add memories failed with {response.status_code}"
)
data = response.json()
assert "added_memories" in data, "Response missing 'added_memories'"
@ -191,23 +230,26 @@ def test_add_memories_with_hierarchy(verbose):
relations = first_memory["relations"]
if "added_entities" in relations and relations["added_entities"]:
if verbose:
print(f" Graph extracted: {len(relations['added_entities'])} relationships")
print(
f" Graph extracted: {len(relations['added_entities'])} relationships"
)
print(f" Sample relations: {relations['added_entities'][:3]}")
if verbose:
print(f" Added {len(memories)} memory blocks")
print(f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001")
print(
f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001"
)
def test_search_memories_basic(verbose):
"""Test basic memory search functionality"""
# Test meaningful search
payload = {
"query": "TechCorp",
"user_id": TEST_USER,
"limit": 10
}
payload = {"query": "TechCorp", "user_id": TEST_USER, "limit": 10}
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15)
response = requests.post(
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Search")
assert response.status_code == 200, f"Search failed with {response.status_code}"
@ -232,6 +274,7 @@ def test_search_memories_basic(verbose):
print(f" Found {data['total_count']} memories")
print(f" First memory: {memory['memory'][:50]}...")
def test_search_memories_hierarchy_filters(verbose):
"""Test multi-level hierarchy filtering in search"""
# Test with hierarchy filters
@ -241,13 +284,17 @@ def test_search_memories_hierarchy_filters(verbose):
"agent_id": "test_agent",
"run_id": "test_run_001",
"session_id": "test_session_001",
"limit": 10
"limit": 10,
}
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15)
response = requests.post(
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Hierarchy Search")
assert response.status_code == 200, f"Hierarchy search failed with {response.status_code}"
assert response.status_code == 200, (
f"Hierarchy search failed with {response.status_code}"
)
data = response.json()
assert "memories" in data, "Hierarchy search response missing 'memories'"
@ -257,7 +304,10 @@ def test_search_memories_hierarchy_filters(verbose):
if verbose:
print(f" Found {len(data['memories'])} memories with hierarchy filters")
print(f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001")
print(
f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001"
)
def test_get_user_memories_with_hierarchy(verbose):
"""Test retrieving user memories with hierarchy filtering"""
@ -266,13 +316,20 @@ def test_get_user_memories_with_hierarchy(verbose):
"limit": 20,
"agent_id": "test_agent",
"run_id": "test_run_001",
"session_id": "test_session_001"
"session_id": "test_session_001",
}
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}", params=params, timeout=15)
response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}",
params=params,
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Get User Memories with Hierarchy")
assert response.status_code == 200, f"Get user memories with hierarchy failed with {response.status_code}"
assert response.status_code == 200, (
f"Get user memories with hierarchy failed with {response.status_code}"
)
memories = response.json()
assert isinstance(memories, list), "User memories should return a list"
@ -290,10 +347,13 @@ def test_get_user_memories_with_hierarchy(verbose):
if verbose:
print(" No memories found with hierarchy filters (may be expected)")
def test_memory_history(verbose):
"""Test memory history endpoint"""
# First get a memory to check history for
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for history test"
memories = response.json()
@ -305,27 +365,38 @@ def test_memory_history(verbose):
memory_id = memories[0]["id"]
# Test memory history endpoint
response = requests.get(f"{BASE_URL}/memories/{memory_id}/history", timeout=15)
response = requests.get(
f"{BASE_URL}/memories/{memory_id}/history?user_id={TEST_USER}",
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Memory History")
assert response.status_code == 200, f"Memory history failed with {response.status_code}"
assert response.status_code == 200, (
f"Memory history failed with {response.status_code}"
)
data = response.json()
assert "memory_id" in data, "History response missing 'memory_id'"
assert "history" in data, "History response missing 'history'"
assert "message" in data, "History response missing success message"
assert data["memory_id"] == memory_id, f"Wrong memory_id in response: {data['memory_id']}"
assert data["memory_id"] == memory_id, (
f"Wrong memory_id in response: {data['memory_id']}"
)
if verbose:
print(f" Retrieved history for memory {memory_id}")
print(f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}")
print(
f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}"
)
def test_update_memory(verbose):
"""Test updating a specific memory"""
# First get a memory to update
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for update test"
memories = response.json()
@ -337,10 +408,13 @@ def test_update_memory(verbose):
# Update the memory
payload = {
"memory_id": memory_id,
"content": f"UPDATED: {original_content}"
"user_id": TEST_USER,
"content": f"UPDATED: {original_content}",
}
response = requests.put(f"{BASE_URL}/memories", json=payload, timeout=10)
response = requests.put(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Update")
assert response.status_code == 200, f"Update failed with {response.status_code}"
@ -352,15 +426,15 @@ def test_update_memory(verbose):
print(f" Updated memory {memory_id}")
print(f" Original: {original_content[:30]}...")
def test_chat_with_memory(verbose):
"""Test memory-enhanced chat functionality"""
payload = {
"message": "What company do I work for?",
"user_id": TEST_USER
}
payload = {"message": "What company do I work for?", "user_id": TEST_USER}
try:
response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=90)
response = requests.post(
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=90
)
log_response(response, verbose, "Chat")
assert response.status_code == 200, f"Chat failed with {response.status_code}"
@ -383,12 +457,15 @@ def test_chat_with_memory(verbose):
print(" Chat endpoint timed out (LLM API may be slow)")
# Still test that the endpoint exists and accepts requests
try:
response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=5)
response = requests.post(
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=5
)
except requests.exceptions.ReadTimeout:
# This is expected - endpoint exists but processing is slow
if verbose:
print(" Chat endpoint confirmed active (processing timeout expected)")
def test_graph_relationships_creation(verbose):
"""Test graph relationships creation with entity-rich memories"""
# Create a separate test user for graph relationship testing
@ -397,41 +474,67 @@ def test_graph_relationships_creation(verbose):
# Add memories with clear entity relationships
payload = {
"messages": [
{"role": "user", "content": "John Smith works at Microsoft as a Senior Software Engineer"},
{"role": "user", "content": "John Smith is friends with Sarah Johnson who works at Google"},
{"role": "user", "content": "Sarah Johnson lives in Seattle and loves hiking"},
{
"role": "user",
"content": "John Smith works at Microsoft as a Senior Software Engineer",
},
{
"role": "user",
"content": "John Smith is friends with Sarah Johnson who works at Google",
},
{
"role": "user",
"content": "Sarah Johnson lives in Seattle and loves hiking",
},
{"role": "user", "content": "Microsoft is located in Redmond, Washington"},
{"role": "user", "content": "John Smith and Sarah Johnson both graduated from Stanford University"}
{
"role": "user",
"content": "John Smith and Sarah Johnson both graduated from Stanford University",
},
],
"user_id": graph_test_user,
"metadata": {"test": "graph_relationships", "scenario": "entity_creation"}
"metadata": {"test": "graph_relationships", "scenario": "entity_creation"},
}
response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60)
response = requests.post(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
)
log_response(response, verbose, "Add Graph Memories")
assert response.status_code == 200, f"Add graph memories failed with {response.status_code}"
assert response.status_code == 200, (
f"Add graph memories failed with {response.status_code}"
)
data = response.json()
assert "added_memories" in data, "Response missing 'added_memories'"
if verbose:
print(f" Added {len(data['added_memories'])} memories for graph relationship testing")
print(
f" Added {len(data['added_memories'])} memories for graph relationship testing"
)
# Wait a moment for graph processing (Mem0 graph extraction can be async)
time.sleep(2)
# Test graph relationships endpoint
response = requests.get(f"{BASE_URL}/graph/relationships/{graph_test_user}", timeout=15)
response = requests.get(
f"{BASE_URL}/graph/relationships/{graph_test_user}",
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Graph Relationships")
assert response.status_code == 200, f"Graph relationships failed with {response.status_code}"
assert response.status_code == 200, (
f"Graph relationships failed with {response.status_code}"
)
graph_data = response.json()
assert "relationships" in graph_data, "Graph response missing 'relationships'"
assert "entities" in graph_data, "Graph response missing 'entities'"
assert "user_id" in graph_data, "Graph response missing 'user_id'"
assert graph_data["user_id"] == graph_test_user, f"Wrong user_id in graph: {graph_data['user_id']}"
assert graph_data["user_id"] == graph_test_user, (
f"Wrong user_id in graph: {graph_data['user_id']}"
)
relationships = graph_data["relationships"]
entities = graph_data["entities"]
@ -447,20 +550,28 @@ def test_graph_relationships_creation(verbose):
source = rel.get("source", "unknown")
target = rel.get("target", "unknown")
relationship = rel.get("relationship", "unknown")
print(f" {i+1}. {source} --{relationship}--> {target}")
print(f" {i + 1}. {source} --{relationship}--> {target}")
# Print sample entities if they exist
if entities:
print(f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}")
print(
f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}"
)
# Verify relationship structure (if relationships exist)
for rel in relationships:
assert "source" in rel or "from" in rel, f"Relationship missing source/from: {rel}"
assert "source" in rel or "from" in rel, (
f"Relationship missing source/from: {rel}"
)
assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}"
assert "relationship" in rel or "type" in rel, f"Relationship missing type: {rel}"
assert "relationship" in rel or "type" in rel, (
f"Relationship missing type: {rel}"
)
# Clean up graph test user memories
cleanup_response = requests.delete(f"{BASE_URL}/memories/user/{graph_test_user}", timeout=15)
cleanup_response = requests.delete(
f"{BASE_URL}/memories/user/{graph_test_user}", headers=AUTH_HEADERS, timeout=15
)
assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories"
if verbose:
@ -469,12 +580,17 @@ def test_graph_relationships_creation(verbose):
# Note: We expect some relationships even if graph extraction is basic
# The test passes if the endpoint works and returns proper structure
def test_graph_relationships(verbose):
"""Test graph relationships endpoint"""
response = requests.get(f"{BASE_URL}/graph/relationships/{TEST_USER}", timeout=15)
response = requests.get(
f"{BASE_URL}/graph/relationships/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Graph")
assert response.status_code == 200, f"Graph endpoint failed with {response.status_code}"
assert response.status_code == 200, (
f"Graph endpoint failed with {response.status_code}"
)
data = response.json()
assert "relationships" in data, "Graph response missing 'relationships'"
@ -486,10 +602,13 @@ def test_graph_relationships(verbose):
print(f" Relationships: {len(data['relationships'])}")
print(f" Entities: {len(data['entities'])}")
def test_delete_specific_memory(verbose):
"""Test deleting a specific memory"""
# Get a memory to delete
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for deletion test"
memories = response.json()
@ -498,7 +617,9 @@ def test_delete_specific_memory(verbose):
memory_id = memories[0]["id"]
# Delete the memory
response = requests.delete(f"{BASE_URL}/memories/{memory_id}", timeout=10)
response = requests.delete(
f"{BASE_URL}/memories/{memory_id}", headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Delete")
assert response.status_code == 200, f"Delete failed with {response.status_code}"
@ -509,9 +630,12 @@ def test_delete_specific_memory(verbose):
if verbose:
print(f" Deleted memory {memory_id}")
def test_delete_all_user_memories(verbose):
"""Test deleting all memories for a user"""
response = requests.delete(f"{BASE_URL}/memories/user/{TEST_USER}", timeout=15)
response = requests.delete(
f"{BASE_URL}/memories/user/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Delete All")
assert response.status_code == 200, f"Delete all failed with {response.status_code}"
@ -522,12 +646,17 @@ def test_delete_all_user_memories(verbose):
if verbose:
print(f"Deleted all memories for {TEST_USER}")
def test_cleanup_verification(verbose):
"""Verify cleanup was successful"""
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=10", timeout=10)
response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=10", headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Cleanup Check")
assert response.status_code == 200, f"Cleanup verification failed with {response.status_code}"
assert response.status_code == 200, (
f"Cleanup verification failed with {response.status_code}"
)
memories = response.json()
assert isinstance(memories, list), "Should return list even if empty"
@ -539,5 +668,79 @@ def test_cleanup_verification(verbose):
if verbose:
print(" Cleanup successful - no memories remain")
# ================== SECURITY TEST FUNCTIONS ==================
def test_auth_required_endpoints(verbose):
"""Test that protected endpoints require authentication"""
endpoints_requiring_auth = [
("GET", f"{BASE_URL}/memories/{TEST_USER}"),
("POST", f"{BASE_URL}/memories/search"),
("GET", f"{BASE_URL}/stats"),
("GET", f"{BASE_URL}/models"),
("GET", f"{BASE_URL}/users"),
]
for method, url in endpoints_requiring_auth:
if method == "GET":
response = requests.get(url, timeout=5)
else:
response = requests.post(
url, json={"query": "test", "user_id": TEST_USER}, timeout=5
)
assert response.status_code in [401, 403], (
f"{method} {url} should require auth, got {response.status_code}"
)
if verbose:
print(f" {method} {url}: {response.status_code} (auth required)")
def test_ownership_verification(verbose):
"""Test that users can only access their own data"""
other_user = "other_user_not_me"
response = requests.get(
f"{BASE_URL}/memories/{other_user}", headers=AUTH_HEADERS, timeout=5
)
assert response.status_code in [403, 404], (
f"Accessing other user's memories should be denied, got {response.status_code}"
)
if verbose:
print(f" Ownership check passed: {response.status_code}")
def test_request_size_limit(verbose):
"""Test request size limit enforcement (10MB max)"""
large_payload = {
"messages": [{"role": "user", "content": "x" * (11 * 1024 * 1024)}],
"user_id": TEST_USER,
}
try:
response = requests.post(
f"{BASE_URL}/memories",
json=large_payload,
headers={**AUTH_HEADERS, "Content-Length": str(11 * 1024 * 1024)},
timeout=5,
)
assert response.status_code == 413, (
f"Large request should return 413, got {response.status_code}"
)
if verbose:
print(f" Request size limit enforced: {response.status_code}")
except requests.exceptions.RequestException as e:
if verbose:
print(
f" Request size limit test: connection issue (expected for large payload)"
)
if __name__ == "__main__":
main()