Compare commits

..

1 commit

Author SHA1 Message Date
2045a042eb Working frontend and openai compatible endpoint 2025-10-27 15:29:55 +00:00
11 changed files with 615 additions and 1598 deletions

View file

@ -11,11 +11,14 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy requirements and install Python dependencies # Copy requirements and install Python dependencies
COPY requirements.txt . COPY backend/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt RUN pip install --no-cache-dir -r requirements.txt
# Copy application code # Copy backend code
COPY . . COPY backend/ .
# Copy frontend directory
COPY frontend/ /app/frontend/
# Set Python path # Set Python path
ENV PYTHONPATH=/app ENV PYTHONPATH=/app

View file

@ -19,9 +19,7 @@ class AuthService:
def __init__(self): def __init__(self):
"""Initialize auth service with API key to user mapping.""" """Initialize auth service with API key to user mapping."""
self.api_key_to_user = settings.api_key_mapping self.api_key_to_user = settings.api_key_mapping
logger.info( logger.info(f"Auth service initialized with {len(self.api_key_to_user)} API keys")
f"Auth service initialized with {len(self.api_key_to_user)} API keys"
)
def verify_api_key(self, api_key: str) -> str: def verify_api_key(self, api_key: str) -> str:
""" """
@ -39,7 +37,8 @@ class AuthService:
if api_key not in self.api_key_to_user: if api_key not in self.api_key_to_user:
logger.warning(f"Invalid API key attempted: {api_key[:10]}...") logger.warning(f"Invalid API key attempted: {api_key[:10]}...")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
) )
user_id = self.api_key_to_user[api_key] user_id = self.api_key_to_user[api_key]
@ -69,7 +68,7 @@ class AuthService:
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied: You can only access your own memories", detail=f"Access denied: You can only access your own memories"
) )
return authenticated_user_id return authenticated_user_id
@ -94,7 +93,7 @@ async def get_current_user(api_key: str = Security(api_key_header)) -> str:
async def get_current_user_openai( async def get_current_user_openai(
authorization: Optional[str] = Header(None), authorization: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None, alias="X-API-Key"), x_api_key: Optional[str] = Header(None, alias="X-API-Key")
) -> str: ) -> str:
""" """
FastAPI dependency for OpenAI-compatible authentication. FastAPI dependency for OpenAI-compatible authentication.
@ -115,23 +114,24 @@ async def get_current_user_openai(
# Try Bearer token first (OpenAI standard) # Try Bearer token first (OpenAI standard)
if authorization and authorization.startswith("Bearer "): if authorization and authorization.startswith("Bearer "):
api_key = authorization[7:] # Remove "Bearer " prefix api_key = authorization[7:] # Remove "Bearer " prefix
logger.debug("Extracted API key from Authorization Bearer token") logger.debug(f"Extracted API key from Authorization Bearer token")
# Fall back to X-API-Key header # Fall back to X-API-Key header
elif x_api_key: elif x_api_key:
api_key = x_api_key api_key = x_api_key
logger.debug("Extracted API key from X-API-Key header") logger.debug(f"Extracted API key from X-API-Key header")
else: else:
logger.warning("No API key provided in Authorization or X-API-Key headers") logger.warning("No API key provided in Authorization or X-API-Key headers")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API key. Provide either 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header", detail="Missing API key. Provide either 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header"
) )
return auth_service.verify_api_key(api_key) return auth_service.verify_api_key(api_key)
async def verify_user_access( async def verify_user_access(
api_key: str = Security(api_key_header), user_id: Optional[str] = None api_key: str = Security(api_key_header),
user_id: Optional[str] = None
) -> str: ) -> str:
""" """
FastAPI dependency to verify user can access the requested user_id. FastAPI dependency to verify user can access the requested user_id.
@ -152,7 +152,7 @@ async def verify_user_access(
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: You can only access your own memories", detail="Access denied: You can only access your own memories"
) )
return authenticated_user_id return authenticated_user_id

View file

@ -11,87 +11,39 @@ class Settings(BaseSettings):
"""Application settings loaded from environment variables.""" """Application settings loaded from environment variables."""
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", case_sensitive=False, extra="ignore" env_file=".env",
case_sensitive=False,
extra='ignore'
) )
# API Configuration # API Configuration
# Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env) # Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env)
openai_api_key: str = Field( openai_api_key: str = Field(validation_alias=AliasChoices('OPENAI_API_KEY', 'OPENAI_COMPAT_API_KEY', 'openai_api_key'))
validation_alias=AliasChoices( openai_base_url: str = Field(validation_alias=AliasChoices('OPENAI_BASE_URL', 'OPENAI_COMPAT_BASE_URL', 'openai_base_url'))
"OPENAI_API_KEY", "OPENAI_COMPAT_API_KEY", "openai_api_key" cohere_api_key: str = Field(validation_alias=AliasChoices('COHERE_API_KEY', 'cohere_api_key'))
)
)
openai_base_url: str = Field(
validation_alias=AliasChoices(
"OPENAI_BASE_URL", "OPENAI_COMPAT_BASE_URL", "openai_base_url"
)
)
cohere_api_key: str = Field(
validation_alias=AliasChoices("COHERE_API_KEY", "cohere_api_key")
)
# Database Configuration # Database Configuration
qdrant_host: str = Field( qdrant_host: str = Field(default="localhost", validation_alias=AliasChoices('QDRANT_HOST', 'qdrant_host'))
default="localhost", validation_alias=AliasChoices("QDRANT_HOST", "qdrant_host") qdrant_port: int = Field(default=6333, validation_alias=AliasChoices('QDRANT_PORT', 'qdrant_port'))
) qdrant_collection_name: str = Field(default="mem0", validation_alias=AliasChoices('QDRANT_COLLECTION_NAME', 'qdrant_collection_name'))
qdrant_port: int = Field(
default=6333, validation_alias=AliasChoices("QDRANT_PORT", "qdrant_port")
)
qdrant_collection_name: str = Field(
default="mem0",
validation_alias=AliasChoices(
"QDRANT_COLLECTION_NAME", "qdrant_collection_name"
),
)
# Neo4j Configuration # Neo4j Configuration
neo4j_uri: str = Field( neo4j_uri: str = Field(default="bolt://localhost:7687", validation_alias=AliasChoices('NEO4J_URI', 'neo4j_uri'))
default="bolt://localhost:7687", neo4j_username: str = Field(default="neo4j", validation_alias=AliasChoices('NEO4J_USERNAME', 'neo4j_username'))
validation_alias=AliasChoices("NEO4J_URI", "neo4j_uri"), neo4j_password: str = Field(default="mem0_neo4j_password", validation_alias=AliasChoices('NEO4J_PASSWORD', 'neo4j_password'))
)
neo4j_username: str = Field(
default="neo4j",
validation_alias=AliasChoices("NEO4J_USERNAME", "neo4j_username"),
)
neo4j_password: str = Field(
default="mem0_neo4j_password",
validation_alias=AliasChoices("NEO4J_PASSWORD", "neo4j_password"),
)
# Application Configuration # Application Configuration
log_level: str = Field( log_level: str = Field(default="INFO", validation_alias=AliasChoices('LOG_LEVEL', 'log_level'))
default="INFO", validation_alias=AliasChoices("LOG_LEVEL", "log_level") cors_origins: str = Field(default="http://localhost:3000", validation_alias=AliasChoices('CORS_ORIGINS', 'cors_origins'))
)
cors_origins: str = Field(
default="http://localhost:3000",
validation_alias=AliasChoices("CORS_ORIGINS", "cors_origins"),
)
# Model Configuration - Ultra-minimal (single model) # Model Configuration - Ultra-minimal (single model)
default_model: str = Field( default_model: str = Field(default="claude-sonnet-4", validation_alias=AliasChoices('DEFAULT_MODEL', 'default_model'))
default="claude-sonnet-4",
validation_alias=AliasChoices("DEFAULT_MODEL", "default_model"),
)
# Embedder Configuration
ollama_base_url: str = Field(
default="http://host.docker.internal:11434",
validation_alias=AliasChoices("OLLAMA_BASE_URL", "ollama_base_url"),
)
embedding_model: str = Field(
default="qwen3-embedding:4b-q8_0",
validation_alias=AliasChoices("EMBEDDING_MODEL", "embedding_model"),
)
embedding_dims: int = Field(
default=2560, validation_alias=AliasChoices("EMBEDDING_DIMS", "embedding_dims")
)
# Authentication Configuration # Authentication Configuration
# Format: JSON string mapping API keys to user IDs # Format: JSON string mapping API keys to user IDs
# Example: {"api_key_123": "alice", "api_key_456": "bob"} # Example: {"api_key_123": "alice", "api_key_456": "bob"}
api_keys: str = Field( api_keys: str = Field(default="{}", validation_alias=AliasChoices('API_KEYS', 'api_keys'))
default="{}", validation_alias=AliasChoices("API_KEYS", "api_keys")
)
@property @property
def cors_origins_list(self) -> List[str]: def cors_origins_list(self) -> List[str]:

File diff suppressed because it is too large Load diff

View file

@ -1,240 +0,0 @@
"""MCP Server for Mem0 Memory Service.
Exposes memory operations as MCP tools over HTTP with API key authentication.
"""
import contextlib
import logging
from contextvars import ContextVar
from typing import Optional
from pydantic import Field
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse
from starlette.routing import Mount
from mcp.server.fastmcp import FastMCP
from config import settings
logger = logging.getLogger(__name__)
# Context variable for authenticated user_id
# Set by middleware, read by tools
current_user_id: ContextVar[str] = ContextVar("current_user_id", default="")
def get_authenticated_user() -> str:
"""Get the authenticated user_id from context.
Raises:
ValueError: If no authenticated user in context.
"""
user_id = current_user_id.get()
if not user_id:
raise ValueError("No authenticated user in context")
return user_id
class MCPAuthMiddleware(BaseHTTPMiddleware):
"""Middleware to authenticate MCP requests using API key."""
async def dispatch(self, request, call_next):
# Extract API key from headers
api_key = (
request.headers.get("x-api-key") or
request.headers.get("X-API-Key") or
request.headers.get("Authorization", "").replace("Bearer ", "")
)
if not api_key:
return JSONResponse(
{"error": "Missing API key. Provide x-api-key or Authorization header."},
status_code=401
)
# Map API key to user_id
user_id = settings.api_key_mapping.get(api_key)
if not user_id:
return JSONResponse(
{"error": "Invalid API key"},
status_code=401
)
# Store user_id in context for tools to access
token = current_user_id.set(user_id)
try:
response = await call_next(request)
return response
finally:
current_user_id.reset(token)
# Create FastMCP server
# streamable_http_path="/" since we mount at /mcp in main.py
mcp = FastMCP(
"Mem0 Memory Service",
stateless_http=True,
json_response=True,
streamable_http_path="/"
)
@mcp.tool()
async def add_memory(
content: str = Field(description="Content to add to memory"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"),
) -> dict:
"""Add content to the user's memory.
The user_id is automatically determined from the API key authentication.
Use agent_id and run_id for multi-agent or session-based memory organization.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}")
result = await mem0_manager.add_memories(
messages=[{"role": "user", "content": content}],
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=metadata
)
return result
@mcp.tool()
async def search_memory(
query: str = Field(description="Search query to find relevant memories"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"),
limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"),
) -> dict:
"""Search the user's memories.
Returns memories most relevant to the query. The user_id is automatically
determined from API key authentication.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}")
result = await mem0_manager.search_memories(
query=query,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit
)
return result
@mcp.tool()
async def remove_memory(
memory_id: str = Field(description="The ID of the memory to remove"),
) -> dict:
"""Remove a specific memory by its ID.
Only memories belonging to the authenticated user can be deleted.
Verifies ownership before deletion.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}")
# Verify ownership: get user's memories and check if memory_id exists
user_memories = await mem0_manager.get_user_memories(
user_id=user_id,
limit=10000 # Get all to check ownership
)
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
if memory_id not in memory_ids:
raise ValueError(f"Memory '{memory_id}' not found or access denied")
result = await mem0_manager.delete_memory(memory_id=memory_id)
return result
@mcp.tool()
async def chat(
message: str = Field(description="The user's message to chat with"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
) -> str:
"""Chat with memory context.
Retrieves relevant memories based on the message, generates a response
using the configured LLM, and stores the conversation in memory.
The user_id is automatically determined from API key authentication.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...")
result = await mem0_manager.chat_with_memory(
message=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id
)
return result.get("response", "")
@contextlib.asynccontextmanager
async def mcp_lifespan():
"""Context manager for MCP session lifecycle.
Must be used in the main FastAPI lifespan since mounted app
lifespans don't run automatically.
"""
async with mcp.session_manager.run():
logger.info("MCP session manager started")
yield
logger.info("MCP session manager stopped")
def create_mcp_app() -> Starlette:
"""Create and configure the MCP Starlette application.
Returns a Starlette app with MCP endpoints, authentication middleware,
and CORS support.
Note: The MCP session manager must be started via mcp_lifespan() in
the main FastAPI lifespan, not here.
"""
# Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp
streamable_app = mcp.streamable_http_app()
# Create Starlette app with MCP routes
app = Starlette(
routes=[Mount("/", app=streamable_app)],
)
# Add authentication middleware
app.add_middleware(MCPAuthMiddleware)
# Add CORS middleware for browser clients
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
allow_headers=["*"],
expose_headers=["Mcp-Session-Id"],
)
return app

View file

@ -5,48 +5,24 @@ from typing import Dict, List, Optional, Any
from datetime import datetime from datetime import datetime
from mem0 import Memory from mem0 import Memory
from openai import OpenAI from openai import OpenAI
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
)
import structlog
from config import settings from config import settings
from monitoring import timed from monitoring import timed
logger = structlog.get_logger(__name__) logger = logging.getLogger(__name__)
# Retry decorator for database operations (Qdrant, Neo4j)
db_retry = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility # Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
from mem0.llms.openai import OpenAILLM from mem0.llms.openai import OpenAILLM
_original_generate_response = OpenAILLM.generate_response _original_generate_response = OpenAILLM.generate_response
def patched_generate_response(self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs):
def patched_generate_response(
self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs
):
# Remove 'store' parameter as LiteLLM doesn't support it # Remove 'store' parameter as LiteLLM doesn't support it
if hasattr(self.config, "store"): if hasattr(self.config, 'store'):
self.config.store = None self.config.store = None
# Remove 'top_p' to avoid conflict with temperature for Claude models # Remove 'top_p' to avoid conflict with temperature for Claude models
if hasattr(self.config, "top_p"): if hasattr(self.config, 'top_p'):
self.config.top_p = None self.config.top_p = None
return _original_generate_response( return _original_generate_response(self, messages, response_format, tools, tool_choice, **kwargs)
self, messages, response_format, tools, tool_choice, **kwargs
)
OpenAILLM.generate_response = patched_generate_response OpenAILLM.generate_response = patched_generate_response
logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter") logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter")
@ -60,16 +36,8 @@ class Mem0Manager:
def __init__(self): def __init__(self):
# Custom endpoint configuration with graph memory enabled # Custom endpoint configuration with graph memory enabled
logger.info( logger.info("Initializing ultra-minimal Mem0Manager with custom endpoint with settings:", settings)
"Initializing Mem0Manager with custom endpoint",
model=settings.default_model,
embedding_model=settings.embedding_model,
embedding_dims=settings.embedding_dims,
qdrant_host=settings.qdrant_host,
neo4j_uri=settings.neo4j_uri,
)
config = { config = {
"version": "v1.1",
"enable_graph": True, "enable_graph": True,
"llm": { "llm": {
"provider": "openai", "provider": "openai",
@ -78,16 +46,17 @@ class Mem0Manager:
"api_key": settings.openai_api_key, "api_key": settings.openai_api_key,
"openai_base_url": settings.openai_base_url, "openai_base_url": settings.openai_base_url,
"temperature": 0.1, "temperature": 0.1,
"top_p": None, "top_p": None # Don't use top_p with Claude models
}, }
}, },
"embedder": { "embedder": {
"provider": "ollama", "provider": "ollama",
"config": { "config": {
"model": settings.embedding_model, "model": "hf.co/Qwen/Qwen3-Embedding-0.6B-GGUF:Q8_0",
"ollama_base_url": settings.ollama_base_url, # "api_key": settings.embedder_api_key,
"embedding_dims": settings.embedding_dims, "ollama_base_url": "https://models.breezehq.dev",
}, "embedding_dims": 1024
}
}, },
"vector_store": { "vector_store": {
"provider": "qdrant", "provider": "qdrant",
@ -95,39 +64,38 @@ class Mem0Manager:
"collection_name": settings.qdrant_collection_name, "collection_name": settings.qdrant_collection_name,
"host": settings.qdrant_host, "host": settings.qdrant_host,
"port": settings.qdrant_port, "port": settings.qdrant_port,
"embedding_model_dims": settings.embedding_dims, "embedding_model_dims": 1024,
"on_disk": True, "on_disk": True
}, }
}, },
"graph_store": { "graph_store": {
"provider": "neo4j", "provider": "neo4j",
"config": { "config": {
"url": settings.neo4j_uri, "url": settings.neo4j_uri,
"username": settings.neo4j_username, "username": settings.neo4j_username,
"password": settings.neo4j_password, "password": settings.neo4j_password
}, }
}, },
"reranker": { "reranker": {
"provider": "cohere", "provider": "cohere",
"config": { "config": {
"api_key": settings.cohere_api_key, "api_key": settings.cohere_api_key,
"model": "rerank-english-v3.0", "model": "rerank-english-v3.0",
"top_n": 10, "top_n": 10
}, }
}, }
} }
self.memory = Memory.from_config(config) self.memory = Memory.from_config(config)
self.openai_client = OpenAI( self.openai_client = OpenAI(
api_key=settings.openai_api_key, api_key=settings.openai_api_key,
base_url=settings.openai_base_url, base_url=settings.openai_base_url
timeout=60.0, # 60 second timeout for LLM calls
max_retries=2, # Retry failed requests up to 2 times
) )
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint") logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
# Pure passthrough methods - no custom logic # Pure passthrough methods - no custom logic
@db_retry
@timed("add_memories") @timed("add_memories")
async def add_memories( async def add_memories(
self, self,
@ -135,14 +103,14 @@ class Mem0Manager:
user_id: Optional[str] = "default", user_id: Optional[str] = "default",
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Add memories - simplified native Mem0 pattern (10 lines vs 45).""" """Add memories - simplified native Mem0 pattern (10 lines vs 45)."""
try: try:
# Convert ChatMessage objects to dict if needed # Convert ChatMessage objects to dict if needed
formatted_messages = [] formatted_messages = []
for msg in messages: for msg in messages:
if hasattr(msg, "dict"): if hasattr(msg, 'dict'):
formatted_messages.append(msg.dict()) formatted_messages.append(msg.dict())
else: else:
formatted_messages.append(msg) formatted_messages.append(msg)
@ -155,35 +123,26 @@ class Mem0Manager:
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"source": "chat_conversation", "source": "chat_conversation",
"message_count": len(formatted_messages), "message_count": len(formatted_messages),
"auto_generated": True, "auto_generated": True
} }
# Merge user metadata with auto metadata (user metadata takes precedence) # Merge user metadata with auto metadata (user metadata takes precedence)
enhanced_metadata = {**auto_metadata, **combined_metadata} enhanced_metadata = {**auto_metadata, **combined_metadata}
# Direct Mem0 add with enhanced metadata # Direct Mem0 add with enhanced metadata
result = self.memory.add( result = self.memory.add(formatted_messages, user_id=user_id,
formatted_messages, agent_id=agent_id, run_id=run_id,
user_id=user_id, metadata=enhanced_metadata)
agent_id=agent_id,
run_id=run_id,
metadata=enhanced_metadata,
)
return { return {
"added_memories": result if isinstance(result, list) else [result], "added_memories": result if isinstance(result, list) else [result],
"message": "Memories added successfully", "message": "Memories added successfully",
"hierarchy": { "hierarchy": {"user_id": user_id, "agent_id": agent_id, "run_id": run_id}
"user_id": user_id,
"agent_id": agent_id,
"run_id": run_id,
},
} }
except Exception as e: except Exception as e:
logger.error(f"Error adding memories: {e}") logger.error(f"Error adding memories: {e}")
raise raise e
@db_retry
@timed("search_memories") @timed("search_memories")
async def search_memories( async def search_memories(
self, self,
@ -196,79 +155,37 @@ class Mem0Manager:
# rerank: bool = False, # rerank: bool = False,
# filter_memories: bool = False, # filter_memories: bool = False,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Search memories - native Mem0 pattern""" """Search memories - native Mem0 pattern"""
try: try:
# Minimal empty query protection for API compatibility # Minimal empty query protection for API compatibility
if not query or query.strip() == "": if not query or query.strip() == "":
return { return {"memories": [], "total_count": 0, "query": query, "note": "Empty query provided, no results returned. Use a specific query to search memories."}
"memories": [],
"total_count": 0,
"query": query,
"note": "Empty query provided, no results returned. Use a specific query to search memories.",
}
# Direct Mem0 search - trust native handling # Direct Mem0 search - trust native handling
result = self.memory.search( result = self.memory.search(query=query, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit, threshold=threshold, filters=filters)
query=query, return {"memories": result.get("results", []), "total_count": len(result.get("results", [])), "query": query}
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit,
threshold=threshold,
filters=filters,
)
return {
"memories": result.get("results", []),
"total_count": len(result.get("results", [])),
"query": query,
}
except Exception as e: except Exception as e:
logger.error(f"Error searching memories: {e}") logger.error(f"Error searching memories: {e}")
raise raise e
@db_retry
async def get_user_memories( async def get_user_memories(
self, self,
user_id: str, user_id: str,
limit: int = 10, limit: int = 10,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Get all memories for a user - native Mem0 pattern.""" """Get all memories for a user - native Mem0 pattern."""
try: try:
# Direct Mem0 get_all call - trust native parameter handling # Direct Mem0 get_all call - trust native parameter handling
result = self.memory.get_all( result = self.memory.get_all(user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id, filters=filters)
user_id=user_id,
limit=limit,
agent_id=agent_id,
run_id=run_id,
filters=filters,
)
return result.get("results", []) return result.get("results", [])
except Exception as e: except Exception as e:
logger.error(f"Error getting user memories: {e}") logger.error(f"Error getting user memories: {e}")
raise raise e
@db_retry
async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
"""Get a single memory by ID. Returns None if not found."""
try:
result = self.memory.get(memory_id=memory_id)
return result
except Exception as e:
logger.debug(f"Memory {memory_id} not found or error: {e}")
return None
async def verify_memory_ownership(self, memory_id: str, user_id: str) -> bool:
"""Check if a memory belongs to a user. O(1) instead of O(n)."""
memory = await self.get_memory(memory_id)
if memory is None:
return False
return memory.get("user_id") == user_id
@db_retry
@timed("update_memory") @timed("update_memory")
async def update_memory( async def update_memory(
self, self,
@ -277,13 +194,15 @@ class Mem0Manager:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Update memory - pure Mem0 passthrough.""" """Update memory - pure Mem0 passthrough."""
try: try:
result = self.memory.update(memory_id=memory_id, data=content) result = self.memory.update(
memory_id=memory_id,
data=content
)
return {"message": "Memory updated successfully", "result": result} return {"message": "Memory updated successfully", "result": result}
except Exception as e: except Exception as e:
logger.error(f"Error updating memory: {e}") logger.error(f"Error updating memory: {e}")
raise raise e
@db_retry
@timed("delete_memory") @timed("delete_memory")
async def delete_memory(self, memory_id: str) -> Dict[str, Any]: async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
"""Delete memory - pure Mem0 passthrough.""" """Delete memory - pure Mem0 passthrough."""
@ -292,7 +211,7 @@ class Mem0Manager:
return {"message": "Memory deleted successfully"} return {"message": "Memory deleted successfully"}
except Exception as e: except Exception as e:
logger.error(f"Error deleting memory: {e}") logger.error(f"Error deleting memory: {e}")
raise raise e
async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]: async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]:
"""Delete all user memories - pure Mem0 passthrough.""" """Delete all user memories - pure Mem0 passthrough."""
@ -301,7 +220,7 @@ class Mem0Manager:
return {"message": "All user memories deleted successfully"} return {"message": "All user memories deleted successfully"}
except Exception as e: except Exception as e:
logger.error(f"Error deleting user memories: {e}") logger.error(f"Error deleting user memories: {e}")
raise raise e
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]: async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
"""Get memory change history - pure Mem0 passthrough.""" """Get memory change history - pure Mem0 passthrough."""
@ -310,24 +229,22 @@ class Mem0Manager:
return { return {
"memory_id": memory_id, "memory_id": memory_id,
"history": history, "history": history,
"message": "Memory history retrieved successfully", "message": "Memory history retrieved successfully"
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting memory history: {e}") logger.error(f"Error getting memory history: {e}")
raise raise e
async def get_graph_relationships(
self, async def get_graph_relationships(self, user_id: Optional[str], agent_id: Optional[str], run_id: Optional[str], limit: int = 50) -> Dict[str, Any]:
user_id: Optional[str],
agent_id: Optional[str],
run_id: Optional[str],
limit: int = 50,
) -> Dict[str, Any]:
"""Get graph relationships - using correct Mem0 get_all() method.""" """Get graph relationships - using correct Mem0 get_all() method."""
try: try:
# Use get_all() to retrieve memories with graph relationships # Use get_all() to retrieve memories with graph relationships
result = self.memory.get_all( result = self.memory.get_all(
user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit
) )
# Extract relationships from Mem0's response structure # Extract relationships from Mem0's response structure
@ -355,7 +272,7 @@ class Mem0Manager:
"agent_id": agent_id, "agent_id": agent_id,
"run_id": run_id, "run_id": run_id,
"total_memories": len(result.get("results", [])), "total_memories": len(result.get("results", [])),
"total_relationships": len(relationships), "total_relationships": len(relationships)
} }
except Exception as e: except Exception as e:
@ -369,7 +286,7 @@ class Mem0Manager:
"run_id": run_id, "run_id": run_id,
"total_memories": 0, "total_memories": 0,
"total_relationships": 0, "total_relationships": 0,
"error": str(e), "error": str(e)
} }
@timed("chat_with_memory") @timed("chat_with_memory")
@ -387,70 +304,49 @@ class Mem0Manager:
try: try:
total_start_time = time.time() total_start_time = time.time()
logger.info("Starting chat request", user_id=user_id) print(f"\n🚀 Starting chat request for user: {user_id}")
# Stage 1: Memory Search
search_start_time = time.time() search_start_time = time.time()
search_result = self.memory.search( search_result = self.memory.search(query=message, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=10, threshold=0.3)
query=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=10,
threshold=0.3,
)
relevant_memories = search_result.get("results", []) relevant_memories = search_result.get("results", [])
memories_str = "\n".join( memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories)
f"- {entry['memory']}" for entry in relevant_memories
)
search_time = time.time() - search_start_time search_time = time.time() - search_start_time
logger.debug( print(f"🔍 Memory search took: {search_time:.2f}s (found {len(relevant_memories)} memories)")
"Memory search completed",
search_time_s=round(search_time, 2),
memories_found=len(relevant_memories),
)
# Stage 2: Prepare LLM messages
prep_start_time = time.time() prep_start_time = time.time()
system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}" system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
messages = [{"role": "system", "content": system_prompt}] messages = [{"role": "system", "content": system_prompt}]
# Add conversation context if provided (last 50 messages)
if context: if context:
messages.extend(context) messages.extend(context)
logger.debug("Added context messages", context_count=len(context)) print(f"📝 Added {len(context)} context messages")
# Add current user message
messages.append({"role": "user", "content": message}) messages.append({"role": "user", "content": message})
prep_time = time.time() - prep_start_time prep_time = time.time() - prep_start_time
print(f"📋 Message preparation took: {prep_time:.3f}s")
# Stage 3: LLM Call
llm_start_time = time.time() llm_start_time = time.time()
response = self.openai_client.chat.completions.create( response = self.openai_client.chat.completions.create(model=settings.default_model, messages=messages)
model=settings.default_model, messages=messages
)
assistant_response = response.choices[0].message.content assistant_response = response.choices[0].message.content
llm_time = time.time() - llm_start_time llm_time = time.time() - llm_start_time
logger.debug( print(f"🤖 LLM call took: {llm_time:.2f}s (model: {settings.default_model})")
"LLM call completed",
llm_time_s=round(llm_time, 2),
model=settings.default_model,
)
# Stage 4: Memory Add
add_start_time = time.time() add_start_time = time.time()
memory_messages = [ memory_messages = [{"role": "user", "content": message}, {"role": "assistant", "content": assistant_response}]
{"role": "user", "content": message},
{"role": "assistant", "content": assistant_response},
]
self.memory.add(memory_messages, user_id=user_id) self.memory.add(memory_messages, user_id=user_id)
add_time = time.time() - add_start_time add_time = time.time() - add_start_time
print(f"💾 Memory add took: {add_time:.2f}s")
# Total timing summary
total_time = time.time() - total_start_time total_time = time.time() - total_start_time
logger.info( print(f"⏱️ TOTAL: {total_time:.2f}s | Search: {search_time:.2f}s | LLM: {llm_time:.2f}s | Add: {add_time:.2f}s | Prep: {prep_time:.3f}s")
"Chat request completed", print(f"📊 Breakdown: Search {(search_time/total_time)*100:.1f}% | LLM {(llm_time/total_time)*100:.1f}% | Add {(add_time/total_time)*100:.1f}%\n")
user_id=user_id,
total_time_s=round(total_time, 2),
search_time_s=round(search_time, 2),
llm_time_s=round(llm_time, 2),
add_time_s=round(add_time, 2),
memories_used=len(relevant_memories),
model=settings.default_model,
)
return { return {
"response": assistant_response, "response": assistant_response,
@ -460,22 +356,17 @@ class Mem0Manager:
"total": round(total_time, 2), "total": round(total_time, 2),
"search": round(search_time, 2), "search": round(search_time, 2),
"llm": round(llm_time, 2), "llm": round(llm_time, 2),
"add": round(add_time, 2), "add": round(add_time, 2)
}, }
} }
except Exception as e: except Exception as e:
logger.error( logger.error(f"Error in chat_with_memory: {e}")
"Error in chat_with_memory",
error=str(e),
user_id=user_id,
exc_info=True,
)
return { return {
"error": str(e), "error": str(e),
"response": "I apologize, but I encountered an error processing your request.", "response": "I apologize, but I encountered an error processing your request.",
"memories_used": 0, "memories_used": 0,
"model_used": None, "model_used": None
} }
async def health_check(self) -> Dict[str, str]: async def health_check(self) -> Dict[str, str]:

View file

@ -2,115 +2,53 @@
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import re
# Constants for input validation
MAX_MESSAGE_LENGTH = 50000 # ~12k tokens max per message
MAX_QUERY_LENGTH = 10000 # ~2.5k tokens max per query
MAX_USER_ID_LENGTH = 100 # Reasonable user ID length
MAX_MEMORY_ID_LENGTH = 100 # Memory IDs are typically UUIDs
MAX_CONTEXT_MESSAGES = 100 # Max conversation context messages
USER_ID_PATTERN = r"^[a-zA-Z0-9_\-\.@]+$" # Alphanumeric with common separators
# Request Models # Request Models
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
"""Chat message structure.""" """Chat message structure."""
role: str = Field(..., description="Message role (user, assistant, system)")
role: str = Field( content: str = Field(..., description="Message content")
..., max_length=20, description="Message role (user, assistant, system)"
)
content: str = Field(
..., max_length=MAX_MESSAGE_LENGTH, description="Message content"
)
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
"""Ultra-minimal chat request.""" """Ultra-minimal chat request."""
message: str = Field(..., description="User message")
message: str = Field(..., max_length=MAX_MESSAGE_LENGTH, description="User message") user_id: Optional[str] = Field("default", description="User identifier")
user_id: Optional[str] = Field( agent_id: Optional[str] = Field(None, description="Agent identifier")
"default", run_id: Optional[str] = Field(None, description="Run identifier")
max_length=MAX_USER_ID_LENGTH, context: Optional[List[ChatMessage]] = Field(None, description="Previous conversation context")
pattern=USER_ID_PATTERN,
description="User identifier (alphanumeric, _, -, ., @)",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
context: Optional[List[ChatMessage]] = Field(
None,
max_length=MAX_CONTEXT_MESSAGES,
description="Previous conversation context (max 100 messages)",
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class MemoryAddRequest(BaseModel): class MemoryAddRequest(BaseModel):
"""Request to add memories with hierarchy support - open-source compatible.""" """Request to add memories with hierarchy support - open-source compatible."""
messages: List[ChatMessage] = Field(..., description="Messages to process")
messages: List[ChatMessage] = Field( user_id: Optional[str] = Field("default", description="User identifier")
..., agent_id: Optional[str] = Field(None, description="Agent identifier")
max_length=MAX_CONTEXT_MESSAGES, run_id: Optional[str] = Field(None, description="Run identifier")
description="Messages to process (max 100 messages)",
)
user_id: Optional[str] = Field(
"default",
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class MemorySearchRequest(BaseModel): class MemorySearchRequest(BaseModel):
"""Request to search memories with hierarchy filtering.""" """Request to search memories with hierarchy filtering."""
query: str = Field(..., description="Search query")
query: str = Field(..., max_length=MAX_QUERY_LENGTH, description="Search query") user_id: Optional[str] = Field("default", description="User identifier")
user_id: Optional[str] = Field( agent_id: Optional[str] = Field(None, description="Agent identifier")
"default", run_id: Optional[str] = Field(None, description="Run identifier")
max_length=MAX_USER_ID_LENGTH, limit: int = Field(5, description="Maximum number of results")
pattern=USER_ID_PATTERN, threshold: Optional[float] = Field(None, description="Minimum relevance score")
description="User identifier",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
limit: int = Field(5, ge=1, le=100, description="Maximum number of results (1-100)")
threshold: Optional[float] = Field(
None, ge=0.0, le=1.0, description="Minimum relevance score (0-1)"
)
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters") filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters")
# Hierarchy filters (open-source compatible)
agent_id: Optional[str] = Field(None, description="Filter by agent identifier")
run_id: Optional[str] = Field(None, description="Filter by run identifier")
class MemoryUpdateRequest(BaseModel): class MemoryUpdateRequest(BaseModel):
"""Request to update a memory.""" """Request to update a memory."""
memory_id: str = Field(..., description="Memory ID to update")
memory_id: str = Field( content: str = Field(..., description="New memory content")
..., max_length=MAX_MEMORY_ID_LENGTH, description="Memory ID to update"
)
user_id: str = Field(
...,
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier for ownership verification",
)
content: str = Field(
..., max_length=MAX_MESSAGE_LENGTH, description="New memory content"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata")
@ -119,23 +57,19 @@ class MemoryUpdateRequest(BaseModel):
class MemoryItem(BaseModel): class MemoryItem(BaseModel):
"""Individual memory item.""" """Individual memory item."""
id: str = Field(..., description="Memory unique identifier") id: str = Field(..., description="Memory unique identifier")
memory: str = Field(..., description="Memory content") memory: str = Field(..., description="Memory content")
user_id: Optional[str] = Field(None, description="Associated user ID") user_id: Optional[str] = Field(None, description="Associated user ID")
agent_id: Optional[str] = Field(None, description="Associated agent ID") agent_id: Optional[str] = Field(None, description="Associated agent ID")
run_id: Optional[str] = Field(None, description="Associated run ID") run_id: Optional[str] = Field(None, description="Associated run ID")
metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata")
score: Optional[float] = Field( score: Optional[float] = Field(None, description="Relevance score (for search results)")
None, description="Relevance score (for search results)"
)
created_at: Optional[str] = Field(None, description="Creation timestamp") created_at: Optional[str] = Field(None, description="Creation timestamp")
updated_at: Optional[str] = Field(None, description="Last update timestamp") updated_at: Optional[str] = Field(None, description="Last update timestamp")
class MemorySearchResponse(BaseModel): class MemorySearchResponse(BaseModel):
"""Memory search results - pure Mem0 structure.""" """Memory search results - pure Mem0 structure."""
memories: List[MemoryItem] = Field(..., description="Found memories") memories: List[MemoryItem] = Field(..., description="Found memories")
total_count: int = Field(..., description="Total number of memories found") total_count: int = Field(..., description="Total number of memories found")
query: str = Field(..., description="Original search query") query: str = Field(..., description="Original search query")
@ -143,37 +77,27 @@ class MemorySearchResponse(BaseModel):
class MemoryAddResponse(BaseModel): class MemoryAddResponse(BaseModel):
"""Response from adding memories - pure Mem0 structure.""" """Response from adding memories - pure Mem0 structure."""
added_memories: List[Dict[str, Any]] = Field(..., description="Memories that were added")
added_memories: List[Dict[str, Any]] = Field(
..., description="Memories that were added"
)
message: str = Field(..., description="Success message") message: str = Field(..., description="Success message")
class GraphRelationship(BaseModel): class GraphRelationship(BaseModel):
"""Graph relationship structure.""" """Graph relationship structure."""
source: str = Field(..., description="Source entity") source: str = Field(..., description="Source entity")
relationship: str = Field(..., description="Relationship type") relationship: str = Field(..., description="Relationship type")
target: str = Field(..., description="Target entity") target: str = Field(..., description="Target entity")
properties: Optional[Dict[str, Any]] = Field( properties: Optional[Dict[str, Any]] = Field(None, description="Relationship properties")
None, description="Relationship properties"
)
class GraphResponse(BaseModel): class GraphResponse(BaseModel):
"""Graph relationships - pure Mem0 structure.""" """Graph relationships - pure Mem0 structure."""
relationships: List[GraphRelationship] = Field(..., description="Found relationships")
relationships: List[GraphRelationship] = Field(
..., description="Found relationships"
)
entities: List[str] = Field(..., description="Unique entities") entities: List[str] = Field(..., description="Unique entities")
user_id: str = Field(..., description="User identifier") user_id: str = Field(..., description="User identifier")
class HealthResponse(BaseModel): class HealthResponse(BaseModel):
"""Health check response.""" """Health check response."""
status: str = Field(..., description="Service status") status: str = Field(..., description="Service status")
services: Dict[str, str] = Field(..., description="Individual service statuses") services: Dict[str, str] = Field(..., description="Individual service statuses")
timestamp: str = Field(..., description="Health check timestamp") timestamp: str = Field(..., description="Health check timestamp")
@ -181,7 +105,6 @@ class HealthResponse(BaseModel):
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
"""Error response structure.""" """Error response structure."""
error: str = Field(..., description="Error message") error: str = Field(..., description="Error message")
detail: Optional[str] = Field(None, description="Detailed error information") detail: Optional[str] = Field(None, description="Detailed error information")
status_code: int = Field(..., description="HTTP status code") status_code: int = Field(..., description="HTTP status code")
@ -189,10 +112,8 @@ class ErrorResponse(BaseModel):
# Statistics and Monitoring Models # Statistics and Monitoring Models
class MemoryOperationStats(BaseModel): class MemoryOperationStats(BaseModel):
"""Memory operation statistics.""" """Memory operation statistics."""
add: int = Field(..., description="Number of add operations") add: int = Field(..., description="Number of add operations")
search: int = Field(..., description="Number of search operations") search: int = Field(..., description="Number of search operations")
update: int = Field(..., description="Number of update operations") update: int = Field(..., description="Number of update operations")
@ -201,47 +122,34 @@ class MemoryOperationStats(BaseModel):
class GlobalStatsResponse(BaseModel): class GlobalStatsResponse(BaseModel):
"""Global application statistics.""" """Global application statistics."""
total_memories: int = Field(..., description="Total memories across all users") total_memories: int = Field(..., description="Total memories across all users")
total_users: int = Field(..., description="Total number of users") total_users: int = Field(..., description="Total number of users")
api_calls_today: int = Field(..., description="Total API calls today") api_calls_today: int = Field(..., description="Total API calls today")
avg_response_time_ms: float = Field( avg_response_time_ms: float = Field(..., description="Average response time in milliseconds")
..., description="Average response time in milliseconds" memory_operations: MemoryOperationStats = Field(..., description="Memory operation breakdown")
)
memory_operations: MemoryOperationStats = Field(
..., description="Memory operation breakdown"
)
uptime_seconds: float = Field(..., description="Application uptime in seconds") uptime_seconds: float = Field(..., description="Application uptime in seconds")
class UserStatsResponse(BaseModel): class UserStatsResponse(BaseModel):
"""User-specific statistics.""" """User-specific statistics."""
user_id: str = Field(..., description="User identifier") user_id: str = Field(..., description="User identifier")
memory_count: int = Field(..., description="Number of memories for this user") memory_count: int = Field(..., description="Number of memories for this user")
relationship_count: int = Field( relationship_count: int = Field(..., description="Number of graph relationships for this user")
..., description="Number of graph relationships for this user"
)
last_activity: Optional[str] = Field(None, description="Last activity timestamp") last_activity: Optional[str] = Field(None, description="Last activity timestamp")
api_calls_today: int = Field(..., description="API calls made by this user today") api_calls_today: int = Field(..., description="API calls made by this user today")
avg_response_time_ms: float = Field( avg_response_time_ms: float = Field(..., description="Average response time for this user's requests")
..., description="Average response time for this user's requests"
)
# OpenAI-Compatible API Models # OpenAI-Compatible API Models
class OpenAIMessage(BaseModel): class OpenAIMessage(BaseModel):
"""OpenAI message format.""" """OpenAI message format."""
role: str = Field(..., description="Message role (system, user, assistant)") role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content") content: str = Field(..., description="Message content")
class OpenAIChatCompletionRequest(BaseModel): class OpenAIChatCompletionRequest(BaseModel):
"""OpenAI chat completion request format.""" """OpenAI chat completion request format."""
model: str = Field(..., description="Model to use (will use configured default)") model: str = Field(..., description="Model to use (will use configured default)")
messages: List[Dict[str, str]] = Field(..., description="List of messages") messages: List[Dict[str, str]] = Field(..., description="List of messages")
temperature: Optional[float] = Field(0.7, description="Sampling temperature") temperature: Optional[float] = Field(0.7, description="Sampling temperature")
@ -252,14 +160,11 @@ class OpenAIChatCompletionRequest(BaseModel):
stop: Optional[List[str]] = Field(None, description="Stop sequences") stop: Optional[List[str]] = Field(None, description="Stop sequences")
presence_penalty: Optional[float] = Field(0, description="Presence penalty") presence_penalty: Optional[float] = Field(0, description="Presence penalty")
frequency_penalty: Optional[float] = Field(0, description="Frequency penalty") frequency_penalty: Optional[float] = Field(0, description="Frequency penalty")
user: Optional[str] = Field( user: Optional[str] = Field(None, description="User identifier (ignored, uses API key)")
None, description="User identifier (ignored, uses API key)"
)
class OpenAIUsage(BaseModel): class OpenAIUsage(BaseModel):
"""Token usage information.""" """Token usage information."""
prompt_tokens: int = Field(..., description="Tokens in the prompt") prompt_tokens: int = Field(..., description="Tokens in the prompt")
completion_tokens: int = Field(..., description="Tokens in the completion") completion_tokens: int = Field(..., description="Tokens in the completion")
total_tokens: int = Field(..., description="Total tokens used") total_tokens: int = Field(..., description="Total tokens used")
@ -267,14 +172,12 @@ class OpenAIUsage(BaseModel):
class OpenAIChoiceMessage(BaseModel): class OpenAIChoiceMessage(BaseModel):
"""Message in a choice.""" """Message in a choice."""
role: str = Field(..., description="Role of the message") role: str = Field(..., description="Role of the message")
content: str = Field(..., description="Content of the message") content: str = Field(..., description="Content of the message")
class OpenAIChoice(BaseModel): class OpenAIChoice(BaseModel):
"""Individual completion choice.""" """Individual completion choice."""
index: int = Field(..., description="Choice index") index: int = Field(..., description="Choice index")
message: OpenAIChoiceMessage = Field(..., description="Message content") message: OpenAIChoiceMessage = Field(..., description="Message content")
finish_reason: str = Field(..., description="Reason for completion finish") finish_reason: str = Field(..., description="Reason for completion finish")
@ -282,7 +185,6 @@ class OpenAIChoice(BaseModel):
class OpenAIChatCompletionResponse(BaseModel): class OpenAIChatCompletionResponse(BaseModel):
"""OpenAI chat completion response format.""" """OpenAI chat completion response format."""
id: str = Field(..., description="Unique completion ID") id: str = Field(..., description="Unique completion ID")
object: str = Field(default="chat.completion", description="Object type") object: str = Field(default="chat.completion", description="Object type")
created: int = Field(..., description="Unix timestamp of creation") created: int = Field(..., description="Unix timestamp of creation")
@ -293,19 +195,14 @@ class OpenAIChatCompletionResponse(BaseModel):
# Streaming-specific models # Streaming-specific models
class OpenAIStreamDelta(BaseModel): class OpenAIStreamDelta(BaseModel):
"""Delta content in a streaming chunk.""" """Delta content in a streaming chunk."""
role: Optional[str] = Field(None, description="Role (only in first chunk)") role: Optional[str] = Field(None, description="Role (only in first chunk)")
content: Optional[str] = Field(None, description="Incremental content") content: Optional[str] = Field(None, description="Incremental content")
class OpenAIStreamChoice(BaseModel): class OpenAIStreamChoice(BaseModel):
"""Individual streaming choice.""" """Individual streaming choice."""
index: int = Field(..., description="Choice index") index: int = Field(..., description="Choice index")
delta: OpenAIStreamDelta = Field(..., description="Delta content") delta: OpenAIStreamDelta = Field(..., description="Delta content")
finish_reason: Optional[str] = Field( finish_reason: Optional[str] = Field(None, description="Reason for completion finish")
None, description="Reason for completion finish"
)

View file

@ -19,7 +19,6 @@ ollama
# Utilities # Utilities
pydantic pydantic
pydantic-settings pydantic-settings
tenacity
python-dotenv python-dotenv
httpx httpx
aiofiles aiofiles
@ -32,9 +31,3 @@ python-json-logger
# CORS and Security # CORS and Security
python-jose[cryptography] python-jose[cryptography]
passlib[bcrypt] passlib[bcrypt]
# Rate Limiting
slowapi
# MCP Server
mcp[server]>=1.0.0

View file

@ -5,8 +5,6 @@ services:
container_name: mem0-qdrant container_name: mem0-qdrant
expose: expose:
- "6333" - "6333"
networks:
- mem0_network
volumes: volumes:
- qdrant_data:/qdrant/storage - qdrant_data:/qdrant/storage
command: > command: >
@ -17,6 +15,8 @@ services:
timeout: 5s timeout: 5s
retries: 5 retries: 5
restart: unless-stopped restart: unless-stopped
networks:
- mem0_network
# Neo4j with APOC for graph relationships # Neo4j with APOC for graph relationships
neo4j: neo4j:
@ -51,8 +51,8 @@ services:
# Backend API service # Backend API service
backend: backend:
build: build:
context: ./backend context: .
dockerfile: Dockerfile dockerfile: ./backend/Dockerfile
container_name: mem0-backend container_name: mem0-backend
environment: environment:
OPENAI_API_KEY: ${OPENAI_COMPAT_API_KEY} OPENAI_API_KEY: ${OPENAI_COMPAT_API_KEY}
@ -69,11 +69,8 @@ services:
CORS_ORIGINS: ${CORS_ORIGINS:-http://localhost:3000} CORS_ORIGINS: ${CORS_ORIGINS:-http://localhost:3000}
DEFAULT_MODEL: ${DEFAULT_MODEL:-claude-sonnet-4} DEFAULT_MODEL: ${DEFAULT_MODEL:-claude-sonnet-4}
API_KEYS: ${API_KEYS:-{}} API_KEYS: ${API_KEYS:-{}}
OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://host.docker.internal:11434}
EMBEDDING_MODEL: ${EMBEDDING_MODEL:-nomic-embed-text}
EMBEDDING_DIMS: ${EMBEDDING_DIMS:-2560}
expose: expose:
- "8000" - 8000
networks: networks:
- npm_network - npm_network
- mem0_network - mem0_network

122
setup.sh
View file

@ -1,122 +0,0 @@
#!/bin/bash
set -e
COMPOSE_PROJECT_NAME="${COMPOSE_PROJECT_NAME:-mem0}"
VOLUMES=("qdrant_data" "neo4j_data" "neo4j_logs" "neo4j_import" "neo4j_plugins")
echo "=========================================="
echo " Mem0 Setup Script"
echo "=========================================="
echo ""
check_volumes_exist() {
local exists=false
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
exists=true
break
fi
done
echo "$exists"
}
list_existing_volumes() {
echo "Existing volumes:"
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
size=$(docker system df -v 2>/dev/null | grep "$full_name" | awk '{print $4}' || echo "unknown")
echo " - $full_name ($size)"
fi
done
}
remove_volumes() {
echo "Stopping containers..."
docker compose down 2>/dev/null || true
echo "Removing volumes..."
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
echo " Removing $full_name..."
docker volume rm "$full_name" 2>/dev/null || true
fi
done
echo "Volumes removed."
}
build_and_start() {
echo ""
echo "Building containers..."
docker compose build
echo ""
echo "Starting services..."
docker compose up -d
echo ""
echo "Waiting for services to be healthy..."
sleep 5
echo ""
echo "Checking health..."
curl -s http://localhost:8000/health 2>/dev/null | jq . || echo "Health check not available yet. Services may still be starting."
echo ""
echo "=========================================="
echo " Setup Complete!"
echo "=========================================="
echo ""
echo "Services:"
echo " - Backend API: http://localhost:8000"
echo " - API Docs: http://localhost:8000/docs"
echo " - Health Check: http://localhost:8000/health"
echo ""
echo "Logs: docker compose logs -f backend"
}
if [ "$(check_volumes_exist)" = "true" ]; then
echo "Existing data volumes detected!"
echo ""
list_existing_volumes
echo ""
echo "Options:"
echo " 1) Keep existing data and start services"
echo " 2) Reset everything (DELETE ALL DATA) and start fresh"
echo " 3) Exit"
echo ""
read -p "Choose an option [1/2/3]: " choice
case $choice in
1)
echo ""
echo "Keeping existing data..."
build_and_start
;;
2)
echo ""
read -p "Are you sure you want to DELETE ALL DATA? Type 'yes' to confirm: " confirm
if [ "$confirm" = "yes" ]; then
remove_volumes
build_and_start
else
echo "Aborted."
exit 1
fi
;;
3)
echo "Exiting."
exit 0
;;
*)
echo "Invalid option. Exiting."
exit 1
;;
esac
else
echo "No existing volumes found. Starting fresh setup..."
build_and_start
fi

View file

@ -19,24 +19,13 @@ import time
BASE_URL = "http://localhost:8000" BASE_URL = "http://localhost:8000"
TEST_USER = f"test_user_{int(datetime.now().timestamp())}" TEST_USER = f"test_user_{int(datetime.now().timestamp())}"
# API Key for authentication - set via environment or use default test key
import os
API_KEY = os.environ.get("MEM0_API_KEY", "test-api-key")
AUTH_HEADERS = {"X-API-Key": API_KEY}
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Mem0 Integration Tests - Real API Testing (Zero Mocking)", description="Mem0 Integration Tests - Real API Testing (Zero Mocking)",
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="Show detailed output and API responses",
) )
parser.add_argument("--verbose", "-v", action="store_true",
help="Show detailed output and API responses")
args = parser.parse_args() args = parser.parse_args()
verbose = args.verbose verbose = args.verbose
@ -50,9 +39,6 @@ def main():
# Test sequence - order matters for data dependencies # Test sequence - order matters for data dependencies
tests = [ tests = [
test_health_check, test_health_check,
test_auth_required_endpoints,
test_ownership_verification,
test_request_size_limit,
test_empty_search_protection, test_empty_search_protection,
test_add_memories_with_hierarchy, test_add_memories_with_hierarchy,
test_search_memories_basic, test_search_memories_basic,
@ -65,7 +51,7 @@ def main():
test_graph_relationships, test_graph_relationships,
test_delete_specific_memory, test_delete_specific_memory,
test_delete_all_user_memories, test_delete_all_user_memories,
test_cleanup_verification, test_cleanup_verification
] ]
results = [] results = []
@ -96,7 +82,6 @@ def main():
print("❌ Some tests failed! Check the output above.") print("❌ Some tests failed! Check the output above.")
sys.exit(1) sys.exit(1)
def run_test(name, test_func, verbose): def run_test(name, test_func, verbose):
"""Run a single test with error handling""" """Run a single test with error handling"""
try: try:
@ -117,7 +102,6 @@ def run_test(name, test_func, verbose):
print(f"{name}: {e}") print(f"{name}: {e}")
return False return False
def log_response(response, verbose, context=""): def log_response(response, verbose, context=""):
"""Log API response details if verbose""" """Log API response details if verbose"""
if verbose: if verbose:
@ -127,30 +111,22 @@ def log_response(response, verbose, context=""):
if isinstance(data, dict) and len(data) < 5: if isinstance(data, dict) and len(data) < 5:
print(f" {context} Response: {data}") print(f" {context} Response: {data}")
else: else:
print( print(f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}")
f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}"
)
except: except:
print(f" {context} Response: {response.text[:100]}...") print(f" {context} Response: {response.text[:100]}...")
# ================== TEST FUNCTIONS ================== # ================== TEST FUNCTIONS ==================
def test_health_check(verbose): def test_health_check(verbose):
"""Test service health endpoint""" """Test service health endpoint"""
response = requests.get( response = requests.get(f"{BASE_URL}/health", timeout=10)
f"{BASE_URL}/health", timeout=10
) # Health doesn't require auth
log_response(response, verbose, "Health") log_response(response, verbose, "Health")
assert response.status_code == 200, f"Expected 200, got {response.status_code}" assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json() data = response.json()
assert "status" in data, "Health response missing 'status' field" assert "status" in data, "Health response missing 'status' field"
assert data["status"] in ["healthy", "degraded"], ( assert data["status"] in ["healthy", "degraded"], f"Invalid status: {data['status']}"
f"Invalid status: {data['status']}"
)
# Check individual services # Check individual services
assert "services" in data, "Health response missing 'services' field" assert "services" in data, "Health response missing 'services' field"
@ -160,19 +136,18 @@ def test_health_check(verbose):
for service, status in data["services"].items(): for service, status in data["services"].items():
print(f" {service}: {status}") print(f" {service}: {status}")
def test_empty_search_protection(verbose): def test_empty_search_protection(verbose):
"""Test empty query protection (should not return 500 error)""" """Test empty query protection (should not return 500 error)"""
payload = {"query": "", "user_id": TEST_USER, "limit": 5} payload = {
"query": "",
"user_id": TEST_USER,
"limit": 5
}
response = requests.post( response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=10)
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Empty Search") log_response(response, verbose, "Empty Search")
assert response.status_code == 200, ( assert response.status_code == 200, f"Empty query failed with {response.status_code}"
f"Empty query failed with {response.status_code}"
)
data = response.json() data = response.json()
assert data["memories"] == [], "Empty query should return empty memories list" assert data["memories"] == [], "Empty query should return empty memories list"
@ -183,39 +158,25 @@ def test_empty_search_protection(verbose):
print(f" Empty search note: {data['note']}") print(f" Empty search note: {data['note']}")
print(f" Total count: {data.get('total_count', 0)}") print(f" Total count: {data.get('total_count', 0)}")
def test_add_memories_with_hierarchy(verbose): def test_add_memories_with_hierarchy(verbose):
"""Test adding memories with multi-level hierarchy support""" """Test adding memories with multi-level hierarchy support"""
payload = { payload = {
"messages": [ "messages": [
{ {"role": "user", "content": "I work at TechCorp as a Senior Software Engineer"},
"role": "user", {"role": "user", "content": "My colleague Sarah from Marketing team helped with Q3 presentation"},
"content": "I work at TechCorp as a Senior Software Engineer", {"role": "user", "content": "Meeting with John the Product Manager tomorrow about new feature development"}
},
{
"role": "user",
"content": "My colleague Sarah from Marketing team helped with Q3 presentation",
},
{
"role": "user",
"content": "Meeting with John the Product Manager tomorrow about new feature development",
},
], ],
"user_id": TEST_USER, "user_id": TEST_USER,
"agent_id": "test_agent", "agent_id": "test_agent",
"run_id": "test_run_001", "run_id": "test_run_001",
"session_id": "test_session_001", "session_id": "test_session_001",
"metadata": {"test": "integration", "scenario": "work_context"}, "metadata": {"test": "integration", "scenario": "work_context"}
} }
response = requests.post( response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60)
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
)
log_response(response, verbose, "Add Memories") log_response(response, verbose, "Add Memories")
assert response.status_code == 200, ( assert response.status_code == 200, f"Add memories failed with {response.status_code}"
f"Add memories failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "added_memories" in data, "Response missing 'added_memories'" assert "added_memories" in data, "Response missing 'added_memories'"
@ -230,26 +191,23 @@ def test_add_memories_with_hierarchy(verbose):
relations = first_memory["relations"] relations = first_memory["relations"]
if "added_entities" in relations and relations["added_entities"]: if "added_entities" in relations and relations["added_entities"]:
if verbose: if verbose:
print( print(f" Graph extracted: {len(relations['added_entities'])} relationships")
f" Graph extracted: {len(relations['added_entities'])} relationships"
)
print(f" Sample relations: {relations['added_entities'][:3]}") print(f" Sample relations: {relations['added_entities'][:3]}")
if verbose: if verbose:
print(f" Added {len(memories)} memory blocks") print(f" Added {len(memories)} memory blocks")
print( print(f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001")
f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001"
)
def test_search_memories_basic(verbose): def test_search_memories_basic(verbose):
"""Test basic memory search functionality""" """Test basic memory search functionality"""
# Test meaningful search # Test meaningful search
payload = {"query": "TechCorp", "user_id": TEST_USER, "limit": 10} payload = {
"query": "TechCorp",
"user_id": TEST_USER,
"limit": 10
}
response = requests.post( response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15)
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Search") log_response(response, verbose, "Search")
assert response.status_code == 200, f"Search failed with {response.status_code}" assert response.status_code == 200, f"Search failed with {response.status_code}"
@ -274,7 +232,6 @@ def test_search_memories_basic(verbose):
print(f" Found {data['total_count']} memories") print(f" Found {data['total_count']} memories")
print(f" First memory: {memory['memory'][:50]}...") print(f" First memory: {memory['memory'][:50]}...")
def test_search_memories_hierarchy_filters(verbose): def test_search_memories_hierarchy_filters(verbose):
"""Test multi-level hierarchy filtering in search""" """Test multi-level hierarchy filtering in search"""
# Test with hierarchy filters # Test with hierarchy filters
@ -284,17 +241,13 @@ def test_search_memories_hierarchy_filters(verbose):
"agent_id": "test_agent", "agent_id": "test_agent",
"run_id": "test_run_001", "run_id": "test_run_001",
"session_id": "test_session_001", "session_id": "test_session_001",
"limit": 10, "limit": 10
} }
response = requests.post( response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15)
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Hierarchy Search") log_response(response, verbose, "Hierarchy Search")
assert response.status_code == 200, ( assert response.status_code == 200, f"Hierarchy search failed with {response.status_code}"
f"Hierarchy search failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "memories" in data, "Hierarchy search response missing 'memories'" assert "memories" in data, "Hierarchy search response missing 'memories'"
@ -304,10 +257,7 @@ def test_search_memories_hierarchy_filters(verbose):
if verbose: if verbose:
print(f" Found {len(data['memories'])} memories with hierarchy filters") print(f" Found {len(data['memories'])} memories with hierarchy filters")
print( print(f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001")
f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001"
)
def test_get_user_memories_with_hierarchy(verbose): def test_get_user_memories_with_hierarchy(verbose):
"""Test retrieving user memories with hierarchy filtering""" """Test retrieving user memories with hierarchy filtering"""
@ -316,20 +266,13 @@ def test_get_user_memories_with_hierarchy(verbose):
"limit": 20, "limit": 20,
"agent_id": "test_agent", "agent_id": "test_agent",
"run_id": "test_run_001", "run_id": "test_run_001",
"session_id": "test_session_001", "session_id": "test_session_001"
} }
response = requests.get( response = requests.get(f"{BASE_URL}/memories/{TEST_USER}", params=params, timeout=15)
f"{BASE_URL}/memories/{TEST_USER}",
params=params,
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Get User Memories with Hierarchy") log_response(response, verbose, "Get User Memories with Hierarchy")
assert response.status_code == 200, ( assert response.status_code == 200, f"Get user memories with hierarchy failed with {response.status_code}"
f"Get user memories with hierarchy failed with {response.status_code}"
)
memories = response.json() memories = response.json()
assert isinstance(memories, list), "User memories should return a list" assert isinstance(memories, list), "User memories should return a list"
@ -347,13 +290,10 @@ def test_get_user_memories_with_hierarchy(verbose):
if verbose: if verbose:
print(" No memories found with hierarchy filters (may be expected)") print(" No memories found with hierarchy filters (may be expected)")
def test_memory_history(verbose): def test_memory_history(verbose):
"""Test memory history endpoint""" """Test memory history endpoint"""
# First get a memory to check history for # First get a memory to check history for
response = requests.get( response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for history test" assert response.status_code == 200, "Failed to get memory for history test"
memories = response.json() memories = response.json()
@ -365,38 +305,27 @@ def test_memory_history(verbose):
memory_id = memories[0]["id"] memory_id = memories[0]["id"]
# Test memory history endpoint # Test memory history endpoint
response = requests.get( response = requests.get(f"{BASE_URL}/memories/{memory_id}/history", timeout=15)
f"{BASE_URL}/memories/{memory_id}/history?user_id={TEST_USER}",
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Memory History") log_response(response, verbose, "Memory History")
assert response.status_code == 200, ( assert response.status_code == 200, f"Memory history failed with {response.status_code}"
f"Memory history failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "memory_id" in data, "History response missing 'memory_id'" assert "memory_id" in data, "History response missing 'memory_id'"
assert "history" in data, "History response missing 'history'" assert "history" in data, "History response missing 'history'"
assert "message" in data, "History response missing success message" assert "message" in data, "History response missing success message"
assert data["memory_id"] == memory_id, ( assert data["memory_id"] == memory_id, f"Wrong memory_id in response: {data['memory_id']}"
f"Wrong memory_id in response: {data['memory_id']}"
)
if verbose: if verbose:
print(f" Retrieved history for memory {memory_id}") print(f" Retrieved history for memory {memory_id}")
print( print(f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}")
f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}"
)
def test_update_memory(verbose): def test_update_memory(verbose):
"""Test updating a specific memory""" """Test updating a specific memory"""
# First get a memory to update # First get a memory to update
response = requests.get( response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for update test" assert response.status_code == 200, "Failed to get memory for update test"
memories = response.json() memories = response.json()
@ -408,13 +337,10 @@ def test_update_memory(verbose):
# Update the memory # Update the memory
payload = { payload = {
"memory_id": memory_id, "memory_id": memory_id,
"user_id": TEST_USER, "content": f"UPDATED: {original_content}"
"content": f"UPDATED: {original_content}",
} }
response = requests.put( response = requests.put(f"{BASE_URL}/memories", json=payload, timeout=10)
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Update") log_response(response, verbose, "Update")
assert response.status_code == 200, f"Update failed with {response.status_code}" assert response.status_code == 200, f"Update failed with {response.status_code}"
@ -426,15 +352,15 @@ def test_update_memory(verbose):
print(f" Updated memory {memory_id}") print(f" Updated memory {memory_id}")
print(f" Original: {original_content[:30]}...") print(f" Original: {original_content[:30]}...")
def test_chat_with_memory(verbose): def test_chat_with_memory(verbose):
"""Test memory-enhanced chat functionality""" """Test memory-enhanced chat functionality"""
payload = {"message": "What company do I work for?", "user_id": TEST_USER} payload = {
"message": "What company do I work for?",
"user_id": TEST_USER
}
try: try:
response = requests.post( response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=90)
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=90
)
log_response(response, verbose, "Chat") log_response(response, verbose, "Chat")
assert response.status_code == 200, f"Chat failed with {response.status_code}" assert response.status_code == 200, f"Chat failed with {response.status_code}"
@ -457,15 +383,12 @@ def test_chat_with_memory(verbose):
print(" Chat endpoint timed out (LLM API may be slow)") print(" Chat endpoint timed out (LLM API may be slow)")
# Still test that the endpoint exists and accepts requests # Still test that the endpoint exists and accepts requests
try: try:
response = requests.post( response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=5)
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=5
)
except requests.exceptions.ReadTimeout: except requests.exceptions.ReadTimeout:
# This is expected - endpoint exists but processing is slow # This is expected - endpoint exists but processing is slow
if verbose: if verbose:
print(" Chat endpoint confirmed active (processing timeout expected)") print(" Chat endpoint confirmed active (processing timeout expected)")
def test_graph_relationships_creation(verbose): def test_graph_relationships_creation(verbose):
"""Test graph relationships creation with entity-rich memories""" """Test graph relationships creation with entity-rich memories"""
# Create a separate test user for graph relationship testing # Create a separate test user for graph relationship testing
@ -474,67 +397,41 @@ def test_graph_relationships_creation(verbose):
# Add memories with clear entity relationships # Add memories with clear entity relationships
payload = { payload = {
"messages": [ "messages": [
{ {"role": "user", "content": "John Smith works at Microsoft as a Senior Software Engineer"},
"role": "user", {"role": "user", "content": "John Smith is friends with Sarah Johnson who works at Google"},
"content": "John Smith works at Microsoft as a Senior Software Engineer", {"role": "user", "content": "Sarah Johnson lives in Seattle and loves hiking"},
},
{
"role": "user",
"content": "John Smith is friends with Sarah Johnson who works at Google",
},
{
"role": "user",
"content": "Sarah Johnson lives in Seattle and loves hiking",
},
{"role": "user", "content": "Microsoft is located in Redmond, Washington"}, {"role": "user", "content": "Microsoft is located in Redmond, Washington"},
{ {"role": "user", "content": "John Smith and Sarah Johnson both graduated from Stanford University"}
"role": "user",
"content": "John Smith and Sarah Johnson both graduated from Stanford University",
},
], ],
"user_id": graph_test_user, "user_id": graph_test_user,
"metadata": {"test": "graph_relationships", "scenario": "entity_creation"}, "metadata": {"test": "graph_relationships", "scenario": "entity_creation"}
} }
response = requests.post( response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60)
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
)
log_response(response, verbose, "Add Graph Memories") log_response(response, verbose, "Add Graph Memories")
assert response.status_code == 200, ( assert response.status_code == 200, f"Add graph memories failed with {response.status_code}"
f"Add graph memories failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "added_memories" in data, "Response missing 'added_memories'" assert "added_memories" in data, "Response missing 'added_memories'"
if verbose: if verbose:
print( print(f" Added {len(data['added_memories'])} memories for graph relationship testing")
f" Added {len(data['added_memories'])} memories for graph relationship testing"
)
# Wait a moment for graph processing (Mem0 graph extraction can be async) # Wait a moment for graph processing (Mem0 graph extraction can be async)
time.sleep(2) time.sleep(2)
# Test graph relationships endpoint # Test graph relationships endpoint
response = requests.get( response = requests.get(f"{BASE_URL}/graph/relationships/{graph_test_user}", timeout=15)
f"{BASE_URL}/graph/relationships/{graph_test_user}",
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Graph Relationships") log_response(response, verbose, "Graph Relationships")
assert response.status_code == 200, ( assert response.status_code == 200, f"Graph relationships failed with {response.status_code}"
f"Graph relationships failed with {response.status_code}"
)
graph_data = response.json() graph_data = response.json()
assert "relationships" in graph_data, "Graph response missing 'relationships'" assert "relationships" in graph_data, "Graph response missing 'relationships'"
assert "entities" in graph_data, "Graph response missing 'entities'" assert "entities" in graph_data, "Graph response missing 'entities'"
assert "user_id" in graph_data, "Graph response missing 'user_id'" assert "user_id" in graph_data, "Graph response missing 'user_id'"
assert graph_data["user_id"] == graph_test_user, ( assert graph_data["user_id"] == graph_test_user, f"Wrong user_id in graph: {graph_data['user_id']}"
f"Wrong user_id in graph: {graph_data['user_id']}"
)
relationships = graph_data["relationships"] relationships = graph_data["relationships"]
entities = graph_data["entities"] entities = graph_data["entities"]
@ -554,24 +451,16 @@ def test_graph_relationships_creation(verbose):
# Print sample entities if they exist # Print sample entities if they exist
if entities: if entities:
print( print(f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}")
f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}"
)
# Verify relationship structure (if relationships exist) # Verify relationship structure (if relationships exist)
for rel in relationships: for rel in relationships:
assert "source" in rel or "from" in rel, ( assert "source" in rel or "from" in rel, f"Relationship missing source/from: {rel}"
f"Relationship missing source/from: {rel}"
)
assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}" assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}"
assert "relationship" in rel or "type" in rel, ( assert "relationship" in rel or "type" in rel, f"Relationship missing type: {rel}"
f"Relationship missing type: {rel}"
)
# Clean up graph test user memories # Clean up graph test user memories
cleanup_response = requests.delete( cleanup_response = requests.delete(f"{BASE_URL}/memories/user/{graph_test_user}", timeout=15)
f"{BASE_URL}/memories/user/{graph_test_user}", headers=AUTH_HEADERS, timeout=15
)
assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories" assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories"
if verbose: if verbose:
@ -580,17 +469,12 @@ def test_graph_relationships_creation(verbose):
# Note: We expect some relationships even if graph extraction is basic # Note: We expect some relationships even if graph extraction is basic
# The test passes if the endpoint works and returns proper structure # The test passes if the endpoint works and returns proper structure
def test_graph_relationships(verbose): def test_graph_relationships(verbose):
"""Test graph relationships endpoint""" """Test graph relationships endpoint"""
response = requests.get( response = requests.get(f"{BASE_URL}/graph/relationships/{TEST_USER}", timeout=15)
f"{BASE_URL}/graph/relationships/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Graph") log_response(response, verbose, "Graph")
assert response.status_code == 200, ( assert response.status_code == 200, f"Graph endpoint failed with {response.status_code}"
f"Graph endpoint failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "relationships" in data, "Graph response missing 'relationships'" assert "relationships" in data, "Graph response missing 'relationships'"
@ -602,13 +486,10 @@ def test_graph_relationships(verbose):
print(f" Relationships: {len(data['relationships'])}") print(f" Relationships: {len(data['relationships'])}")
print(f" Entities: {len(data['entities'])}") print(f" Entities: {len(data['entities'])}")
def test_delete_specific_memory(verbose): def test_delete_specific_memory(verbose):
"""Test deleting a specific memory""" """Test deleting a specific memory"""
# Get a memory to delete # Get a memory to delete
response = requests.get( response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for deletion test" assert response.status_code == 200, "Failed to get memory for deletion test"
memories = response.json() memories = response.json()
@ -617,9 +498,7 @@ def test_delete_specific_memory(verbose):
memory_id = memories[0]["id"] memory_id = memories[0]["id"]
# Delete the memory # Delete the memory
response = requests.delete( response = requests.delete(f"{BASE_URL}/memories/{memory_id}", timeout=10)
f"{BASE_URL}/memories/{memory_id}", headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Delete") log_response(response, verbose, "Delete")
assert response.status_code == 200, f"Delete failed with {response.status_code}" assert response.status_code == 200, f"Delete failed with {response.status_code}"
@ -630,12 +509,9 @@ def test_delete_specific_memory(verbose):
if verbose: if verbose:
print(f" Deleted memory {memory_id}") print(f" Deleted memory {memory_id}")
def test_delete_all_user_memories(verbose): def test_delete_all_user_memories(verbose):
"""Test deleting all memories for a user""" """Test deleting all memories for a user"""
response = requests.delete( response = requests.delete(f"{BASE_URL}/memories/user/{TEST_USER}", timeout=15)
f"{BASE_URL}/memories/user/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Delete All") log_response(response, verbose, "Delete All")
assert response.status_code == 200, f"Delete all failed with {response.status_code}" assert response.status_code == 200, f"Delete all failed with {response.status_code}"
@ -646,17 +522,12 @@ def test_delete_all_user_memories(verbose):
if verbose: if verbose:
print(f"Deleted all memories for {TEST_USER}") print(f"Deleted all memories for {TEST_USER}")
def test_cleanup_verification(verbose): def test_cleanup_verification(verbose):
"""Verify cleanup was successful""" """Verify cleanup was successful"""
response = requests.get( response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=10", timeout=10)
f"{BASE_URL}/memories/{TEST_USER}?limit=10", headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Cleanup Check") log_response(response, verbose, "Cleanup Check")
assert response.status_code == 200, ( assert response.status_code == 200, f"Cleanup verification failed with {response.status_code}"
f"Cleanup verification failed with {response.status_code}"
)
memories = response.json() memories = response.json()
assert isinstance(memories, list), "Should return list even if empty" assert isinstance(memories, list), "Should return list even if empty"
@ -668,79 +539,5 @@ def test_cleanup_verification(verbose):
if verbose: if verbose:
print(" Cleanup successful - no memories remain") print(" Cleanup successful - no memories remain")
# ================== SECURITY TEST FUNCTIONS ==================
def test_auth_required_endpoints(verbose):
"""Test that protected endpoints require authentication"""
endpoints_requiring_auth = [
("GET", f"{BASE_URL}/memories/{TEST_USER}"),
("POST", f"{BASE_URL}/memories/search"),
("GET", f"{BASE_URL}/stats"),
("GET", f"{BASE_URL}/models"),
("GET", f"{BASE_URL}/users"),
]
for method, url in endpoints_requiring_auth:
if method == "GET":
response = requests.get(url, timeout=5)
else:
response = requests.post(
url, json={"query": "test", "user_id": TEST_USER}, timeout=5
)
assert response.status_code in [401, 403], (
f"{method} {url} should require auth, got {response.status_code}"
)
if verbose:
print(f" {method} {url}: {response.status_code} (auth required)")
def test_ownership_verification(verbose):
"""Test that users can only access their own data"""
other_user = "other_user_not_me"
response = requests.get(
f"{BASE_URL}/memories/{other_user}", headers=AUTH_HEADERS, timeout=5
)
assert response.status_code in [403, 404], (
f"Accessing other user's memories should be denied, got {response.status_code}"
)
if verbose:
print(f" Ownership check passed: {response.status_code}")
def test_request_size_limit(verbose):
"""Test request size limit enforcement (10MB max)"""
large_payload = {
"messages": [{"role": "user", "content": "x" * (11 * 1024 * 1024)}],
"user_id": TEST_USER,
}
try:
response = requests.post(
f"{BASE_URL}/memories",
json=large_payload,
headers={**AUTH_HEADERS, "Content-Length": str(11 * 1024 * 1024)},
timeout=5,
)
assert response.status_code == 413, (
f"Large request should return 413, got {response.status_code}"
)
if verbose:
print(f" Request size limit enforced: {response.status_code}")
except requests.exceptions.RequestException as e:
if verbose:
print(
f" Request size limit test: connection issue (expected for large payload)"
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()