Compare commits

...

10 commits

Author SHA1 Message Date
82accabc73 fix: properly log settings as structlog kwargs instead of dropping them 2026-01-16 00:40:53 +05:30
6f9b545c15 fix: remove invalid positional arg from structlog call 2026-01-16 00:38:41 +05:30
5bcecf4649 fix: use structlog instead of logging in mem0_manager for kwargs support 2026-01-16 00:34:51 +05:30
638a591dc5 add setup script with volume reset option, fix default embedding dims 2026-01-16 00:29:29 +05:30
a190527076 fix: use npm_network for NPM proxy, expose instead of ports 2026-01-16 00:01:41 +05:30
9e86c30548 fix: pass OLLAMA_BASE_URL, EMBEDDING_MODEL, EMBEDDING_DIMS to container 2026-01-15 23:55:44 +05:30
2c1d73a1ec add OpenAI-compatible endpoint and improved login UI
- Add /v1/chat/completions and /chat/completions endpoints (OpenAI SDK compatible)
- Add streaming support with SSE for chat completions
- Add get_current_user_openai auth supporting Bearer token and X-API-Key
- Add OpenAI-compatible request/response models (OpenAIChatCompletionRequest, etc.)
- Cherry-pick improved login UI from cloud branch (styled login screen, logout button)
2026-01-15 23:29:08 +05:30
a228780146 production improvements: configurable embeddings, v1.1, O(1) ownership, retries
- Make Ollama URL configurable via OLLAMA_BASE_URL env var
- Add version: v1.1 to Mem0 config (required for latest features)
- Make embedding model and dimensions configurable
- Fix ownership check: O(1) lookup instead of fetching 10k records
- Add tenacity retry logic for database operations
2026-01-15 23:01:18 +05:30
50edce2d3c security hardening: add auth, rate limiting, fix info disclosure
- Add auth to /models and /users endpoints
- Add rate limiting to all endpoints (10-120/min based on operation type)
- Fix 11 info disclosure issues (detail=str(e) -> generic message)
- Fix 2 silent except blocks with proper logging
- Fix 7 raise e -> raise for proper exception chaining
- Fix health check to not expose exception details
- Update tests with X-API-Key headers and security tests
2026-01-15 22:41:24 +05:30
35c1bbec4e added MCP HTTP endpoint with auth
Exposes memory operations as MCP tools over /mcp endpoint:
- add_memory, search_memory, remove_memory, chat
- API key auth via x-api-key or Authorization header
- User isolation enforced via contextvars
2026-01-11 14:00:16 +05:30
11 changed files with 2088 additions and 522 deletions

View file

@ -1,7 +1,7 @@
"""Simple API key authentication for Mem0 Interface."""
from typing import Optional
from fastapi import HTTPException, Security, status
from fastapi import HTTPException, Security, status, Header
from fastapi.security import APIKeyHeader
import structlog
@ -19,7 +19,9 @@ class AuthService:
def __init__(self):
"""Initialize auth service with API key to user mapping."""
self.api_key_to_user = settings.api_key_mapping
logger.info(f"Auth service initialized with {len(self.api_key_to_user)} API keys")
logger.info(
f"Auth service initialized with {len(self.api_key_to_user)} API keys"
)
def verify_api_key(self, api_key: str) -> str:
"""
@ -37,8 +39,7 @@ class AuthService:
if api_key not in self.api_key_to_user:
logger.warning(f"Invalid API key attempted: {api_key[:10]}...")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
)
user_id = self.api_key_to_user[api_key]
@ -68,7 +69,7 @@ class AuthService:
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied: You can only access your own memories"
detail=f"Access denied: You can only access your own memories",
)
return authenticated_user_id
@ -91,9 +92,46 @@ async def get_current_user(api_key: str = Security(api_key_header)) -> str:
return auth_service.verify_api_key(api_key)
async def get_current_user_openai(
authorization: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
) -> str:
"""
FastAPI dependency for OpenAI-compatible authentication.
Supports both Authorization: Bearer and X-API-Key headers.
Args:
authorization: Authorization header (Bearer token)
x_api_key: X-API-Key header
Returns:
str: Authenticated user_id
Raises:
HTTPException: If no valid API key is provided
"""
api_key = None
# Try Bearer token first (OpenAI standard)
if authorization and authorization.startswith("Bearer "):
api_key = authorization[7:] # Remove "Bearer " prefix
logger.debug("Extracted API key from Authorization Bearer token")
# Fall back to X-API-Key header
elif x_api_key:
api_key = x_api_key
logger.debug("Extracted API key from X-API-Key header")
else:
logger.warning("No API key provided in Authorization or X-API-Key headers")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API key. Provide either 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header",
)
return auth_service.verify_api_key(api_key)
async def verify_user_access(
api_key: str = Security(api_key_header),
user_id: Optional[str] = None
api_key: str = Security(api_key_header), user_id: Optional[str] = None
) -> str:
"""
FastAPI dependency to verify user can access the requested user_id.
@ -114,7 +152,7 @@ async def verify_user_access(
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: You can only access your own memories"
detail="Access denied: You can only access your own memories",
)
return authenticated_user_id

View file

@ -11,39 +11,87 @@ class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env",
case_sensitive=False,
extra='ignore'
env_file=".env", case_sensitive=False, extra="ignore"
)
# API Configuration
# Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env)
openai_api_key: str = Field(validation_alias=AliasChoices('OPENAI_API_KEY', 'OPENAI_COMPAT_API_KEY', 'openai_api_key'))
openai_base_url: str = Field(validation_alias=AliasChoices('OPENAI_BASE_URL', 'OPENAI_COMPAT_BASE_URL', 'openai_base_url'))
cohere_api_key: str = Field(validation_alias=AliasChoices('COHERE_API_KEY', 'cohere_api_key'))
openai_api_key: str = Field(
validation_alias=AliasChoices(
"OPENAI_API_KEY", "OPENAI_COMPAT_API_KEY", "openai_api_key"
)
)
openai_base_url: str = Field(
validation_alias=AliasChoices(
"OPENAI_BASE_URL", "OPENAI_COMPAT_BASE_URL", "openai_base_url"
)
)
cohere_api_key: str = Field(
validation_alias=AliasChoices("COHERE_API_KEY", "cohere_api_key")
)
# Database Configuration
qdrant_host: str = Field(default="localhost", validation_alias=AliasChoices('QDRANT_HOST', 'qdrant_host'))
qdrant_port: int = Field(default=6333, validation_alias=AliasChoices('QDRANT_PORT', 'qdrant_port'))
qdrant_collection_name: str = Field(default="mem0", validation_alias=AliasChoices('QDRANT_COLLECTION_NAME', 'qdrant_collection_name'))
qdrant_host: str = Field(
default="localhost", validation_alias=AliasChoices("QDRANT_HOST", "qdrant_host")
)
qdrant_port: int = Field(
default=6333, validation_alias=AliasChoices("QDRANT_PORT", "qdrant_port")
)
qdrant_collection_name: str = Field(
default="mem0",
validation_alias=AliasChoices(
"QDRANT_COLLECTION_NAME", "qdrant_collection_name"
),
)
# Neo4j Configuration
neo4j_uri: str = Field(default="bolt://localhost:7687", validation_alias=AliasChoices('NEO4J_URI', 'neo4j_uri'))
neo4j_username: str = Field(default="neo4j", validation_alias=AliasChoices('NEO4J_USERNAME', 'neo4j_username'))
neo4j_password: str = Field(default="mem0_neo4j_password", validation_alias=AliasChoices('NEO4J_PASSWORD', 'neo4j_password'))
neo4j_uri: str = Field(
default="bolt://localhost:7687",
validation_alias=AliasChoices("NEO4J_URI", "neo4j_uri"),
)
neo4j_username: str = Field(
default="neo4j",
validation_alias=AliasChoices("NEO4J_USERNAME", "neo4j_username"),
)
neo4j_password: str = Field(
default="mem0_neo4j_password",
validation_alias=AliasChoices("NEO4J_PASSWORD", "neo4j_password"),
)
# Application Configuration
log_level: str = Field(default="INFO", validation_alias=AliasChoices('LOG_LEVEL', 'log_level'))
cors_origins: str = Field(default="http://localhost:3000", validation_alias=AliasChoices('CORS_ORIGINS', 'cors_origins'))
log_level: str = Field(
default="INFO", validation_alias=AliasChoices("LOG_LEVEL", "log_level")
)
cors_origins: str = Field(
default="http://localhost:3000",
validation_alias=AliasChoices("CORS_ORIGINS", "cors_origins"),
)
# Model Configuration - Ultra-minimal (single model)
default_model: str = Field(default="claude-sonnet-4", validation_alias=AliasChoices('DEFAULT_MODEL', 'default_model'))
default_model: str = Field(
default="claude-sonnet-4",
validation_alias=AliasChoices("DEFAULT_MODEL", "default_model"),
)
# Embedder Configuration
ollama_base_url: str = Field(
default="http://host.docker.internal:11434",
validation_alias=AliasChoices("OLLAMA_BASE_URL", "ollama_base_url"),
)
embedding_model: str = Field(
default="qwen3-embedding:4b-q8_0",
validation_alias=AliasChoices("EMBEDDING_MODEL", "embedding_model"),
)
embedding_dims: int = Field(
default=2560, validation_alias=AliasChoices("EMBEDDING_DIMS", "embedding_dims")
)
# Authentication Configuration
# Format: JSON string mapping API keys to user IDs
# Example: {"api_key_123": "alice", "api_key_456": "bob"}
api_keys: str = Field(default="{}", validation_alias=AliasChoices('API_KEYS', 'api_keys'))
api_keys: str = Field(
default="{}", validation_alias=AliasChoices("API_KEYS", "api_keys")
)
@property
def cors_origins_list(self) -> List[str]:

View file

@ -1,25 +1,55 @@
"""Main FastAPI application for Mem0 Interface POC."""
import json
import logging
import time
from datetime import datetime
from typing import List, Dict, Any, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
import structlog
import asyncio
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from config import settings
# Rate limiter - uses IP address as key, falls back to API key for authenticated requests
def get_rate_limit_key(request: Request) -> str:
"""Get rate limit key - prefer API key if available, otherwise IP."""
api_key = request.headers.get("x-api-key", "")
if api_key:
return f"apikey:{api_key[:16]}" # Use first 16 chars of API key
return get_remote_address(request)
limiter = Limiter(key_func=get_rate_limit_key)
from models import (
ChatRequest, MemoryAddRequest, MemoryAddResponse,
MemorySearchRequest, MemorySearchResponse, MemoryUpdateRequest,
MemoryItem, GraphResponse, HealthResponse, ErrorResponse,
GlobalStatsResponse, UserStatsResponse
ChatRequest,
MemoryAddRequest,
MemoryAddResponse,
MemorySearchRequest,
MemorySearchResponse,
MemoryUpdateRequest,
MemoryItem,
GraphResponse,
HealthResponse,
ErrorResponse,
GlobalStatsResponse,
UserStatsResponse,
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
OpenAIChoice,
OpenAIChoiceMessage,
OpenAIUsage,
)
from mem0_manager import mem0_manager
from auth import get_current_user, auth_service
from auth import get_current_user, get_current_user_openai, auth_service
# Configure structured logging
structlog.configure(
@ -32,7 +62,7 @@ structlog.configure(
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer()
structlog.processors.JSONRenderer(),
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
@ -58,9 +88,27 @@ async def lifespan(app: FastAPI):
else:
logger.info("All services are healthy")
# Start MCP session manager if available
mcp_context = None
try:
from mcp_server import mcp_lifespan
mcp_context = mcp_lifespan()
await mcp_context.__aenter__()
except ImportError:
logger.warning("MCP server not available")
except Exception as e:
logger.error(f"Failed to start MCP session manager: {e}")
yield
# Shutdown
if mcp_context:
try:
await mcp_context.__aexit__(None, None, None)
except Exception as e:
logger.error(f"Error stopping MCP session manager: {e}")
logger.info("Shutting down Mem0 Interface POC")
@ -69,19 +117,47 @@ app = FastAPI(
title="Mem0 Interface POC",
description="Minimal but fully functional Mem0 interface with PostgreSQL and Neo4j integration",
version="1.0.0",
lifespan=lifespan
lifespan=lifespan,
)
# Add CORS middleware - Allow all origins for development
# Add rate limiter to app state and exception handler
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Add CORS middleware - Allow all origins (secured via API key auth)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for development
allow_origins=["*"], # Allow all origins - secured via API key authentication
allow_credentials=False, # Must be False when allow_origins=["*"]
allow_methods=["*"],
allow_headers=["*"],
)
# Request size limit middleware - prevent DoS via large payloads
MAX_REQUEST_SIZE = 10 * 1024 * 1024 # 10MB limit
@app.middleware("http")
async def limit_request_size(request, call_next):
"""Reject requests that exceed the maximum allowed size."""
content_length = request.headers.get("content-length")
if content_length:
try:
if int(content_length) > MAX_REQUEST_SIZE:
return JSONResponse(
status_code=413,
content={
"error": "Request payload too large",
"max_size_bytes": MAX_REQUEST_SIZE,
"max_size_mb": MAX_REQUEST_SIZE / (1024 * 1024),
},
)
except ValueError:
pass # Invalid content-length header, let it through for other validation
return await call_next(request)
# Request logging middleware with monitoring
@app.middleware("http")
async def log_requests(request, call_next):
@ -94,19 +170,19 @@ async def log_requests(request, call_next):
# Extract user_id from request if available
user_id = None
if request.method == "POST":
# Try to extract user_id from request body for POST requests
try:
body = await request.body()
if body:
import json
data = json.loads(body)
user_id = data.get('user_id')
except:
pass
user_id = data.get("user_id")
except json.JSONDecodeError:
pass # Non-JSON body, user_id extraction not possible
except Exception as e:
logger.debug("Could not extract user_id from request body", error=str(e))
elif "user_id" in str(request.url.path):
# Extract user_id from path for GET requests
path_parts = request.url.path.split('/')
if len(path_parts) > 2 and path_parts[-2] in ['memories', 'stats']:
path_parts = request.url.path.split("/")
if len(path_parts) > 2 and path_parts[-2] in ["memories", "stats"]:
user_id = path_parts[-1]
# Log start of request
@ -115,7 +191,7 @@ async def log_requests(request, call_next):
correlation_id=correlation_id,
method=request.method,
path=request.url.path,
user_id=user_id
user_id=user_id,
)
response = await call_next(request)
@ -136,7 +212,7 @@ async def log_requests(request, call_next):
status_code=response.status_code,
process_time_ms=round(process_time_ms, 2),
user_id=user_id,
slow_request=True
slow_request=True,
)
elif response.status_code >= 400:
logger.error(
@ -147,7 +223,7 @@ async def log_requests(request, call_next):
status_code=response.status_code,
process_time_ms=round(process_time_ms, 2),
user_id=user_id,
slow_request=False
slow_request=False,
)
else:
logger.info(
@ -158,7 +234,7 @@ async def log_requests(request, call_next):
status_code=response.status_code,
process_time_ms=round(process_time_ms, 2),
user_id=user_id,
slow_request=False
slow_request=False,
)
return response
@ -167,11 +243,23 @@ async def log_requests(request, call_next):
# Exception handlers
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""Global exception handler."""
logger.error(f"Unhandled exception: {exc}", exc_info=True)
"""Global exception handler - logs details but returns generic message."""
# Log full exception details for debugging (internal only)
logger.error(
"Unhandled exception",
exc_info=True,
path=request.url.path,
method=request.method,
error_type=type(exc).__name__,
error_message=str(exc),
)
# Return generic error to client - don't expose internal details
return JSONResponse(
status_code=500,
content={"error": "Internal server error", "detail": str(exc)}
content={
"error": "Internal server error",
"message": "An unexpected error occurred",
},
)
@ -181,50 +269,59 @@ async def health_check():
"""Check the health of all services."""
try:
services = await mem0_manager.health_check()
overall_status = "healthy" if all("healthy" in status for status in services.values()) else "degraded"
overall_status = (
"healthy"
if all("healthy" in status for status in services.values())
else "degraded"
)
return HealthResponse(
status=overall_status,
services=services,
timestamp=datetime.utcnow().isoformat()
timestamp=datetime.utcnow().isoformat(),
)
except Exception as e:
logger.error(f"Health check failed: {e}")
logger.error(f"Health check failed: {e}", exc_info=True)
return HealthResponse(
status="unhealthy",
services={"error": str(e)},
timestamp=datetime.utcnow().isoformat()
services={"error": "Health check failed - see logs for details"},
timestamp=datetime.utcnow().isoformat(),
)
# Core chat endpoint with memory enhancement
@app.post("/chat")
@limiter.limit("30/minute") # Chat is expensive - limit to 30/min
async def chat_with_memory(
request: ChatRequest,
authenticated_user: str = Depends(get_current_user)
request: Request,
chat_request: ChatRequest,
authenticated_user: str = Depends(get_current_user),
):
"""Ultra-minimal chat endpoint - pure Mem0 + custom endpoint."""
try:
# Verify user can only access their own data
if authenticated_user != request.user_id:
if authenticated_user != chat_request.user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only chat as yourself (authenticated as '{authenticated_user}')"
detail=f"Access denied: You can only chat as yourself (authenticated as '{authenticated_user}')",
)
logger.info(f"Processing chat request for user: {request.user_id}")
logger.info(f"Processing chat request for user: {chat_request.user_id}")
# Convert ChatMessage objects to dict format if context provided
context_dict = None
if request.context:
context_dict = [{"role": msg.role, "content": msg.content} for msg in request.context]
if chat_request.context:
context_dict = [
{"role": msg.role, "content": msg.content}
for msg in chat_request.context
]
result = await mem0_manager.chat_with_memory(
message=request.message,
user_id=request.user_id,
agent_id=request.agent_id,
run_id=request.run_id,
context=context_dict
message=chat_request.message,
user_id=chat_request.user_id,
agent_id=chat_request.agent_id,
run_id=chat_request.run_id,
context=context_dict,
)
return result
@ -233,32 +330,173 @@ async def chat_with_memory(
raise
except Exception as e:
logger.error(f"Error in chat endpoint: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
async def stream_openai_response(
completion_id: str, model: str, content: str, created: int
):
"""Generate SSE stream for OpenAI-compatible streaming by chunking the response."""
import uuid
# First chunk with role
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
# Stream content in chunks (3 words at a time for smooth effect)
words = content.split()
chunk_size = 3
for i in range(0, len(words), chunk_size):
word_chunk = " ".join(words[i : i + chunk_size])
if i + chunk_size < len(words):
word_chunk += " "
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{"index": 0, "delta": {"content": word_chunk}, "finish_reason": None}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.05)
# Final chunk
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
@app.post("/chat/completions")
@limiter.limit("30/minute")
async def openai_chat_completions(
request: Request,
completion_request: OpenAIChatCompletionRequest,
authenticated_user: str = Depends(get_current_user_openai),
):
"""OpenAI-compatible chat completions endpoint with mem0 memory integration."""
try:
import uuid
user_id = authenticated_user
logger.info(
f"OpenAI chat completion for user: {user_id} (streaming={completion_request.stream})"
)
# Extract last user message
user_messages = [
m for m in completion_request.messages if m.get("role") == "user"
]
if not user_messages:
raise HTTPException(
status_code=400,
detail="No user messages provided. Include at least one message with role='user'.",
)
last_message = user_messages[-1].get("content", "")
context = (
completion_request.messages[:-1]
if len(completion_request.messages) > 1
else None
)
# Call chat_with_memory
result = await mem0_manager.chat_with_memory(
message=last_message,
user_id=user_id,
context=context,
)
completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created_time = int(time.time())
assistant_content = result.get("response", "")
if completion_request.stream:
return StreamingResponse(
stream_openai_response(
completion_id=completion_id,
model=settings.default_model,
content=assistant_content,
created=created_time,
),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
else:
return OpenAIChatCompletionResponse(
id=completion_id,
object="chat.completion",
created=created_time,
model=settings.default_model,
choices=[
OpenAIChoice(
index=0,
message=OpenAIChoiceMessage(
role="assistant", content=assistant_content
),
finish_reason="stop",
)
],
usage=OpenAIUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in OpenAI chat completions: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Memory management endpoints - pure Mem0 passthroughs
@app.post("/memories")
@limiter.limit("60/minute") # Memory operations - 60/min
async def add_memories(
request: MemoryAddRequest,
authenticated_user: str = Depends(get_current_user)
request: Request,
memory_request: MemoryAddRequest,
authenticated_user: str = Depends(get_current_user),
):
"""Add memories - pure Mem0 passthrough."""
try:
# Verify user can only add to their own memories
if authenticated_user != request.user_id:
if authenticated_user != memory_request.user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only add memories for yourself (authenticated as '{authenticated_user}')"
detail=f"Access denied: You can only add memories for yourself (authenticated as '{authenticated_user}')",
)
logger.info(f"Adding memories for user: {request.user_id}")
logger.info(f"Adding memories for user: {memory_request.user_id}")
result = await mem0_manager.add_memories(
messages=request.messages,
user_id=request.user_id,
agent_id=request.agent_id,
run_id=request.run_id,
metadata=request.metadata
messages=memory_request.messages,
user_id=memory_request.user_id,
agent_id=memory_request.agent_id,
run_id=memory_request.run_id,
metadata=memory_request.metadata,
)
return result
@ -267,33 +505,40 @@ async def add_memories(
raise
except Exception as e:
logger.error(f"Error adding memories: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.post("/memories/search")
@limiter.limit("120/minute") # Search is lighter - 120/min
async def search_memories(
request: MemorySearchRequest,
authenticated_user: str = Depends(get_current_user)
request: Request,
search_request: MemorySearchRequest,
authenticated_user: str = Depends(get_current_user),
):
"""Search memories - pure Mem0 passthrough."""
try:
# Verify user can only search their own memories
if authenticated_user != request.user_id:
if authenticated_user != search_request.user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only search your own memories (authenticated as '{authenticated_user}')"
detail=f"Access denied: You can only search your own memories (authenticated as '{authenticated_user}')",
)
logger.info(f"Searching memories for user: {request.user_id}, query: {request.query}")
logger.info(
f"Searching memories for user: {search_request.user_id}, query: {search_request.query}"
)
result = await mem0_manager.search_memories(
query=request.query,
user_id=request.user_id,
limit=request.limit,
threshold=request.threshold or 0.2,
filters=request.filters,
agent_id=request.agent_id,
run_id=request.run_id
query=search_request.query,
user_id=search_request.user_id,
limit=search_request.limit,
threshold=search_request.threshold or 0.2,
filters=search_request.filters,
agent_id=search_request.agent_id,
run_id=search_request.run_id,
)
return result
@ -302,16 +547,21 @@ async def search_memories(
raise
except Exception as e:
logger.error(f"Error searching memories: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.get("/memories/{user_id}")
@limiter.limit("120/minute")
async def get_user_memories(
request: Request,
user_id: str,
authenticated_user: str = Depends(get_current_user),
limit: int = 10,
agent_id: Optional[str] = None,
run_id: Optional[str] = None
run_id: Optional[str] = None,
):
"""Get all memories for a user with hierarchy filtering - pure Mem0 passthrough."""
try:
@ -319,16 +569,13 @@ async def get_user_memories(
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only retrieve your own memories (authenticated as '{authenticated_user}')"
detail=f"Access denied: You can only retrieve your own memories (authenticated as '{authenticated_user}')",
)
logger.info(f"Retrieving memories for user: {user_id}")
memories = await mem0_manager.get_user_memories(
user_id=user_id,
limit=limit,
agent_id=agent_id,
run_id=run_id
user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id
)
return memories
@ -337,28 +584,44 @@ async def get_user_memories(
raise
except Exception as e:
logger.error(f"Error retrieving user memories: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.put("/memories")
@limiter.limit("60/minute")
async def update_memory(
request: MemoryUpdateRequest,
authenticated_user: str = Depends(get_current_user)
request: Request,
update_request: MemoryUpdateRequest,
authenticated_user: str = Depends(get_current_user),
):
"""Update memory - pure Mem0 passthrough."""
"""Update memory - verifies ownership before update."""
try:
# Verify user owns the memory being updated
if authenticated_user != request.user_id:
if authenticated_user != update_request.user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')"
detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')",
)
logger.info(f"Updating memory: {request.memory_id}")
# Verify memory ownership with O(1) lookup instead of fetching all memories
if not await mem0_manager.verify_memory_ownership(
update_request.memory_id, authenticated_user
):
raise HTTPException(
status_code=404,
detail=f"Memory '{update_request.memory_id}' not found or access denied",
)
logger.info(
f"Updating memory: {update_request.memory_id}", user_id=authenticated_user
)
result = await mem0_manager.update_memory(
memory_id=request.memory_id,
content=request.content,
memory_id=update_request.memory_id,
content=update_request.content,
)
return result
@ -367,25 +630,31 @@ async def update_memory(
raise
except Exception as e:
logger.error(f"Error updating memory: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.delete("/memories/{memory_id}")
@limiter.limit("60/minute")
async def delete_memory(
request: Request,
memory_id: str,
user_id: str, # Add user_id as query parameter for verification
authenticated_user: str = Depends(get_current_user)
authenticated_user: str = Depends(get_current_user),
):
"""Delete a specific memory."""
"""Delete a specific memory - verifies ownership before deletion."""
try:
# Verify user owns the memory being deleted
if authenticated_user != user_id:
# Verify memory ownership with O(1) lookup instead of fetching all memories
if not await mem0_manager.verify_memory_ownership(
memory_id, authenticated_user
):
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')"
status_code=404,
detail=f"Memory '{memory_id}' not found or access denied",
)
logger.info(f"Deleting memory: {memory_id}")
logger.info(f"Deleting memory: {memory_id}", user_id=authenticated_user)
result = await mem0_manager.delete_memory(memory_id=memory_id)
@ -395,13 +664,16 @@ async def delete_memory(
raise
except Exception as e:
logger.error(f"Error deleting memory: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.delete("/memories/user/{user_id}")
@limiter.limit("10/minute") # Dangerous bulk delete - heavily rate limited
async def delete_user_memories(
user_id: str,
authenticated_user: str = Depends(get_current_user)
request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
):
"""Delete all memories for a specific user."""
try:
@ -409,7 +681,7 @@ async def delete_user_memories(
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')"
detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')",
)
logger.info(f"Deleting all memories for user: {user_id}")
@ -422,14 +694,17 @@ async def delete_user_memories(
raise
except Exception as e:
logger.error(f"Error deleting user memories: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Graph relationships endpoint - pure Mem0 passthrough
@app.get("/graph/relationships/{user_id}")
@limiter.limit("60/minute")
async def get_graph_relationships(
user_id: str,
authenticated_user: str = Depends(get_current_user)
request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
):
"""Get graph relationships - pure Mem0 passthrough."""
try:
@ -437,11 +712,13 @@ async def get_graph_relationships(
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only view your own relationships (authenticated as '{authenticated_user}')"
detail=f"Access denied: You can only view your own relationships (authenticated as '{authenticated_user}')",
)
logger.info(f"Retrieving graph relationships for user: {user_id}")
result = await mem0_manager.get_graph_relationships(user_id=user_id, agent_id=None, run_id=None, limit=10000)
result = await mem0_manager.get_graph_relationships(
user_id=user_id, agent_id=None, run_id=None, limit=10000
)
return result
@ -449,68 +726,101 @@ async def get_graph_relationships(
raise
except Exception as e:
logger.error(f"Error retrieving graph relationships: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Memory history endpoint - new feature
@app.get("/memories/{memory_id}/history")
async def get_memory_history(memory_id: str):
@limiter.limit("120/minute")
async def get_memory_history(
request: Request,
memory_id: str,
user_id: str, # Required query param to verify ownership
authenticated_user: str = Depends(get_current_user),
):
"""Get memory change history - pure Mem0 passthrough."""
try:
logger.info(f"Retrieving history for memory: {memory_id}")
# Verify user can only access their own memory history
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only view your own memory history (authenticated as '{authenticated_user}')",
)
# Verify memory ownership with O(1) lookup instead of fetching all memories
if not await mem0_manager.verify_memory_ownership(memory_id, user_id):
raise HTTPException(
status_code=404,
detail=f"Memory '{memory_id}' not found or access denied",
)
logger.info(f"Retrieving history for memory: {memory_id}", user_id=user_id)
result = await mem0_manager.get_memory_history(memory_id=memory_id)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving memory history: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Statistics and monitoring endpoints
@app.get("/stats", response_model=GlobalStatsResponse)
async def get_global_stats():
"""Get global application statistics."""
@limiter.limit("60/minute")
async def get_global_stats(
request: Request, authenticated_user: str = Depends(get_current_user)
):
"""Get global application statistics - requires authentication."""
try:
from monitoring import stats
# Get basic stats from monitoring
basic_stats = stats.get_global_stats()
# Get actual memory count from Mem0 (simplified approach)
try:
# This is a rough estimate - in production you might want a more efficient method
sample_result = await mem0_manager.search_memories(query="*", user_id="__stats_check__", limit=1)
# For now, we'll use the basic stats total_memories value
# You could implement a more accurate count by querying the database directly
total_memories = basic_stats['total_memories'] # Will be 0 for now
except:
sample_result = await mem0_manager.search_memories(
query="*", user_id="__stats_check__", limit=1
)
total_memories = basic_stats["total_memories"]
except Exception:
total_memories = 0
return GlobalStatsResponse(
total_memories=total_memories,
total_users=basic_stats['total_users'],
api_calls_today=basic_stats['api_calls_today'],
avg_response_time_ms=basic_stats['avg_response_time_ms'],
total_users=basic_stats["total_users"],
api_calls_today=basic_stats["api_calls_today"],
avg_response_time_ms=basic_stats["avg_response_time_ms"],
memory_operations={
"add": basic_stats['memory_operations']['add'],
"search": basic_stats['memory_operations']['search'],
"update": basic_stats['memory_operations']['update'],
"delete": basic_stats['memory_operations']['delete']
"add": basic_stats["memory_operations"]["add"],
"search": basic_stats["memory_operations"]["search"],
"update": basic_stats["memory_operations"]["update"],
"delete": basic_stats["memory_operations"]["delete"],
},
uptime_seconds=basic_stats['uptime_seconds']
uptime_seconds=basic_stats["uptime_seconds"],
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting global stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.get("/stats/{user_id}", response_model=UserStatsResponse)
@limiter.limit("120/minute")
async def get_user_stats(
user_id: str,
authenticated_user: str = Depends(get_current_user)
request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
):
"""Get user-specific statistics."""
try:
@ -518,7 +828,7 @@ async def get_user_stats(
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only view your own statistics (authenticated as '{authenticated_user}')"
detail=f"Access denied: You can only view your own statistics (authenticated as '{authenticated_user}')",
)
from monitoring import stats
@ -528,63 +838,92 @@ async def get_user_stats(
# Get actual memory count for this user
try:
user_memories = await mem0_manager.get_user_memories(user_id=user_id, limit=10000)
user_memories = await mem0_manager.get_user_memories(
user_id=user_id, limit=10000
)
memory_count = len(user_memories)
except:
except Exception as e:
logger.warning(f"Failed to get memory count for user {user_id}: {e}")
memory_count = 0
# Get relationship count for this user
try:
graph_data = await mem0_manager.get_graph_relationships(user_id=user_id, agent_id=None, run_id=None)
relationship_count = len(graph_data.get('relationships', []))
except:
graph_data = await mem0_manager.get_graph_relationships(
user_id=user_id, agent_id=None, run_id=None
)
relationship_count = len(graph_data.get("relationships", []))
except Exception as e:
logger.warning(f"Failed to get relationship count for user {user_id}: {e}")
relationship_count = 0
return UserStatsResponse(
user_id=user_id,
memory_count=memory_count,
relationship_count=relationship_count,
last_activity=basic_stats['last_activity'],
api_calls_today=basic_stats['api_calls_today'],
avg_response_time_ms=basic_stats['avg_response_time_ms']
last_activity=basic_stats["last_activity"],
api_calls_today=basic_stats["api_calls_today"],
avg_response_time_ms=basic_stats["avg_response_time_ms"],
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting user stats for {user_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Utility endpoints
@app.get("/models")
async def get_available_models():
"""Get current model configuration."""
@limiter.limit("120/minute")
async def get_available_models(
request: Request, authenticated_user: str = Depends(get_current_user)
):
"""Get current model configuration - requires authentication."""
return {
"current_model": settings.default_model,
"endpoint": settings.openai_base_url,
"note": "Using single model with pure Mem0 intelligence"
"note": "Using single model with pure Mem0 intelligence",
}
@app.get("/users")
async def get_active_users():
"""Get list of users with memories (simplified implementation)."""
@limiter.limit("60/minute")
async def get_active_users(
request: Request, authenticated_user: str = Depends(get_current_user)
):
"""Get list of users with memories (simplified implementation) - requires authentication."""
# This would typically query the database for users with memories
# For now, return a placeholder
return {
"message": "This endpoint would return users with stored memories",
"note": "Implementation depends on direct database access or Mem0 user enumeration capabilities"
"note": "Implementation depends on direct database access or Mem0 user enumeration capabilities",
}
# Mount MCP server at /mcp endpoint
try:
from mcp_server import create_mcp_app
mcp_app = create_mcp_app()
app.mount("/mcp", mcp_app)
logger.info("MCP server mounted at /mcp")
except ImportError as e:
logger.warning(f"MCP server not available (missing dependencies): {e}")
except Exception as e:
logger.error(f"Failed to mount MCP server: {e}")
if __name__ == "__main__":
import uvicorn
print("Starting UVicorn server...")
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
log_level=settings.log_level.lower(),
reload=True
reload=True,
)

240
backend/mcp_server.py Normal file
View file

@ -0,0 +1,240 @@
"""MCP Server for Mem0 Memory Service.
Exposes memory operations as MCP tools over HTTP with API key authentication.
"""
import contextlib
import logging
from contextvars import ContextVar
from typing import Optional
from pydantic import Field
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse
from starlette.routing import Mount
from mcp.server.fastmcp import FastMCP
from config import settings
logger = logging.getLogger(__name__)
# Context variable for authenticated user_id
# Set by middleware, read by tools
current_user_id: ContextVar[str] = ContextVar("current_user_id", default="")
def get_authenticated_user() -> str:
"""Get the authenticated user_id from context.
Raises:
ValueError: If no authenticated user in context.
"""
user_id = current_user_id.get()
if not user_id:
raise ValueError("No authenticated user in context")
return user_id
class MCPAuthMiddleware(BaseHTTPMiddleware):
"""Middleware to authenticate MCP requests using API key."""
async def dispatch(self, request, call_next):
# Extract API key from headers
api_key = (
request.headers.get("x-api-key") or
request.headers.get("X-API-Key") or
request.headers.get("Authorization", "").replace("Bearer ", "")
)
if not api_key:
return JSONResponse(
{"error": "Missing API key. Provide x-api-key or Authorization header."},
status_code=401
)
# Map API key to user_id
user_id = settings.api_key_mapping.get(api_key)
if not user_id:
return JSONResponse(
{"error": "Invalid API key"},
status_code=401
)
# Store user_id in context for tools to access
token = current_user_id.set(user_id)
try:
response = await call_next(request)
return response
finally:
current_user_id.reset(token)
# Create FastMCP server
# streamable_http_path="/" since we mount at /mcp in main.py
mcp = FastMCP(
"Mem0 Memory Service",
stateless_http=True,
json_response=True,
streamable_http_path="/"
)
@mcp.tool()
async def add_memory(
content: str = Field(description="Content to add to memory"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"),
) -> dict:
"""Add content to the user's memory.
The user_id is automatically determined from the API key authentication.
Use agent_id and run_id for multi-agent or session-based memory organization.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}")
result = await mem0_manager.add_memories(
messages=[{"role": "user", "content": content}],
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=metadata
)
return result
@mcp.tool()
async def search_memory(
query: str = Field(description="Search query to find relevant memories"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"),
limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"),
) -> dict:
"""Search the user's memories.
Returns memories most relevant to the query. The user_id is automatically
determined from API key authentication.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}")
result = await mem0_manager.search_memories(
query=query,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit
)
return result
@mcp.tool()
async def remove_memory(
memory_id: str = Field(description="The ID of the memory to remove"),
) -> dict:
"""Remove a specific memory by its ID.
Only memories belonging to the authenticated user can be deleted.
Verifies ownership before deletion.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}")
# Verify ownership: get user's memories and check if memory_id exists
user_memories = await mem0_manager.get_user_memories(
user_id=user_id,
limit=10000 # Get all to check ownership
)
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
if memory_id not in memory_ids:
raise ValueError(f"Memory '{memory_id}' not found or access denied")
result = await mem0_manager.delete_memory(memory_id=memory_id)
return result
@mcp.tool()
async def chat(
message: str = Field(description="The user's message to chat with"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
) -> str:
"""Chat with memory context.
Retrieves relevant memories based on the message, generates a response
using the configured LLM, and stores the conversation in memory.
The user_id is automatically determined from API key authentication.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...")
result = await mem0_manager.chat_with_memory(
message=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id
)
return result.get("response", "")
@contextlib.asynccontextmanager
async def mcp_lifespan():
"""Context manager for MCP session lifecycle.
Must be used in the main FastAPI lifespan since mounted app
lifespans don't run automatically.
"""
async with mcp.session_manager.run():
logger.info("MCP session manager started")
yield
logger.info("MCP session manager stopped")
def create_mcp_app() -> Starlette:
"""Create and configure the MCP Starlette application.
Returns a Starlette app with MCP endpoints, authentication middleware,
and CORS support.
Note: The MCP session manager must be started via mcp_lifespan() in
the main FastAPI lifespan, not here.
"""
# Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp
streamable_app = mcp.streamable_http_app()
# Create Starlette app with MCP routes
app = Starlette(
routes=[Mount("/", app=streamable_app)],
)
# Add authentication middleware
app.add_middleware(MCPAuthMiddleware)
# Add CORS middleware for browser clients
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
allow_headers=["*"],
expose_headers=["Mcp-Session-Id"],
)
return app

View file

@ -5,24 +5,48 @@ from typing import Dict, List, Optional, Any
from datetime import datetime
from mem0 import Memory
from openai import OpenAI
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
)
import structlog
from config import settings
from monitoring import timed
logger = logging.getLogger(__name__)
logger = structlog.get_logger(__name__)
# Retry decorator for database operations (Qdrant, Neo4j)
db_retry = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
from mem0.llms.openai import OpenAILLM
_original_generate_response = OpenAILLM.generate_response
def patched_generate_response(self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs):
def patched_generate_response(
self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs
):
# Remove 'store' parameter as LiteLLM doesn't support it
if hasattr(self.config, 'store'):
if hasattr(self.config, "store"):
self.config.store = None
# Remove 'top_p' to avoid conflict with temperature for Claude models
if hasattr(self.config, 'top_p'):
if hasattr(self.config, "top_p"):
self.config.top_p = None
return _original_generate_response(self, messages, response_format, tools, tool_choice, **kwargs)
return _original_generate_response(
self, messages, response_format, tools, tool_choice, **kwargs
)
OpenAILLM.generate_response = patched_generate_response
logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter")
@ -36,8 +60,16 @@ class Mem0Manager:
def __init__(self):
# Custom endpoint configuration with graph memory enabled
logger.info("Initializing ultra-minimal Mem0Manager with custom endpoint with settings:", settings)
logger.info(
"Initializing Mem0Manager with custom endpoint",
model=settings.default_model,
embedding_model=settings.embedding_model,
embedding_dims=settings.embedding_dims,
qdrant_host=settings.qdrant_host,
neo4j_uri=settings.neo4j_uri,
)
config = {
"version": "v1.1",
"enable_graph": True,
"llm": {
"provider": "openai",
@ -46,17 +78,16 @@ class Mem0Manager:
"api_key": settings.openai_api_key,
"openai_base_url": settings.openai_base_url,
"temperature": 0.1,
"top_p": None # Don't use top_p with Claude models
}
"top_p": None,
},
},
"embedder": {
"provider": "ollama",
"config": {
"model": "qwen3-embedding:4b-q8_0",
# "api_key": settings.embedder_api_key,
"ollama_base_url": "http://172.17.0.1:11434",
"embedding_dims": 2560
}
"model": settings.embedding_model,
"ollama_base_url": settings.ollama_base_url,
"embedding_dims": settings.embedding_dims,
},
},
"vector_store": {
"provider": "qdrant",
@ -64,38 +95,39 @@ class Mem0Manager:
"collection_name": settings.qdrant_collection_name,
"host": settings.qdrant_host,
"port": settings.qdrant_port,
"embedding_model_dims": 2560,
"on_disk": True
}
"embedding_model_dims": settings.embedding_dims,
"on_disk": True,
},
},
"graph_store": {
"provider": "neo4j",
"config": {
"url": settings.neo4j_uri,
"username": settings.neo4j_username,
"password": settings.neo4j_password
}
"password": settings.neo4j_password,
},
},
"reranker": {
"provider": "cohere",
"config": {
"api_key": settings.cohere_api_key,
"model": "rerank-english-v3.0",
"top_n": 10
}
}
"top_n": 10,
},
},
}
self.memory = Memory.from_config(config)
self.openai_client = OpenAI(
api_key=settings.openai_api_key,
base_url=settings.openai_base_url
base_url=settings.openai_base_url,
timeout=60.0, # 60 second timeout for LLM calls
max_retries=2, # Retry failed requests up to 2 times
)
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
# Pure passthrough methods - no custom logic
@db_retry
@timed("add_memories")
async def add_memories(
self,
@ -103,14 +135,14 @@ class Mem0Manager:
user_id: Optional[str] = "default",
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Add memories - simplified native Mem0 pattern (10 lines vs 45)."""
try:
# Convert ChatMessage objects to dict if needed
formatted_messages = []
for msg in messages:
if hasattr(msg, 'dict'):
if hasattr(msg, "dict"):
formatted_messages.append(msg.dict())
else:
formatted_messages.append(msg)
@ -123,26 +155,35 @@ class Mem0Manager:
"timestamp": datetime.now().isoformat(),
"source": "chat_conversation",
"message_count": len(formatted_messages),
"auto_generated": True
"auto_generated": True,
}
# Merge user metadata with auto metadata (user metadata takes precedence)
enhanced_metadata = {**auto_metadata, **combined_metadata}
# Direct Mem0 add with enhanced metadata
result = self.memory.add(formatted_messages, user_id=user_id,
agent_id=agent_id, run_id=run_id,
metadata=enhanced_metadata)
result = self.memory.add(
formatted_messages,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=enhanced_metadata,
)
return {
"added_memories": result if isinstance(result, list) else [result],
"message": "Memories added successfully",
"hierarchy": {"user_id": user_id, "agent_id": agent_id, "run_id": run_id}
"hierarchy": {
"user_id": user_id,
"agent_id": agent_id,
"run_id": run_id,
},
}
except Exception as e:
logger.error(f"Error adding memories: {e}")
raise e
raise
@db_retry
@timed("search_memories")
async def search_memories(
self,
@ -155,37 +196,79 @@ class Mem0Manager:
# rerank: bool = False,
# filter_memories: bool = False,
agent_id: Optional[str] = None,
run_id: Optional[str] = None
run_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Search memories - native Mem0 pattern"""
try:
# Minimal empty query protection for API compatibility
if not query or query.strip() == "":
return {"memories": [], "total_count": 0, "query": query, "note": "Empty query provided, no results returned. Use a specific query to search memories."}
return {
"memories": [],
"total_count": 0,
"query": query,
"note": "Empty query provided, no results returned. Use a specific query to search memories.",
}
# Direct Mem0 search - trust native handling
result = self.memory.search(query=query, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit, threshold=threshold, filters=filters)
return {"memories": result.get("results", []), "total_count": len(result.get("results", [])), "query": query}
result = self.memory.search(
query=query,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit,
threshold=threshold,
filters=filters,
)
return {
"memories": result.get("results", []),
"total_count": len(result.get("results", [])),
"query": query,
}
except Exception as e:
logger.error(f"Error searching memories: {e}")
raise e
raise
@db_retry
async def get_user_memories(
self,
user_id: str,
limit: int = 10,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None
filters: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""Get all memories for a user - native Mem0 pattern."""
try:
# Direct Mem0 get_all call - trust native parameter handling
result = self.memory.get_all(user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id, filters=filters)
result = self.memory.get_all(
user_id=user_id,
limit=limit,
agent_id=agent_id,
run_id=run_id,
filters=filters,
)
return result.get("results", [])
except Exception as e:
logger.error(f"Error getting user memories: {e}")
raise e
raise
@db_retry
async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
"""Get a single memory by ID. Returns None if not found."""
try:
result = self.memory.get(memory_id=memory_id)
return result
except Exception as e:
logger.debug(f"Memory {memory_id} not found or error: {e}")
return None
async def verify_memory_ownership(self, memory_id: str, user_id: str) -> bool:
"""Check if a memory belongs to a user. O(1) instead of O(n)."""
memory = await self.get_memory(memory_id)
if memory is None:
return False
return memory.get("user_id") == user_id
@db_retry
@timed("update_memory")
async def update_memory(
self,
@ -194,15 +277,13 @@ class Mem0Manager:
) -> Dict[str, Any]:
"""Update memory - pure Mem0 passthrough."""
try:
result = self.memory.update(
memory_id=memory_id,
data=content
)
result = self.memory.update(memory_id=memory_id, data=content)
return {"message": "Memory updated successfully", "result": result}
except Exception as e:
logger.error(f"Error updating memory: {e}")
raise e
raise
@db_retry
@timed("delete_memory")
async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
"""Delete memory - pure Mem0 passthrough."""
@ -211,7 +292,7 @@ class Mem0Manager:
return {"message": "Memory deleted successfully"}
except Exception as e:
logger.error(f"Error deleting memory: {e}")
raise e
raise
async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]:
"""Delete all user memories - pure Mem0 passthrough."""
@ -220,7 +301,7 @@ class Mem0Manager:
return {"message": "All user memories deleted successfully"}
except Exception as e:
logger.error(f"Error deleting user memories: {e}")
raise e
raise
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
"""Get memory change history - pure Mem0 passthrough."""
@ -229,22 +310,24 @@ class Mem0Manager:
return {
"memory_id": memory_id,
"history": history,
"message": "Memory history retrieved successfully"
"message": "Memory history retrieved successfully",
}
except Exception as e:
logger.error(f"Error getting memory history: {e}")
raise e
raise
async def get_graph_relationships(self, user_id: Optional[str], agent_id: Optional[str], run_id: Optional[str], limit: int = 50) -> Dict[str, Any]:
async def get_graph_relationships(
self,
user_id: Optional[str],
agent_id: Optional[str],
run_id: Optional[str],
limit: int = 50,
) -> Dict[str, Any]:
"""Get graph relationships - using correct Mem0 get_all() method."""
try:
# Use get_all() to retrieve memories with graph relationships
result = self.memory.get_all(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit
user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit
)
# Extract relationships from Mem0's response structure
@ -272,7 +355,7 @@ class Mem0Manager:
"agent_id": agent_id,
"run_id": run_id,
"total_memories": len(result.get("results", [])),
"total_relationships": len(relationships)
"total_relationships": len(relationships),
}
except Exception as e:
@ -286,7 +369,7 @@ class Mem0Manager:
"run_id": run_id,
"total_memories": 0,
"total_relationships": 0,
"error": str(e)
"error": str(e),
}
@timed("chat_with_memory")
@ -304,49 +387,70 @@ class Mem0Manager:
try:
total_start_time = time.time()
print(f"\n🚀 Starting chat request for user: {user_id}")
logger.info("Starting chat request", user_id=user_id)
# Stage 1: Memory Search
search_start_time = time.time()
search_result = self.memory.search(query=message, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=10, threshold=0.3)
search_result = self.memory.search(
query=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=10,
threshold=0.3,
)
relevant_memories = search_result.get("results", [])
memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories)
memories_str = "\n".join(
f"- {entry['memory']}" for entry in relevant_memories
)
search_time = time.time() - search_start_time
print(f"🔍 Memory search took: {search_time:.2f}s (found {len(relevant_memories)} memories)")
logger.debug(
"Memory search completed",
search_time_s=round(search_time, 2),
memories_found=len(relevant_memories),
)
# Stage 2: Prepare LLM messages
prep_start_time = time.time()
system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
messages = [{"role": "system", "content": system_prompt}]
# Add conversation context if provided (last 50 messages)
if context:
messages.extend(context)
print(f"📝 Added {len(context)} context messages")
logger.debug("Added context messages", context_count=len(context))
# Add current user message
messages.append({"role": "user", "content": message})
prep_time = time.time() - prep_start_time
print(f"📋 Message preparation took: {prep_time:.3f}s")
# Stage 3: LLM Call
llm_start_time = time.time()
response = self.openai_client.chat.completions.create(model=settings.default_model, messages=messages)
response = self.openai_client.chat.completions.create(
model=settings.default_model, messages=messages
)
assistant_response = response.choices[0].message.content
llm_time = time.time() - llm_start_time
print(f"🤖 LLM call took: {llm_time:.2f}s (model: {settings.default_model})")
logger.debug(
"LLM call completed",
llm_time_s=round(llm_time, 2),
model=settings.default_model,
)
# Stage 4: Memory Add
add_start_time = time.time()
memory_messages = [{"role": "user", "content": message}, {"role": "assistant", "content": assistant_response}]
memory_messages = [
{"role": "user", "content": message},
{"role": "assistant", "content": assistant_response},
]
self.memory.add(memory_messages, user_id=user_id)
add_time = time.time() - add_start_time
print(f"💾 Memory add took: {add_time:.2f}s")
# Total timing summary
total_time = time.time() - total_start_time
print(f"⏱️ TOTAL: {total_time:.2f}s | Search: {search_time:.2f}s | LLM: {llm_time:.2f}s | Add: {add_time:.2f}s | Prep: {prep_time:.3f}s")
print(f"📊 Breakdown: Search {(search_time/total_time)*100:.1f}% | LLM {(llm_time/total_time)*100:.1f}% | Add {(add_time/total_time)*100:.1f}%\n")
logger.info(
"Chat request completed",
user_id=user_id,
total_time_s=round(total_time, 2),
search_time_s=round(search_time, 2),
llm_time_s=round(llm_time, 2),
add_time_s=round(add_time, 2),
memories_used=len(relevant_memories),
model=settings.default_model,
)
return {
"response": assistant_response,
@ -356,17 +460,22 @@ class Mem0Manager:
"total": round(total_time, 2),
"search": round(search_time, 2),
"llm": round(llm_time, 2),
"add": round(add_time, 2)
}
"add": round(add_time, 2),
},
}
except Exception as e:
logger.error(f"Error in chat_with_memory: {e}")
logger.error(
"Error in chat_with_memory",
error=str(e),
user_id=user_id,
exc_info=True,
)
return {
"error": str(e),
"response": "I apologize, but I encountered an error processing your request.",
"memories_used": 0,
"model_used": None
"model_used": None,
}
async def health_check(self) -> Dict[str, str]:

View file

@ -2,53 +2,115 @@
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
import re
# Constants for input validation
MAX_MESSAGE_LENGTH = 50000 # ~12k tokens max per message
MAX_QUERY_LENGTH = 10000 # ~2.5k tokens max per query
MAX_USER_ID_LENGTH = 100 # Reasonable user ID length
MAX_MEMORY_ID_LENGTH = 100 # Memory IDs are typically UUIDs
MAX_CONTEXT_MESSAGES = 100 # Max conversation context messages
USER_ID_PATTERN = r"^[a-zA-Z0-9_\-\.@]+$" # Alphanumeric with common separators
# Request Models
class ChatMessage(BaseModel):
"""Chat message structure."""
role: str = Field(..., description="Message role (user, assistant, system)")
content: str = Field(..., description="Message content")
role: str = Field(
..., max_length=20, description="Message role (user, assistant, system)"
)
content: str = Field(
..., max_length=MAX_MESSAGE_LENGTH, description="Message content"
)
class ChatRequest(BaseModel):
"""Ultra-minimal chat request."""
message: str = Field(..., description="User message")
user_id: Optional[str] = Field("default", description="User identifier")
agent_id: Optional[str] = Field(None, description="Agent identifier")
run_id: Optional[str] = Field(None, description="Run identifier")
context: Optional[List[ChatMessage]] = Field(None, description="Previous conversation context")
message: str = Field(..., max_length=MAX_MESSAGE_LENGTH, description="User message")
user_id: Optional[str] = Field(
"default",
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier (alphanumeric, _, -, ., @)",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
context: Optional[List[ChatMessage]] = Field(
None,
max_length=MAX_CONTEXT_MESSAGES,
description="Previous conversation context (max 100 messages)",
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class MemoryAddRequest(BaseModel):
"""Request to add memories with hierarchy support - open-source compatible."""
messages: List[ChatMessage] = Field(..., description="Messages to process")
user_id: Optional[str] = Field("default", description="User identifier")
agent_id: Optional[str] = Field(None, description="Agent identifier")
run_id: Optional[str] = Field(None, description="Run identifier")
messages: List[ChatMessage] = Field(
...,
max_length=MAX_CONTEXT_MESSAGES,
description="Messages to process (max 100 messages)",
)
user_id: Optional[str] = Field(
"default",
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class MemorySearchRequest(BaseModel):
"""Request to search memories with hierarchy filtering."""
query: str = Field(..., description="Search query")
user_id: Optional[str] = Field("default", description="User identifier")
agent_id: Optional[str] = Field(None, description="Agent identifier")
run_id: Optional[str] = Field(None, description="Run identifier")
limit: int = Field(5, description="Maximum number of results")
threshold: Optional[float] = Field(None, description="Minimum relevance score")
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters")
# Hierarchy filters (open-source compatible)
agent_id: Optional[str] = Field(None, description="Filter by agent identifier")
run_id: Optional[str] = Field(None, description="Filter by run identifier")
query: str = Field(..., max_length=MAX_QUERY_LENGTH, description="Search query")
user_id: Optional[str] = Field(
"default",
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
limit: int = Field(5, ge=1, le=100, description="Maximum number of results (1-100)")
threshold: Optional[float] = Field(
None, ge=0.0, le=1.0, description="Minimum relevance score (0-1)"
)
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters")
class MemoryUpdateRequest(BaseModel):
"""Request to update a memory."""
memory_id: str = Field(..., description="Memory ID to update")
content: str = Field(..., description="New memory content")
memory_id: str = Field(
..., max_length=MAX_MEMORY_ID_LENGTH, description="Memory ID to update"
)
user_id: str = Field(
...,
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier for ownership verification",
)
content: str = Field(
..., max_length=MAX_MESSAGE_LENGTH, description="New memory content"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata")
@ -57,19 +119,23 @@ class MemoryUpdateRequest(BaseModel):
class MemoryItem(BaseModel):
"""Individual memory item."""
id: str = Field(..., description="Memory unique identifier")
memory: str = Field(..., description="Memory content")
user_id: Optional[str] = Field(None, description="Associated user ID")
agent_id: Optional[str] = Field(None, description="Associated agent ID")
run_id: Optional[str] = Field(None, description="Associated run ID")
metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata")
score: Optional[float] = Field(None, description="Relevance score (for search results)")
score: Optional[float] = Field(
None, description="Relevance score (for search results)"
)
created_at: Optional[str] = Field(None, description="Creation timestamp")
updated_at: Optional[str] = Field(None, description="Last update timestamp")
class MemorySearchResponse(BaseModel):
"""Memory search results - pure Mem0 structure."""
memories: List[MemoryItem] = Field(..., description="Found memories")
total_count: int = Field(..., description="Total number of memories found")
query: str = Field(..., description="Original search query")
@ -77,27 +143,37 @@ class MemorySearchResponse(BaseModel):
class MemoryAddResponse(BaseModel):
"""Response from adding memories - pure Mem0 structure."""
added_memories: List[Dict[str, Any]] = Field(..., description="Memories that were added")
added_memories: List[Dict[str, Any]] = Field(
..., description="Memories that were added"
)
message: str = Field(..., description="Success message")
class GraphRelationship(BaseModel):
"""Graph relationship structure."""
source: str = Field(..., description="Source entity")
relationship: str = Field(..., description="Relationship type")
target: str = Field(..., description="Target entity")
properties: Optional[Dict[str, Any]] = Field(None, description="Relationship properties")
properties: Optional[Dict[str, Any]] = Field(
None, description="Relationship properties"
)
class GraphResponse(BaseModel):
"""Graph relationships - pure Mem0 structure."""
relationships: List[GraphRelationship] = Field(..., description="Found relationships")
relationships: List[GraphRelationship] = Field(
..., description="Found relationships"
)
entities: List[str] = Field(..., description="Unique entities")
user_id: str = Field(..., description="User identifier")
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Service status")
services: Dict[str, str] = Field(..., description="Individual service statuses")
timestamp: str = Field(..., description="Health check timestamp")
@ -105,6 +181,7 @@ class HealthResponse(BaseModel):
class ErrorResponse(BaseModel):
"""Error response structure."""
error: str = Field(..., description="Error message")
detail: Optional[str] = Field(None, description="Detailed error information")
status_code: int = Field(..., description="HTTP status code")
@ -112,8 +189,10 @@ class ErrorResponse(BaseModel):
# Statistics and Monitoring Models
class MemoryOperationStats(BaseModel):
"""Memory operation statistics."""
add: int = Field(..., description="Number of add operations")
search: int = Field(..., description="Number of search operations")
update: int = Field(..., description="Number of update operations")
@ -122,19 +201,111 @@ class MemoryOperationStats(BaseModel):
class GlobalStatsResponse(BaseModel):
"""Global application statistics."""
total_memories: int = Field(..., description="Total memories across all users")
total_users: int = Field(..., description="Total number of users")
api_calls_today: int = Field(..., description="Total API calls today")
avg_response_time_ms: float = Field(..., description="Average response time in milliseconds")
memory_operations: MemoryOperationStats = Field(..., description="Memory operation breakdown")
avg_response_time_ms: float = Field(
..., description="Average response time in milliseconds"
)
memory_operations: MemoryOperationStats = Field(
..., description="Memory operation breakdown"
)
uptime_seconds: float = Field(..., description="Application uptime in seconds")
class UserStatsResponse(BaseModel):
"""User-specific statistics."""
user_id: str = Field(..., description="User identifier")
memory_count: int = Field(..., description="Number of memories for this user")
relationship_count: int = Field(..., description="Number of graph relationships for this user")
relationship_count: int = Field(
..., description="Number of graph relationships for this user"
)
last_activity: Optional[str] = Field(None, description="Last activity timestamp")
api_calls_today: int = Field(..., description="API calls made by this user today")
avg_response_time_ms: float = Field(..., description="Average response time for this user's requests")
avg_response_time_ms: float = Field(
..., description="Average response time for this user's requests"
)
# OpenAI-Compatible API Models
class OpenAIMessage(BaseModel):
"""OpenAI message format."""
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
class OpenAIChatCompletionRequest(BaseModel):
"""OpenAI chat completion request format."""
model: str = Field(..., description="Model to use (will use configured default)")
messages: List[Dict[str, str]] = Field(..., description="List of messages")
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
stream: Optional[bool] = Field(False, description="Whether to stream responses")
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
n: Optional[int] = Field(1, description="Number of completions to generate")
stop: Optional[List[str]] = Field(None, description="Stop sequences")
presence_penalty: Optional[float] = Field(0, description="Presence penalty")
frequency_penalty: Optional[float] = Field(0, description="Frequency penalty")
user: Optional[str] = Field(
None, description="User identifier (ignored, uses API key)"
)
class OpenAIUsage(BaseModel):
"""Token usage information."""
prompt_tokens: int = Field(..., description="Tokens in the prompt")
completion_tokens: int = Field(..., description="Tokens in the completion")
total_tokens: int = Field(..., description="Total tokens used")
class OpenAIChoiceMessage(BaseModel):
"""Message in a choice."""
role: str = Field(..., description="Role of the message")
content: str = Field(..., description="Content of the message")
class OpenAIChoice(BaseModel):
"""Individual completion choice."""
index: int = Field(..., description="Choice index")
message: OpenAIChoiceMessage = Field(..., description="Message content")
finish_reason: str = Field(..., description="Reason for completion finish")
class OpenAIChatCompletionResponse(BaseModel):
"""OpenAI chat completion response format."""
id: str = Field(..., description="Unique completion ID")
object: str = Field(default="chat.completion", description="Object type")
created: int = Field(..., description="Unix timestamp of creation")
model: str = Field(..., description="Model used for completion")
choices: List[OpenAIChoice] = Field(..., description="List of completion choices")
usage: Optional[OpenAIUsage] = Field(None, description="Token usage information")
# Streaming-specific models
class OpenAIStreamDelta(BaseModel):
"""Delta content in a streaming chunk."""
role: Optional[str] = Field(None, description="Role (only in first chunk)")
content: Optional[str] = Field(None, description="Incremental content")
class OpenAIStreamChoice(BaseModel):
"""Individual streaming choice."""
index: int = Field(..., description="Choice index")
delta: OpenAIStreamDelta = Field(..., description="Delta content")
finish_reason: Optional[str] = Field(
None, description="Reason for completion finish"
)

View file

@ -19,6 +19,7 @@ ollama
# Utilities
pydantic
pydantic-settings
tenacity
python-dotenv
httpx
aiofiles
@ -31,3 +32,9 @@ python-json-logger
# CORS and Security
python-jose[cryptography]
passlib[bcrypt]
# Rate Limiting
slowapi
# MCP Server
mcp[server]>=1.0.0

View file

@ -5,6 +5,8 @@ services:
container_name: mem0-qdrant
expose:
- "6333"
networks:
- mem0_network
volumes:
- qdrant_data:/qdrant/storage
command: >
@ -32,6 +34,8 @@ services:
expose:
- "7474" # HTTP - Internal only
- "7687" # Bolt - Internal only
networks:
- mem0_network
volumes:
- neo4j_data:/data
- neo4j_logs:/logs
@ -65,8 +69,14 @@ services:
CORS_ORIGINS: ${CORS_ORIGINS:-http://localhost:3000}
DEFAULT_MODEL: ${DEFAULT_MODEL:-claude-sonnet-4}
API_KEYS: ${API_KEYS:-{}}
ports:
- "${BACKEND_PORT:-8000}:8000"
OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://host.docker.internal:11434}
EMBEDDING_MODEL: ${EMBEDDING_MODEL:-nomic-embed-text}
EMBEDDING_DIMS: ${EMBEDDING_DIMS:-2560}
expose:
- "8000"
networks:
- npm_network
- mem0_network
depends_on:
qdrant:
condition: service_healthy
@ -75,7 +85,8 @@ services:
restart: unless-stopped
volumes:
- ./backend:/app
command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
- ./frontend:/app/frontend
command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
volumes:
qdrant_data:
@ -85,5 +96,6 @@ volumes:
neo4j_plugins:
networks:
default:
name: mem0-network
mem0_network:
npm_network:
external: true

View file

@ -18,12 +18,106 @@
display: flex;
}
/* Login Screen */
.login-screen {
display: flex;
align-items: center;
justify-content: center;
width: 100%;
height: 100vh;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
.login-screen.hidden {
display: none;
}
.login-box {
background: white;
padding: 40px;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.2);
width: 100%;
max-width: 400px;
}
.login-box h1 {
margin-bottom: 10px;
color: #333;
font-size: 28px;
text-align: center;
}
.login-box p {
color: #666;
font-size: 14px;
text-align: center;
margin-bottom: 30px;
}
.login-box input {
width: 100%;
padding: 14px;
border: 2px solid #e0e0e0;
border-radius: 8px;
font-size: 14px;
margin-bottom: 20px;
outline: none;
transition: border-color 0.3s;
}
.login-box input:focus {
border-color: #667eea;
}
.login-box button {
width: 100%;
padding: 14px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s, opacity 0.3s;
}
.login-box button:hover {
transform: translateY(-2px);
}
.login-box button:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none;
}
.login-error {
background: #ffe6e6;
border: 1px solid #ffcccc;
color: #cc0000;
padding: 12px;
border-radius: 8px;
margin-bottom: 20px;
font-size: 14px;
display: none;
}
.login-error.show {
display: block;
}
.container {
display: flex;
width: 100%;
height: 100vh;
}
.container.hidden {
display: none;
}
/* Chat Section */
.chat-section {
flex: 1;
@ -58,7 +152,12 @@
font-size: 14px;
}
.clear-chat-btn {
.header-buttons {
display: flex;
gap: 10px;
}
.clear-chat-btn, .logout-btn {
background: #f8f9fa;
color: #666;
border: 1px solid #e0e0e0;
@ -73,16 +172,27 @@
transition: all 0.2s ease;
}
.clear-chat-btn:hover {
.clear-chat-btn:hover, .logout-btn:hover {
background: #e9ecef;
border-color: #ced4da;
color: #495057;
}
.clear-chat-btn:active {
.clear-chat-btn:active, .logout-btn:active {
background: #dee2e6;
}
.logout-btn {
background: #fff3cd;
border-color: #ffc107;
color: #856404;
}
.logout-btn:hover {
background: #ffe69c;
border-color: #ffb300;
}
.chat-messages {
flex: 1;
overflow-y: auto;
@ -281,17 +391,42 @@
</style>
</head>
<body>
<div class="container">
<!-- Login Screen -->
<div class="login-screen" id="loginScreen">
<div class="login-box">
<h1>🧠 Mem0 Chat</h1>
<p>Enter your API key to access your memory-powered assistant</p>
<div class="login-error" id="loginError"></div>
<input
type="password"
id="apiKeyInput"
placeholder="Enter your API key (e.g., sk-xxxxx)"
autocomplete="off"
/>
<button id="loginButton">Connect</button>
</div>
</div>
<!-- Main Chat Interface (hidden initially) -->
<div class="container hidden" id="mainContainer">
<!-- Chat Section -->
<div class="chat-section">
<div class="chat-header">
<div class="chat-header-content">
<h1>What can I help you with?</h1>
<p>Chat with your memories - User: pratik</p>
<p>Chat with your memories - User: <span id="currentUser">...</span></p>
</div>
<div class="header-buttons">
<button class="clear-chat-btn" id="clearChatBtn" title="Clear chat history">
🗑️ Clear Chat
</button>
<button class="logout-btn" id="logoutBtn" title="Logout">
🚪 Logout
</button>
</div>
</div>
<div class="chat-messages" id="chatMessages">
@ -319,10 +454,20 @@
<script>
// Configuration
const API_BASE = 'http://localhost:8000';
const USER_ID = 'pratik';
const API_BASE = window.location.origin;
// State
let API_KEY = null;
let USER_ID = null;
// DOM Elements
const loginScreen = document.getElementById('loginScreen');
const mainContainer = document.getElementById('mainContainer');
const apiKeyInput = document.getElementById('apiKeyInput');
const loginButton = document.getElementById('loginButton');
const loginError = document.getElementById('loginError');
const logoutBtn = document.getElementById('logoutBtn');
const currentUser = document.getElementById('currentUser');
const chatMessages = document.getElementById('chatMessages');
const messageInput = document.getElementById('messageInput');
const sendButton = document.getElementById('sendButton');
@ -336,19 +481,143 @@
// Initialize
document.addEventListener('DOMContentLoaded', function() {
loadChatHistory();
loadMemories();
// Check if already logged in
const savedApiKey = localStorage.getItem('apiKey');
const savedUserId = localStorage.getItem('userId');
if (savedApiKey && savedUserId) {
// Auto-login with saved credentials
API_KEY = savedApiKey;
USER_ID = savedUserId;
showMainInterface();
}
// Event listeners
loginButton.addEventListener('click', handleLogin);
apiKeyInput.addEventListener('keydown', (e) => {
if (e.key === 'Enter') handleLogin();
});
logoutBtn.addEventListener('click', handleLogout);
sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keydown', handleKeyDown);
messageInput.addEventListener('input', autoResizeTextarea);
refreshButton.addEventListener('click', loadMemories);
clearChatBtn.addEventListener('click', clearChatWithConfirmation);
});
// Handle login
async function handleLogin() {
const apiKey = apiKeyInput.value.trim();
if (!apiKey) {
showLoginError('Please enter an API key');
return;
}
loginButton.disabled = true;
loginButton.textContent = 'Verifying...';
hideLoginError();
try {
// Verify API key by calling /health with auth
const response = await fetch(`${API_BASE}/health`, {
headers: {
'X-API-Key': apiKey
}
});
if (!response.ok) {
throw new Error('Invalid API key');
}
// Get user_id by trying to call a test endpoint
// We'll use /models since it doesn't require auth parameters
const userResponse = await fetch(`${API_BASE}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-API-Key': apiKey
},
body: JSON.stringify({
model: 'gpt-4',
messages: [{ role: 'user', content: 'test' }],
stream: false
})
});
if (!userResponse.ok) {
throw new Error('Failed to verify user');
}
// API key is valid, extract user_id from auth_service mapping
// We'll store both and use a simple username extraction
API_KEY = apiKey;
// Try to extract username from API key (e.g., sk-alice -> alice)
if (apiKey.startsWith('sk-')) {
const parts = apiKey.substring(3).split('-');
USER_ID = parts[0]; // Get first part after sk-
} else {
USER_ID = 'user';
}
// Save to localStorage
localStorage.setItem('apiKey', API_KEY);
localStorage.setItem('userId', USER_ID);
// Show main interface
showMainInterface();
} catch (error) {
console.error('Login error:', error);
showLoginError('Invalid API key. Please check and try again.');
loginButton.disabled = false;
loginButton.textContent = 'Connect';
}
}
// Show main interface
function showMainInterface() {
loginScreen.classList.add('hidden');
mainContainer.classList.remove('hidden');
currentUser.textContent = USER_ID;
// Load data
loadChatHistory();
loadMemories();
// Initialize textarea height
autoResizeTextarea();
});
messageInput.focus();
}
// Handle logout
function handleLogout() {
if (confirm('Are you sure you want to logout?')) {
// Clear credentials
localStorage.removeItem('apiKey');
localStorage.removeItem('userId');
API_KEY = null;
USER_ID = null;
// Show login screen
mainContainer.classList.add('hidden');
loginScreen.classList.remove('hidden');
apiKeyInput.value = '';
hideLoginError();
}
}
// Show login error
function showLoginError(message) {
loginError.textContent = message;
loginError.classList.add('show');
}
// Hide login error
function hideLoginError() {
loginError.classList.remove('show');
}
// Load chat history from localStorage
function loadChatHistory() {
@ -419,6 +688,7 @@
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-API-Key': API_KEY
},
body: JSON.stringify({
message: message,
@ -460,7 +730,11 @@
// Load memories from backend
async function loadMemories() {
try{
const response = await fetch(`${API_BASE}/memories/${USER_ID}?limit=50`);
const response = await fetch(`${API_BASE}/memories/${USER_ID}?limit=50`, {
headers: {
'X-API-Key': API_KEY
}
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
@ -512,8 +786,11 @@
}
try {
const response = await fetch(`${API_BASE}/memories/${memoryId}`, {
method: 'DELETE'
const response = await fetch(`${API_BASE}/memories/${memoryId}?user_id=${USER_ID}`, {
method: 'DELETE',
headers: {
'X-API-Key': API_KEY
}
});
if (!response.ok) {

122
setup.sh Executable file
View file

@ -0,0 +1,122 @@
#!/bin/bash
set -e
COMPOSE_PROJECT_NAME="${COMPOSE_PROJECT_NAME:-mem0}"
VOLUMES=("qdrant_data" "neo4j_data" "neo4j_logs" "neo4j_import" "neo4j_plugins")
echo "=========================================="
echo " Mem0 Setup Script"
echo "=========================================="
echo ""
check_volumes_exist() {
local exists=false
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
exists=true
break
fi
done
echo "$exists"
}
list_existing_volumes() {
echo "Existing volumes:"
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
size=$(docker system df -v 2>/dev/null | grep "$full_name" | awk '{print $4}' || echo "unknown")
echo " - $full_name ($size)"
fi
done
}
remove_volumes() {
echo "Stopping containers..."
docker compose down 2>/dev/null || true
echo "Removing volumes..."
for vol in "${VOLUMES[@]}"; do
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
if docker volume ls -q | grep -q "^${full_name}$"; then
echo " Removing $full_name..."
docker volume rm "$full_name" 2>/dev/null || true
fi
done
echo "Volumes removed."
}
build_and_start() {
echo ""
echo "Building containers..."
docker compose build
echo ""
echo "Starting services..."
docker compose up -d
echo ""
echo "Waiting for services to be healthy..."
sleep 5
echo ""
echo "Checking health..."
curl -s http://localhost:8000/health 2>/dev/null | jq . || echo "Health check not available yet. Services may still be starting."
echo ""
echo "=========================================="
echo " Setup Complete!"
echo "=========================================="
echo ""
echo "Services:"
echo " - Backend API: http://localhost:8000"
echo " - API Docs: http://localhost:8000/docs"
echo " - Health Check: http://localhost:8000/health"
echo ""
echo "Logs: docker compose logs -f backend"
}
if [ "$(check_volumes_exist)" = "true" ]; then
echo "Existing data volumes detected!"
echo ""
list_existing_volumes
echo ""
echo "Options:"
echo " 1) Keep existing data and start services"
echo " 2) Reset everything (DELETE ALL DATA) and start fresh"
echo " 3) Exit"
echo ""
read -p "Choose an option [1/2/3]: " choice
case $choice in
1)
echo ""
echo "Keeping existing data..."
build_and_start
;;
2)
echo ""
read -p "Are you sure you want to DELETE ALL DATA? Type 'yes' to confirm: " confirm
if [ "$confirm" = "yes" ]; then
remove_volumes
build_and_start
else
echo "Aborted."
exit 1
fi
;;
3)
echo "Exiting."
exit 0
;;
*)
echo "Invalid option. Exiting."
exit 1
;;
esac
else
echo "No existing volumes found. Starting fresh setup..."
build_and_start
fi

View file

@ -19,13 +19,24 @@ import time
BASE_URL = "http://localhost:8000"
TEST_USER = f"test_user_{int(datetime.now().timestamp())}"
# API Key for authentication - set via environment or use default test key
import os
API_KEY = os.environ.get("MEM0_API_KEY", "test-api-key")
AUTH_HEADERS = {"X-API-Key": API_KEY}
def main():
parser = argparse.ArgumentParser(
description="Mem0 Integration Tests - Real API Testing (Zero Mocking)",
formatter_class=argparse.RawDescriptionHelpFormatter
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="Show detailed output and API responses",
)
parser.add_argument("--verbose", "-v", action="store_true",
help="Show detailed output and API responses")
args = parser.parse_args()
verbose = args.verbose
@ -39,6 +50,9 @@ def main():
# Test sequence - order matters for data dependencies
tests = [
test_health_check,
test_auth_required_endpoints,
test_ownership_verification,
test_request_size_limit,
test_empty_search_protection,
test_add_memories_with_hierarchy,
test_search_memories_basic,
@ -51,7 +65,7 @@ def main():
test_graph_relationships,
test_delete_specific_memory,
test_delete_all_user_memories,
test_cleanup_verification
test_cleanup_verification,
]
results = []
@ -82,6 +96,7 @@ def main():
print("❌ Some tests failed! Check the output above.")
sys.exit(1)
def run_test(name, test_func, verbose):
"""Run a single test with error handling"""
try:
@ -102,6 +117,7 @@ def run_test(name, test_func, verbose):
print(f"{name}: {e}")
return False
def log_response(response, verbose, context=""):
"""Log API response details if verbose"""
if verbose:
@ -111,22 +127,30 @@ def log_response(response, verbose, context=""):
if isinstance(data, dict) and len(data) < 5:
print(f" {context} Response: {data}")
else:
print(f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}")
print(
f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}"
)
except:
print(f" {context} Response: {response.text[:100]}...")
# ================== TEST FUNCTIONS ==================
def test_health_check(verbose):
"""Test service health endpoint"""
response = requests.get(f"{BASE_URL}/health", timeout=10)
response = requests.get(
f"{BASE_URL}/health", timeout=10
) # Health doesn't require auth
log_response(response, verbose, "Health")
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert "status" in data, "Health response missing 'status' field"
assert data["status"] in ["healthy", "degraded"], f"Invalid status: {data['status']}"
assert data["status"] in ["healthy", "degraded"], (
f"Invalid status: {data['status']}"
)
# Check individual services
assert "services" in data, "Health response missing 'services' field"
@ -136,18 +160,19 @@ def test_health_check(verbose):
for service, status in data["services"].items():
print(f" {service}: {status}")
def test_empty_search_protection(verbose):
"""Test empty query protection (should not return 500 error)"""
payload = {
"query": "",
"user_id": TEST_USER,
"limit": 5
}
payload = {"query": "", "user_id": TEST_USER, "limit": 5}
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=10)
response = requests.post(
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Empty Search")
assert response.status_code == 200, f"Empty query failed with {response.status_code}"
assert response.status_code == 200, (
f"Empty query failed with {response.status_code}"
)
data = response.json()
assert data["memories"] == [], "Empty query should return empty memories list"
@ -158,25 +183,39 @@ def test_empty_search_protection(verbose):
print(f" Empty search note: {data['note']}")
print(f" Total count: {data.get('total_count', 0)}")
def test_add_memories_with_hierarchy(verbose):
"""Test adding memories with multi-level hierarchy support"""
payload = {
"messages": [
{"role": "user", "content": "I work at TechCorp as a Senior Software Engineer"},
{"role": "user", "content": "My colleague Sarah from Marketing team helped with Q3 presentation"},
{"role": "user", "content": "Meeting with John the Product Manager tomorrow about new feature development"}
{
"role": "user",
"content": "I work at TechCorp as a Senior Software Engineer",
},
{
"role": "user",
"content": "My colleague Sarah from Marketing team helped with Q3 presentation",
},
{
"role": "user",
"content": "Meeting with John the Product Manager tomorrow about new feature development",
},
],
"user_id": TEST_USER,
"agent_id": "test_agent",
"run_id": "test_run_001",
"session_id": "test_session_001",
"metadata": {"test": "integration", "scenario": "work_context"}
"metadata": {"test": "integration", "scenario": "work_context"},
}
response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60)
response = requests.post(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
)
log_response(response, verbose, "Add Memories")
assert response.status_code == 200, f"Add memories failed with {response.status_code}"
assert response.status_code == 200, (
f"Add memories failed with {response.status_code}"
)
data = response.json()
assert "added_memories" in data, "Response missing 'added_memories'"
@ -191,23 +230,26 @@ def test_add_memories_with_hierarchy(verbose):
relations = first_memory["relations"]
if "added_entities" in relations and relations["added_entities"]:
if verbose:
print(f" Graph extracted: {len(relations['added_entities'])} relationships")
print(
f" Graph extracted: {len(relations['added_entities'])} relationships"
)
print(f" Sample relations: {relations['added_entities'][:3]}")
if verbose:
print(f" Added {len(memories)} memory blocks")
print(f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001")
print(
f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001"
)
def test_search_memories_basic(verbose):
"""Test basic memory search functionality"""
# Test meaningful search
payload = {
"query": "TechCorp",
"user_id": TEST_USER,
"limit": 10
}
payload = {"query": "TechCorp", "user_id": TEST_USER, "limit": 10}
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15)
response = requests.post(
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Search")
assert response.status_code == 200, f"Search failed with {response.status_code}"
@ -232,6 +274,7 @@ def test_search_memories_basic(verbose):
print(f" Found {data['total_count']} memories")
print(f" First memory: {memory['memory'][:50]}...")
def test_search_memories_hierarchy_filters(verbose):
"""Test multi-level hierarchy filtering in search"""
# Test with hierarchy filters
@ -241,13 +284,17 @@ def test_search_memories_hierarchy_filters(verbose):
"agent_id": "test_agent",
"run_id": "test_run_001",
"session_id": "test_session_001",
"limit": 10
"limit": 10,
}
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15)
response = requests.post(
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Hierarchy Search")
assert response.status_code == 200, f"Hierarchy search failed with {response.status_code}"
assert response.status_code == 200, (
f"Hierarchy search failed with {response.status_code}"
)
data = response.json()
assert "memories" in data, "Hierarchy search response missing 'memories'"
@ -257,7 +304,10 @@ def test_search_memories_hierarchy_filters(verbose):
if verbose:
print(f" Found {len(data['memories'])} memories with hierarchy filters")
print(f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001")
print(
f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001"
)
def test_get_user_memories_with_hierarchy(verbose):
"""Test retrieving user memories with hierarchy filtering"""
@ -266,13 +316,20 @@ def test_get_user_memories_with_hierarchy(verbose):
"limit": 20,
"agent_id": "test_agent",
"run_id": "test_run_001",
"session_id": "test_session_001"
"session_id": "test_session_001",
}
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}", params=params, timeout=15)
response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}",
params=params,
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Get User Memories with Hierarchy")
assert response.status_code == 200, f"Get user memories with hierarchy failed with {response.status_code}"
assert response.status_code == 200, (
f"Get user memories with hierarchy failed with {response.status_code}"
)
memories = response.json()
assert isinstance(memories, list), "User memories should return a list"
@ -290,10 +347,13 @@ def test_get_user_memories_with_hierarchy(verbose):
if verbose:
print(" No memories found with hierarchy filters (may be expected)")
def test_memory_history(verbose):
"""Test memory history endpoint"""
# First get a memory to check history for
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for history test"
memories = response.json()
@ -305,27 +365,38 @@ def test_memory_history(verbose):
memory_id = memories[0]["id"]
# Test memory history endpoint
response = requests.get(f"{BASE_URL}/memories/{memory_id}/history", timeout=15)
response = requests.get(
f"{BASE_URL}/memories/{memory_id}/history?user_id={TEST_USER}",
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Memory History")
assert response.status_code == 200, f"Memory history failed with {response.status_code}"
assert response.status_code == 200, (
f"Memory history failed with {response.status_code}"
)
data = response.json()
assert "memory_id" in data, "History response missing 'memory_id'"
assert "history" in data, "History response missing 'history'"
assert "message" in data, "History response missing success message"
assert data["memory_id"] == memory_id, f"Wrong memory_id in response: {data['memory_id']}"
assert data["memory_id"] == memory_id, (
f"Wrong memory_id in response: {data['memory_id']}"
)
if verbose:
print(f" Retrieved history for memory {memory_id}")
print(f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}")
print(
f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}"
)
def test_update_memory(verbose):
"""Test updating a specific memory"""
# First get a memory to update
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for update test"
memories = response.json()
@ -337,10 +408,13 @@ def test_update_memory(verbose):
# Update the memory
payload = {
"memory_id": memory_id,
"content": f"UPDATED: {original_content}"
"user_id": TEST_USER,
"content": f"UPDATED: {original_content}",
}
response = requests.put(f"{BASE_URL}/memories", json=payload, timeout=10)
response = requests.put(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Update")
assert response.status_code == 200, f"Update failed with {response.status_code}"
@ -352,15 +426,15 @@ def test_update_memory(verbose):
print(f" Updated memory {memory_id}")
print(f" Original: {original_content[:30]}...")
def test_chat_with_memory(verbose):
"""Test memory-enhanced chat functionality"""
payload = {
"message": "What company do I work for?",
"user_id": TEST_USER
}
payload = {"message": "What company do I work for?", "user_id": TEST_USER}
try:
response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=90)
response = requests.post(
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=90
)
log_response(response, verbose, "Chat")
assert response.status_code == 200, f"Chat failed with {response.status_code}"
@ -383,12 +457,15 @@ def test_chat_with_memory(verbose):
print(" Chat endpoint timed out (LLM API may be slow)")
# Still test that the endpoint exists and accepts requests
try:
response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=5)
response = requests.post(
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=5
)
except requests.exceptions.ReadTimeout:
# This is expected - endpoint exists but processing is slow
if verbose:
print(" Chat endpoint confirmed active (processing timeout expected)")
def test_graph_relationships_creation(verbose):
"""Test graph relationships creation with entity-rich memories"""
# Create a separate test user for graph relationship testing
@ -397,41 +474,67 @@ def test_graph_relationships_creation(verbose):
# Add memories with clear entity relationships
payload = {
"messages": [
{"role": "user", "content": "John Smith works at Microsoft as a Senior Software Engineer"},
{"role": "user", "content": "John Smith is friends with Sarah Johnson who works at Google"},
{"role": "user", "content": "Sarah Johnson lives in Seattle and loves hiking"},
{
"role": "user",
"content": "John Smith works at Microsoft as a Senior Software Engineer",
},
{
"role": "user",
"content": "John Smith is friends with Sarah Johnson who works at Google",
},
{
"role": "user",
"content": "Sarah Johnson lives in Seattle and loves hiking",
},
{"role": "user", "content": "Microsoft is located in Redmond, Washington"},
{"role": "user", "content": "John Smith and Sarah Johnson both graduated from Stanford University"}
{
"role": "user",
"content": "John Smith and Sarah Johnson both graduated from Stanford University",
},
],
"user_id": graph_test_user,
"metadata": {"test": "graph_relationships", "scenario": "entity_creation"}
"metadata": {"test": "graph_relationships", "scenario": "entity_creation"},
}
response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60)
response = requests.post(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
)
log_response(response, verbose, "Add Graph Memories")
assert response.status_code == 200, f"Add graph memories failed with {response.status_code}"
assert response.status_code == 200, (
f"Add graph memories failed with {response.status_code}"
)
data = response.json()
assert "added_memories" in data, "Response missing 'added_memories'"
if verbose:
print(f" Added {len(data['added_memories'])} memories for graph relationship testing")
print(
f" Added {len(data['added_memories'])} memories for graph relationship testing"
)
# Wait a moment for graph processing (Mem0 graph extraction can be async)
time.sleep(2)
# Test graph relationships endpoint
response = requests.get(f"{BASE_URL}/graph/relationships/{graph_test_user}", timeout=15)
response = requests.get(
f"{BASE_URL}/graph/relationships/{graph_test_user}",
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Graph Relationships")
assert response.status_code == 200, f"Graph relationships failed with {response.status_code}"
assert response.status_code == 200, (
f"Graph relationships failed with {response.status_code}"
)
graph_data = response.json()
assert "relationships" in graph_data, "Graph response missing 'relationships'"
assert "entities" in graph_data, "Graph response missing 'entities'"
assert "user_id" in graph_data, "Graph response missing 'user_id'"
assert graph_data["user_id"] == graph_test_user, f"Wrong user_id in graph: {graph_data['user_id']}"
assert graph_data["user_id"] == graph_test_user, (
f"Wrong user_id in graph: {graph_data['user_id']}"
)
relationships = graph_data["relationships"]
entities = graph_data["entities"]
@ -451,16 +554,24 @@ def test_graph_relationships_creation(verbose):
# Print sample entities if they exist
if entities:
print(f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}")
print(
f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}"
)
# Verify relationship structure (if relationships exist)
for rel in relationships:
assert "source" in rel or "from" in rel, f"Relationship missing source/from: {rel}"
assert "source" in rel or "from" in rel, (
f"Relationship missing source/from: {rel}"
)
assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}"
assert "relationship" in rel or "type" in rel, f"Relationship missing type: {rel}"
assert "relationship" in rel or "type" in rel, (
f"Relationship missing type: {rel}"
)
# Clean up graph test user memories
cleanup_response = requests.delete(f"{BASE_URL}/memories/user/{graph_test_user}", timeout=15)
cleanup_response = requests.delete(
f"{BASE_URL}/memories/user/{graph_test_user}", headers=AUTH_HEADERS, timeout=15
)
assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories"
if verbose:
@ -469,12 +580,17 @@ def test_graph_relationships_creation(verbose):
# Note: We expect some relationships even if graph extraction is basic
# The test passes if the endpoint works and returns proper structure
def test_graph_relationships(verbose):
"""Test graph relationships endpoint"""
response = requests.get(f"{BASE_URL}/graph/relationships/{TEST_USER}", timeout=15)
response = requests.get(
f"{BASE_URL}/graph/relationships/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Graph")
assert response.status_code == 200, f"Graph endpoint failed with {response.status_code}"
assert response.status_code == 200, (
f"Graph endpoint failed with {response.status_code}"
)
data = response.json()
assert "relationships" in data, "Graph response missing 'relationships'"
@ -486,10 +602,13 @@ def test_graph_relationships(verbose):
print(f" Relationships: {len(data['relationships'])}")
print(f" Entities: {len(data['entities'])}")
def test_delete_specific_memory(verbose):
"""Test deleting a specific memory"""
# Get a memory to delete
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for deletion test"
memories = response.json()
@ -498,7 +617,9 @@ def test_delete_specific_memory(verbose):
memory_id = memories[0]["id"]
# Delete the memory
response = requests.delete(f"{BASE_URL}/memories/{memory_id}", timeout=10)
response = requests.delete(
f"{BASE_URL}/memories/{memory_id}", headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Delete")
assert response.status_code == 200, f"Delete failed with {response.status_code}"
@ -509,9 +630,12 @@ def test_delete_specific_memory(verbose):
if verbose:
print(f" Deleted memory {memory_id}")
def test_delete_all_user_memories(verbose):
"""Test deleting all memories for a user"""
response = requests.delete(f"{BASE_URL}/memories/user/{TEST_USER}", timeout=15)
response = requests.delete(
f"{BASE_URL}/memories/user/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Delete All")
assert response.status_code == 200, f"Delete all failed with {response.status_code}"
@ -522,12 +646,17 @@ def test_delete_all_user_memories(verbose):
if verbose:
print(f"Deleted all memories for {TEST_USER}")
def test_cleanup_verification(verbose):
"""Verify cleanup was successful"""
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=10", timeout=10)
response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=10", headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Cleanup Check")
assert response.status_code == 200, f"Cleanup verification failed with {response.status_code}"
assert response.status_code == 200, (
f"Cleanup verification failed with {response.status_code}"
)
memories = response.json()
assert isinstance(memories, list), "Should return list even if empty"
@ -539,5 +668,79 @@ def test_cleanup_verification(verbose):
if verbose:
print(" Cleanup successful - no memories remain")
# ================== SECURITY TEST FUNCTIONS ==================
def test_auth_required_endpoints(verbose):
"""Test that protected endpoints require authentication"""
endpoints_requiring_auth = [
("GET", f"{BASE_URL}/memories/{TEST_USER}"),
("POST", f"{BASE_URL}/memories/search"),
("GET", f"{BASE_URL}/stats"),
("GET", f"{BASE_URL}/models"),
("GET", f"{BASE_URL}/users"),
]
for method, url in endpoints_requiring_auth:
if method == "GET":
response = requests.get(url, timeout=5)
else:
response = requests.post(
url, json={"query": "test", "user_id": TEST_USER}, timeout=5
)
assert response.status_code in [401, 403], (
f"{method} {url} should require auth, got {response.status_code}"
)
if verbose:
print(f" {method} {url}: {response.status_code} (auth required)")
def test_ownership_verification(verbose):
"""Test that users can only access their own data"""
other_user = "other_user_not_me"
response = requests.get(
f"{BASE_URL}/memories/{other_user}", headers=AUTH_HEADERS, timeout=5
)
assert response.status_code in [403, 404], (
f"Accessing other user's memories should be denied, got {response.status_code}"
)
if verbose:
print(f" Ownership check passed: {response.status_code}")
def test_request_size_limit(verbose):
"""Test request size limit enforcement (10MB max)"""
large_payload = {
"messages": [{"role": "user", "content": "x" * (11 * 1024 * 1024)}],
"user_id": TEST_USER,
}
try:
response = requests.post(
f"{BASE_URL}/memories",
json=large_payload,
headers={**AUTH_HEADERS, "Content-Length": str(11 * 1024 * 1024)},
timeout=5,
)
assert response.status_code == 413, (
f"Large request should return 413, got {response.status_code}"
)
if verbose:
print(f" Request size limit enforced: {response.status_code}")
except requests.exceptions.RequestException as e:
if verbose:
print(
f" Request size limit test: connection issue (expected for large payload)"
)
if __name__ == "__main__":
main()