Compare commits

...

2 commits

Author SHA1 Message Date
a228780146 production improvements: configurable embeddings, v1.1, O(1) ownership, retries
- Make Ollama URL configurable via OLLAMA_BASE_URL env var
- Add version: v1.1 to Mem0 config (required for latest features)
- Make embedding model and dimensions configurable
- Fix ownership check: O(1) lookup instead of fetching 10k records
- Add tenacity retry logic for database operations
2026-01-15 23:01:18 +05:30
50edce2d3c security hardening: add auth, rate limiting, fix info disclosure
- Add auth to /models and /users endpoints
- Add rate limiting to all endpoints (10-120/min based on operation type)
- Fix 11 info disclosure issues (detail=str(e) -> generic message)
- Fix 2 silent except blocks with proper logging
- Fix 7 raise e -> raise for proper exception chaining
- Fix health check to not expose exception details
- Update tests with X-API-Key headers and security tests
2026-01-15 22:41:24 +05:30
6 changed files with 1099 additions and 483 deletions

View file

@ -11,39 +11,87 @@ class Settings(BaseSettings):
"""Application settings loaded from environment variables.""" """Application settings loaded from environment variables."""
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file=".env", case_sensitive=False, extra="ignore"
case_sensitive=False,
extra='ignore'
) )
# API Configuration # API Configuration
# Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env) # Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env)
openai_api_key: str = Field(validation_alias=AliasChoices('OPENAI_API_KEY', 'OPENAI_COMPAT_API_KEY', 'openai_api_key')) openai_api_key: str = Field(
openai_base_url: str = Field(validation_alias=AliasChoices('OPENAI_BASE_URL', 'OPENAI_COMPAT_BASE_URL', 'openai_base_url')) validation_alias=AliasChoices(
cohere_api_key: str = Field(validation_alias=AliasChoices('COHERE_API_KEY', 'cohere_api_key')) "OPENAI_API_KEY", "OPENAI_COMPAT_API_KEY", "openai_api_key"
)
)
openai_base_url: str = Field(
validation_alias=AliasChoices(
"OPENAI_BASE_URL", "OPENAI_COMPAT_BASE_URL", "openai_base_url"
)
)
cohere_api_key: str = Field(
validation_alias=AliasChoices("COHERE_API_KEY", "cohere_api_key")
)
# Database Configuration # Database Configuration
qdrant_host: str = Field(default="localhost", validation_alias=AliasChoices('QDRANT_HOST', 'qdrant_host')) qdrant_host: str = Field(
qdrant_port: int = Field(default=6333, validation_alias=AliasChoices('QDRANT_PORT', 'qdrant_port')) default="localhost", validation_alias=AliasChoices("QDRANT_HOST", "qdrant_host")
qdrant_collection_name: str = Field(default="mem0", validation_alias=AliasChoices('QDRANT_COLLECTION_NAME', 'qdrant_collection_name')) )
qdrant_port: int = Field(
default=6333, validation_alias=AliasChoices("QDRANT_PORT", "qdrant_port")
)
qdrant_collection_name: str = Field(
default="mem0",
validation_alias=AliasChoices(
"QDRANT_COLLECTION_NAME", "qdrant_collection_name"
),
)
# Neo4j Configuration # Neo4j Configuration
neo4j_uri: str = Field(default="bolt://localhost:7687", validation_alias=AliasChoices('NEO4J_URI', 'neo4j_uri')) neo4j_uri: str = Field(
neo4j_username: str = Field(default="neo4j", validation_alias=AliasChoices('NEO4J_USERNAME', 'neo4j_username')) default="bolt://localhost:7687",
neo4j_password: str = Field(default="mem0_neo4j_password", validation_alias=AliasChoices('NEO4J_PASSWORD', 'neo4j_password')) validation_alias=AliasChoices("NEO4J_URI", "neo4j_uri"),
)
neo4j_username: str = Field(
default="neo4j",
validation_alias=AliasChoices("NEO4J_USERNAME", "neo4j_username"),
)
neo4j_password: str = Field(
default="mem0_neo4j_password",
validation_alias=AliasChoices("NEO4J_PASSWORD", "neo4j_password"),
)
# Application Configuration # Application Configuration
log_level: str = Field(default="INFO", validation_alias=AliasChoices('LOG_LEVEL', 'log_level')) log_level: str = Field(
cors_origins: str = Field(default="http://localhost:3000", validation_alias=AliasChoices('CORS_ORIGINS', 'cors_origins')) default="INFO", validation_alias=AliasChoices("LOG_LEVEL", "log_level")
)
cors_origins: str = Field(
default="http://localhost:3000",
validation_alias=AliasChoices("CORS_ORIGINS", "cors_origins"),
)
# Model Configuration - Ultra-minimal (single model) # Model Configuration - Ultra-minimal (single model)
default_model: str = Field(default="claude-sonnet-4", validation_alias=AliasChoices('DEFAULT_MODEL', 'default_model')) default_model: str = Field(
default="claude-sonnet-4",
validation_alias=AliasChoices("DEFAULT_MODEL", "default_model"),
)
# Embedder Configuration
ollama_base_url: str = Field(
default="http://host.docker.internal:11434",
validation_alias=AliasChoices("OLLAMA_BASE_URL", "ollama_base_url"),
)
embedding_model: str = Field(
default="qwen3-embedding:4b-q8_0",
validation_alias=AliasChoices("EMBEDDING_MODEL", "embedding_model"),
)
embedding_dims: int = Field(
default=2560, validation_alias=AliasChoices("EMBEDDING_DIMS", "embedding_dims")
)
# Authentication Configuration # Authentication Configuration
# Format: JSON string mapping API keys to user IDs # Format: JSON string mapping API keys to user IDs
# Example: {"api_key_123": "alice", "api_key_456": "bob"} # Example: {"api_key_123": "alice", "api_key_456": "bob"}
api_keys: str = Field(default="{}", validation_alias=AliasChoices('API_KEYS', 'api_keys')) api_keys: str = Field(
default="{}", validation_alias=AliasChoices("API_KEYS", "api_keys")
)
@property @property
def cors_origins_list(self) -> List[str]: def cors_origins_list(self) -> List[str]:

View file

@ -1,22 +1,46 @@
"""Main FastAPI application for Mem0 Interface POC.""" """Main FastAPI application for Mem0 Interface POC."""
import json
import logging import logging
import time import time
from datetime import datetime from datetime import datetime
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
import structlog import structlog
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from config import settings from config import settings
# Rate limiter - uses IP address as key, falls back to API key for authenticated requests
def get_rate_limit_key(request: Request) -> str:
"""Get rate limit key - prefer API key if available, otherwise IP."""
api_key = request.headers.get("x-api-key", "")
if api_key:
return f"apikey:{api_key[:16]}" # Use first 16 chars of API key
return get_remote_address(request)
limiter = Limiter(key_func=get_rate_limit_key)
from models import ( from models import (
ChatRequest, MemoryAddRequest, MemoryAddResponse, ChatRequest,
MemorySearchRequest, MemorySearchResponse, MemoryUpdateRequest, MemoryAddRequest,
MemoryItem, GraphResponse, HealthResponse, ErrorResponse, MemoryAddResponse,
GlobalStatsResponse, UserStatsResponse MemorySearchRequest,
MemorySearchResponse,
MemoryUpdateRequest,
MemoryItem,
GraphResponse,
HealthResponse,
ErrorResponse,
GlobalStatsResponse,
UserStatsResponse,
) )
from mem0_manager import mem0_manager from mem0_manager import mem0_manager
from auth import get_current_user, auth_service from auth import get_current_user, auth_service
@ -32,7 +56,7 @@ structlog.configure(
structlog.processors.StackInfoRenderer(), structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info, structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(), structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer() structlog.processors.JSONRenderer(),
], ],
context_class=dict, context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(), logger_factory=structlog.stdlib.LoggerFactory(),
@ -62,6 +86,7 @@ async def lifespan(app: FastAPI):
mcp_context = None mcp_context = None
try: try:
from mcp_server import mcp_lifespan from mcp_server import mcp_lifespan
mcp_context = mcp_lifespan() mcp_context = mcp_lifespan()
await mcp_context.__aenter__() await mcp_context.__aenter__()
except ImportError: except ImportError:
@ -86,19 +111,47 @@ app = FastAPI(
title="Mem0 Interface POC", title="Mem0 Interface POC",
description="Minimal but fully functional Mem0 interface with PostgreSQL and Neo4j integration", description="Minimal but fully functional Mem0 interface with PostgreSQL and Neo4j integration",
version="1.0.0", version="1.0.0",
lifespan=lifespan lifespan=lifespan,
) )
# Add CORS middleware - Allow all origins for development # Add rate limiter to app state and exception handler
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Add CORS middleware - Allow all origins (secured via API key auth)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], # Allow all origins for development allow_origins=["*"], # Allow all origins - secured via API key authentication
allow_credentials=False, # Must be False when allow_origins=["*"] allow_credentials=False, # Must be False when allow_origins=["*"]
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
# Request size limit middleware - prevent DoS via large payloads
MAX_REQUEST_SIZE = 10 * 1024 * 1024 # 10MB limit
@app.middleware("http")
async def limit_request_size(request, call_next):
"""Reject requests that exceed the maximum allowed size."""
content_length = request.headers.get("content-length")
if content_length:
try:
if int(content_length) > MAX_REQUEST_SIZE:
return JSONResponse(
status_code=413,
content={
"error": "Request payload too large",
"max_size_bytes": MAX_REQUEST_SIZE,
"max_size_mb": MAX_REQUEST_SIZE / (1024 * 1024),
},
)
except ValueError:
pass # Invalid content-length header, let it through for other validation
return await call_next(request)
# Request logging middleware with monitoring # Request logging middleware with monitoring
@app.middleware("http") @app.middleware("http")
async def log_requests(request, call_next): async def log_requests(request, call_next):
@ -111,19 +164,19 @@ async def log_requests(request, call_next):
# Extract user_id from request if available # Extract user_id from request if available
user_id = None user_id = None
if request.method == "POST": if request.method == "POST":
# Try to extract user_id from request body for POST requests
try: try:
body = await request.body() body = await request.body()
if body: if body:
import json
data = json.loads(body) data = json.loads(body)
user_id = data.get('user_id') user_id = data.get("user_id")
except: except json.JSONDecodeError:
pass pass # Non-JSON body, user_id extraction not possible
except Exception as e:
logger.debug("Could not extract user_id from request body", error=str(e))
elif "user_id" in str(request.url.path): elif "user_id" in str(request.url.path):
# Extract user_id from path for GET requests # Extract user_id from path for GET requests
path_parts = request.url.path.split('/') path_parts = request.url.path.split("/")
if len(path_parts) > 2 and path_parts[-2] in ['memories', 'stats']: if len(path_parts) > 2 and path_parts[-2] in ["memories", "stats"]:
user_id = path_parts[-1] user_id = path_parts[-1]
# Log start of request # Log start of request
@ -132,7 +185,7 @@ async def log_requests(request, call_next):
correlation_id=correlation_id, correlation_id=correlation_id,
method=request.method, method=request.method,
path=request.url.path, path=request.url.path,
user_id=user_id user_id=user_id,
) )
response = await call_next(request) response = await call_next(request)
@ -153,7 +206,7 @@ async def log_requests(request, call_next):
status_code=response.status_code, status_code=response.status_code,
process_time_ms=round(process_time_ms, 2), process_time_ms=round(process_time_ms, 2),
user_id=user_id, user_id=user_id,
slow_request=True slow_request=True,
) )
elif response.status_code >= 400: elif response.status_code >= 400:
logger.error( logger.error(
@ -164,7 +217,7 @@ async def log_requests(request, call_next):
status_code=response.status_code, status_code=response.status_code,
process_time_ms=round(process_time_ms, 2), process_time_ms=round(process_time_ms, 2),
user_id=user_id, user_id=user_id,
slow_request=False slow_request=False,
) )
else: else:
logger.info( logger.info(
@ -175,7 +228,7 @@ async def log_requests(request, call_next):
status_code=response.status_code, status_code=response.status_code,
process_time_ms=round(process_time_ms, 2), process_time_ms=round(process_time_ms, 2),
user_id=user_id, user_id=user_id,
slow_request=False slow_request=False,
) )
return response return response
@ -184,11 +237,23 @@ async def log_requests(request, call_next):
# Exception handlers # Exception handlers
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def global_exception_handler(request, exc): async def global_exception_handler(request, exc):
"""Global exception handler.""" """Global exception handler - logs details but returns generic message."""
logger.error(f"Unhandled exception: {exc}", exc_info=True) # Log full exception details for debugging (internal only)
logger.error(
"Unhandled exception",
exc_info=True,
path=request.url.path,
method=request.method,
error_type=type(exc).__name__,
error_message=str(exc),
)
# Return generic error to client - don't expose internal details
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content={"error": "Internal server error", "detail": str(exc)} content={
"error": "Internal server error",
"message": "An unexpected error occurred",
},
) )
@ -198,50 +263,59 @@ async def health_check():
"""Check the health of all services.""" """Check the health of all services."""
try: try:
services = await mem0_manager.health_check() services = await mem0_manager.health_check()
overall_status = "healthy" if all("healthy" in status for status in services.values()) else "degraded" overall_status = (
"healthy"
if all("healthy" in status for status in services.values())
else "degraded"
)
return HealthResponse( return HealthResponse(
status=overall_status, status=overall_status,
services=services, services=services,
timestamp=datetime.utcnow().isoformat() timestamp=datetime.utcnow().isoformat(),
) )
except Exception as e: except Exception as e:
logger.error(f"Health check failed: {e}") logger.error(f"Health check failed: {e}", exc_info=True)
return HealthResponse( return HealthResponse(
status="unhealthy", status="unhealthy",
services={"error": str(e)}, services={"error": "Health check failed - see logs for details"},
timestamp=datetime.utcnow().isoformat() timestamp=datetime.utcnow().isoformat(),
) )
# Core chat endpoint with memory enhancement # Core chat endpoint with memory enhancement
@app.post("/chat") @app.post("/chat")
@limiter.limit("30/minute") # Chat is expensive - limit to 30/min
async def chat_with_memory( async def chat_with_memory(
request: ChatRequest, request: Request,
authenticated_user: str = Depends(get_current_user) chat_request: ChatRequest,
authenticated_user: str = Depends(get_current_user),
): ):
"""Ultra-minimal chat endpoint - pure Mem0 + custom endpoint.""" """Ultra-minimal chat endpoint - pure Mem0 + custom endpoint."""
try: try:
# Verify user can only access their own data # Verify user can only access their own data
if authenticated_user != request.user_id: if authenticated_user != chat_request.user_id:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"Access denied: You can only chat as yourself (authenticated as '{authenticated_user}')" detail=f"Access denied: You can only chat as yourself (authenticated as '{authenticated_user}')",
) )
logger.info(f"Processing chat request for user: {request.user_id}") logger.info(f"Processing chat request for user: {chat_request.user_id}")
# Convert ChatMessage objects to dict format if context provided # Convert ChatMessage objects to dict format if context provided
context_dict = None context_dict = None
if request.context: if chat_request.context:
context_dict = [{"role": msg.role, "content": msg.content} for msg in request.context] context_dict = [
{"role": msg.role, "content": msg.content}
for msg in chat_request.context
]
result = await mem0_manager.chat_with_memory( result = await mem0_manager.chat_with_memory(
message=request.message, message=chat_request.message,
user_id=request.user_id, user_id=chat_request.user_id,
agent_id=request.agent_id, agent_id=chat_request.agent_id,
run_id=request.run_id, run_id=chat_request.run_id,
context=context_dict context=context_dict,
) )
return result return result
@ -250,32 +324,37 @@ async def chat_with_memory(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error in chat endpoint: {e}") logger.error(f"Error in chat endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Memory management endpoints - pure Mem0 passthroughs # Memory management endpoints - pure Mem0 passthroughs
@app.post("/memories") @app.post("/memories")
@limiter.limit("60/minute") # Memory operations - 60/min
async def add_memories( async def add_memories(
request: MemoryAddRequest, request: Request,
authenticated_user: str = Depends(get_current_user) memory_request: MemoryAddRequest,
authenticated_user: str = Depends(get_current_user),
): ):
"""Add memories - pure Mem0 passthrough.""" """Add memories - pure Mem0 passthrough."""
try: try:
# Verify user can only add to their own memories # Verify user can only add to their own memories
if authenticated_user != request.user_id: if authenticated_user != memory_request.user_id:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"Access denied: You can only add memories for yourself (authenticated as '{authenticated_user}')" detail=f"Access denied: You can only add memories for yourself (authenticated as '{authenticated_user}')",
) )
logger.info(f"Adding memories for user: {request.user_id}") logger.info(f"Adding memories for user: {memory_request.user_id}")
result = await mem0_manager.add_memories( result = await mem0_manager.add_memories(
messages=request.messages, messages=memory_request.messages,
user_id=request.user_id, user_id=memory_request.user_id,
agent_id=request.agent_id, agent_id=memory_request.agent_id,
run_id=request.run_id, run_id=memory_request.run_id,
metadata=request.metadata metadata=memory_request.metadata,
) )
return result return result
@ -284,33 +363,40 @@ async def add_memories(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error adding memories: {e}") logger.error(f"Error adding memories: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.post("/memories/search") @app.post("/memories/search")
@limiter.limit("120/minute") # Search is lighter - 120/min
async def search_memories( async def search_memories(
request: MemorySearchRequest, request: Request,
authenticated_user: str = Depends(get_current_user) search_request: MemorySearchRequest,
authenticated_user: str = Depends(get_current_user),
): ):
"""Search memories - pure Mem0 passthrough.""" """Search memories - pure Mem0 passthrough."""
try: try:
# Verify user can only search their own memories # Verify user can only search their own memories
if authenticated_user != request.user_id: if authenticated_user != search_request.user_id:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"Access denied: You can only search your own memories (authenticated as '{authenticated_user}')" detail=f"Access denied: You can only search your own memories (authenticated as '{authenticated_user}')",
) )
logger.info(f"Searching memories for user: {request.user_id}, query: {request.query}") logger.info(
f"Searching memories for user: {search_request.user_id}, query: {search_request.query}"
)
result = await mem0_manager.search_memories( result = await mem0_manager.search_memories(
query=request.query, query=search_request.query,
user_id=request.user_id, user_id=search_request.user_id,
limit=request.limit, limit=search_request.limit,
threshold=request.threshold or 0.2, threshold=search_request.threshold or 0.2,
filters=request.filters, filters=search_request.filters,
agent_id=request.agent_id, agent_id=search_request.agent_id,
run_id=request.run_id run_id=search_request.run_id,
) )
return result return result
@ -319,16 +405,21 @@ async def search_memories(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error searching memories: {e}") logger.error(f"Error searching memories: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.get("/memories/{user_id}") @app.get("/memories/{user_id}")
@limiter.limit("120/minute")
async def get_user_memories( async def get_user_memories(
request: Request,
user_id: str, user_id: str,
authenticated_user: str = Depends(get_current_user), authenticated_user: str = Depends(get_current_user),
limit: int = 10, limit: int = 10,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None run_id: Optional[str] = None,
): ):
"""Get all memories for a user with hierarchy filtering - pure Mem0 passthrough.""" """Get all memories for a user with hierarchy filtering - pure Mem0 passthrough."""
try: try:
@ -336,16 +427,13 @@ async def get_user_memories(
if authenticated_user != user_id: if authenticated_user != user_id:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"Access denied: You can only retrieve your own memories (authenticated as '{authenticated_user}')" detail=f"Access denied: You can only retrieve your own memories (authenticated as '{authenticated_user}')",
) )
logger.info(f"Retrieving memories for user: {user_id}") logger.info(f"Retrieving memories for user: {user_id}")
memories = await mem0_manager.get_user_memories( memories = await mem0_manager.get_user_memories(
user_id=user_id, user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id
limit=limit,
agent_id=agent_id,
run_id=run_id
) )
return memories return memories
@ -354,28 +442,44 @@ async def get_user_memories(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error retrieving user memories: {e}") logger.error(f"Error retrieving user memories: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.put("/memories") @app.put("/memories")
@limiter.limit("60/minute")
async def update_memory( async def update_memory(
request: MemoryUpdateRequest, request: Request,
authenticated_user: str = Depends(get_current_user) update_request: MemoryUpdateRequest,
authenticated_user: str = Depends(get_current_user),
): ):
"""Update memory - pure Mem0 passthrough.""" """Update memory - verifies ownership before update."""
try: try:
# Verify user owns the memory being updated # Verify user owns the memory being updated
if authenticated_user != request.user_id: if authenticated_user != update_request.user_id:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')" detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')",
) )
logger.info(f"Updating memory: {request.memory_id}") # Verify memory ownership with O(1) lookup instead of fetching all memories
if not await mem0_manager.verify_memory_ownership(
update_request.memory_id, authenticated_user
):
raise HTTPException(
status_code=404,
detail=f"Memory '{update_request.memory_id}' not found or access denied",
)
logger.info(
f"Updating memory: {update_request.memory_id}", user_id=authenticated_user
)
result = await mem0_manager.update_memory( result = await mem0_manager.update_memory(
memory_id=request.memory_id, memory_id=update_request.memory_id,
content=request.content, content=update_request.content,
) )
return result return result
@ -384,25 +488,31 @@ async def update_memory(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error updating memory: {e}") logger.error(f"Error updating memory: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.delete("/memories/{memory_id}") @app.delete("/memories/{memory_id}")
@limiter.limit("60/minute")
async def delete_memory( async def delete_memory(
request: Request,
memory_id: str, memory_id: str,
user_id: str, # Add user_id as query parameter for verification authenticated_user: str = Depends(get_current_user),
authenticated_user: str = Depends(get_current_user)
): ):
"""Delete a specific memory.""" """Delete a specific memory - verifies ownership before deletion."""
try: try:
# Verify user owns the memory being deleted # Verify memory ownership with O(1) lookup instead of fetching all memories
if authenticated_user != user_id: if not await mem0_manager.verify_memory_ownership(
memory_id, authenticated_user
):
raise HTTPException( raise HTTPException(
status_code=403, status_code=404,
detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')" detail=f"Memory '{memory_id}' not found or access denied",
) )
logger.info(f"Deleting memory: {memory_id}") logger.info(f"Deleting memory: {memory_id}", user_id=authenticated_user)
result = await mem0_manager.delete_memory(memory_id=memory_id) result = await mem0_manager.delete_memory(memory_id=memory_id)
@ -412,13 +522,16 @@ async def delete_memory(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting memory: {e}") logger.error(f"Error deleting memory: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.delete("/memories/user/{user_id}") @app.delete("/memories/user/{user_id}")
@limiter.limit("10/minute") # Dangerous bulk delete - heavily rate limited
async def delete_user_memories( async def delete_user_memories(
user_id: str, request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
authenticated_user: str = Depends(get_current_user)
): ):
"""Delete all memories for a specific user.""" """Delete all memories for a specific user."""
try: try:
@ -426,7 +539,7 @@ async def delete_user_memories(
if authenticated_user != user_id: if authenticated_user != user_id:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')" detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')",
) )
logger.info(f"Deleting all memories for user: {user_id}") logger.info(f"Deleting all memories for user: {user_id}")
@ -439,14 +552,17 @@ async def delete_user_memories(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting user memories: {e}") logger.error(f"Error deleting user memories: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Graph relationships endpoint - pure Mem0 passthrough # Graph relationships endpoint - pure Mem0 passthrough
@app.get("/graph/relationships/{user_id}") @app.get("/graph/relationships/{user_id}")
@limiter.limit("60/minute")
async def get_graph_relationships( async def get_graph_relationships(
user_id: str, request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
authenticated_user: str = Depends(get_current_user)
): ):
"""Get graph relationships - pure Mem0 passthrough.""" """Get graph relationships - pure Mem0 passthrough."""
try: try:
@ -454,11 +570,13 @@ async def get_graph_relationships(
if authenticated_user != user_id: if authenticated_user != user_id:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"Access denied: You can only view your own relationships (authenticated as '{authenticated_user}')" detail=f"Access denied: You can only view your own relationships (authenticated as '{authenticated_user}')",
) )
logger.info(f"Retrieving graph relationships for user: {user_id}") logger.info(f"Retrieving graph relationships for user: {user_id}")
result = await mem0_manager.get_graph_relationships(user_id=user_id, agent_id=None, run_id=None, limit=10000) result = await mem0_manager.get_graph_relationships(
user_id=user_id, agent_id=None, run_id=None, limit=10000
)
return result return result
@ -466,68 +584,101 @@ async def get_graph_relationships(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error retrieving graph relationships: {e}") logger.error(f"Error retrieving graph relationships: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Memory history endpoint - new feature # Memory history endpoint - new feature
@app.get("/memories/{memory_id}/history") @app.get("/memories/{memory_id}/history")
async def get_memory_history(memory_id: str): @limiter.limit("120/minute")
async def get_memory_history(
request: Request,
memory_id: str,
user_id: str, # Required query param to verify ownership
authenticated_user: str = Depends(get_current_user),
):
"""Get memory change history - pure Mem0 passthrough.""" """Get memory change history - pure Mem0 passthrough."""
try: try:
logger.info(f"Retrieving history for memory: {memory_id}") # Verify user can only access their own memory history
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only view your own memory history (authenticated as '{authenticated_user}')",
)
# Verify memory ownership with O(1) lookup instead of fetching all memories
if not await mem0_manager.verify_memory_ownership(memory_id, user_id):
raise HTTPException(
status_code=404,
detail=f"Memory '{memory_id}' not found or access denied",
)
logger.info(f"Retrieving history for memory: {memory_id}", user_id=user_id)
result = await mem0_manager.get_memory_history(memory_id=memory_id) result = await mem0_manager.get_memory_history(memory_id=memory_id)
return result return result
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"Error retrieving memory history: {e}") logger.error(f"Error retrieving memory history: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Statistics and monitoring endpoints # Statistics and monitoring endpoints
@app.get("/stats", response_model=GlobalStatsResponse) @app.get("/stats", response_model=GlobalStatsResponse)
async def get_global_stats(): @limiter.limit("60/minute")
"""Get global application statistics.""" async def get_global_stats(
request: Request, authenticated_user: str = Depends(get_current_user)
):
"""Get global application statistics - requires authentication."""
try: try:
from monitoring import stats from monitoring import stats
# Get basic stats from monitoring
basic_stats = stats.get_global_stats() basic_stats = stats.get_global_stats()
# Get actual memory count from Mem0 (simplified approach)
try: try:
# This is a rough estimate - in production you might want a more efficient method sample_result = await mem0_manager.search_memories(
sample_result = await mem0_manager.search_memories(query="*", user_id="__stats_check__", limit=1) query="*", user_id="__stats_check__", limit=1
# For now, we'll use the basic stats total_memories value )
# You could implement a more accurate count by querying the database directly total_memories = basic_stats["total_memories"]
total_memories = basic_stats['total_memories'] # Will be 0 for now except Exception:
except:
total_memories = 0 total_memories = 0
return GlobalStatsResponse( return GlobalStatsResponse(
total_memories=total_memories, total_memories=total_memories,
total_users=basic_stats['total_users'], total_users=basic_stats["total_users"],
api_calls_today=basic_stats['api_calls_today'], api_calls_today=basic_stats["api_calls_today"],
avg_response_time_ms=basic_stats['avg_response_time_ms'], avg_response_time_ms=basic_stats["avg_response_time_ms"],
memory_operations={ memory_operations={
"add": basic_stats['memory_operations']['add'], "add": basic_stats["memory_operations"]["add"],
"search": basic_stats['memory_operations']['search'], "search": basic_stats["memory_operations"]["search"],
"update": basic_stats['memory_operations']['update'], "update": basic_stats["memory_operations"]["update"],
"delete": basic_stats['memory_operations']['delete'] "delete": basic_stats["memory_operations"]["delete"],
}, },
uptime_seconds=basic_stats['uptime_seconds'] uptime_seconds=basic_stats["uptime_seconds"],
) )
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"Error getting global stats: {e}") logger.error(f"Error getting global stats: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.get("/stats/{user_id}", response_model=UserStatsResponse) @app.get("/stats/{user_id}", response_model=UserStatsResponse)
@limiter.limit("120/minute")
async def get_user_stats( async def get_user_stats(
user_id: str, request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
authenticated_user: str = Depends(get_current_user)
): ):
"""Get user-specific statistics.""" """Get user-specific statistics."""
try: try:
@ -535,7 +686,7 @@ async def get_user_stats(
if authenticated_user != user_id: if authenticated_user != user_id:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=f"Access denied: You can only view your own statistics (authenticated as '{authenticated_user}')" detail=f"Access denied: You can only view your own statistics (authenticated as '{authenticated_user}')",
) )
from monitoring import stats from monitoring import stats
@ -545,59 +696,75 @@ async def get_user_stats(
# Get actual memory count for this user # Get actual memory count for this user
try: try:
user_memories = await mem0_manager.get_user_memories(user_id=user_id, limit=10000) user_memories = await mem0_manager.get_user_memories(
user_id=user_id, limit=10000
)
memory_count = len(user_memories) memory_count = len(user_memories)
except: except Exception as e:
logger.warning(f"Failed to get memory count for user {user_id}: {e}")
memory_count = 0 memory_count = 0
# Get relationship count for this user # Get relationship count for this user
try: try:
graph_data = await mem0_manager.get_graph_relationships(user_id=user_id, agent_id=None, run_id=None) graph_data = await mem0_manager.get_graph_relationships(
relationship_count = len(graph_data.get('relationships', [])) user_id=user_id, agent_id=None, run_id=None
except: )
relationship_count = len(graph_data.get("relationships", []))
except Exception as e:
logger.warning(f"Failed to get relationship count for user {user_id}: {e}")
relationship_count = 0 relationship_count = 0
return UserStatsResponse( return UserStatsResponse(
user_id=user_id, user_id=user_id,
memory_count=memory_count, memory_count=memory_count,
relationship_count=relationship_count, relationship_count=relationship_count,
last_activity=basic_stats['last_activity'], last_activity=basic_stats["last_activity"],
api_calls_today=basic_stats['api_calls_today'], api_calls_today=basic_stats["api_calls_today"],
avg_response_time_ms=basic_stats['avg_response_time_ms'] avg_response_time_ms=basic_stats["avg_response_time_ms"],
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error getting user stats for {user_id}: {e}") logger.error(f"Error getting user stats for {user_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Utility endpoints # Utility endpoints
@app.get("/models") @app.get("/models")
async def get_available_models(): @limiter.limit("120/minute")
"""Get current model configuration.""" async def get_available_models(
request: Request, authenticated_user: str = Depends(get_current_user)
):
"""Get current model configuration - requires authentication."""
return { return {
"current_model": settings.default_model, "current_model": settings.default_model,
"endpoint": settings.openai_base_url, "endpoint": settings.openai_base_url,
"note": "Using single model with pure Mem0 intelligence" "note": "Using single model with pure Mem0 intelligence",
} }
@app.get("/users") @app.get("/users")
async def get_active_users(): @limiter.limit("60/minute")
"""Get list of users with memories (simplified implementation).""" async def get_active_users(
request: Request, authenticated_user: str = Depends(get_current_user)
):
"""Get list of users with memories (simplified implementation) - requires authentication."""
# This would typically query the database for users with memories # This would typically query the database for users with memories
# For now, return a placeholder # For now, return a placeholder
return { return {
"message": "This endpoint would return users with stored memories", "message": "This endpoint would return users with stored memories",
"note": "Implementation depends on direct database access or Mem0 user enumeration capabilities" "note": "Implementation depends on direct database access or Mem0 user enumeration capabilities",
} }
# Mount MCP server at /mcp endpoint # Mount MCP server at /mcp endpoint
try: try:
from mcp_server import create_mcp_app from mcp_server import create_mcp_app
mcp_app = create_mcp_app() mcp_app = create_mcp_app()
app.mount("/mcp", mcp_app) app.mount("/mcp", mcp_app)
logger.info("MCP server mounted at /mcp") logger.info("MCP server mounted at /mcp")
@ -609,11 +776,12 @@ except Exception as e:
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
print("Starting UVicorn server...") print("Starting UVicorn server...")
uvicorn.run( uvicorn.run(
"main:app", "main:app",
host="0.0.0.0", host="0.0.0.0",
port=8000, port=8000,
log_level=settings.log_level.lower(), log_level=settings.log_level.lower(),
reload=True reload=True,
) )

View file

@ -5,24 +5,47 @@ from typing import Dict, List, Optional, Any
from datetime import datetime from datetime import datetime
from mem0 import Memory from mem0 import Memory
from openai import OpenAI from openai import OpenAI
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
)
from config import settings from config import settings
from monitoring import timed from monitoring import timed
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Retry decorator for database operations (Qdrant, Neo4j)
db_retry = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility # Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
from mem0.llms.openai import OpenAILLM from mem0.llms.openai import OpenAILLM
_original_generate_response = OpenAILLM.generate_response _original_generate_response = OpenAILLM.generate_response
def patched_generate_response(self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs):
def patched_generate_response(
self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs
):
# Remove 'store' parameter as LiteLLM doesn't support it # Remove 'store' parameter as LiteLLM doesn't support it
if hasattr(self.config, 'store'): if hasattr(self.config, "store"):
self.config.store = None self.config.store = None
# Remove 'top_p' to avoid conflict with temperature for Claude models # Remove 'top_p' to avoid conflict with temperature for Claude models
if hasattr(self.config, 'top_p'): if hasattr(self.config, "top_p"):
self.config.top_p = None self.config.top_p = None
return _original_generate_response(self, messages, response_format, tools, tool_choice, **kwargs) return _original_generate_response(
self, messages, response_format, tools, tool_choice, **kwargs
)
OpenAILLM.generate_response = patched_generate_response OpenAILLM.generate_response = patched_generate_response
logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter") logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter")
@ -36,8 +59,12 @@ class Mem0Manager:
def __init__(self): def __init__(self):
# Custom endpoint configuration with graph memory enabled # Custom endpoint configuration with graph memory enabled
logger.info("Initializing ultra-minimal Mem0Manager with custom endpoint with settings:", settings) logger.info(
"Initializing ultra-minimal Mem0Manager with custom endpoint with settings:",
settings,
)
config = { config = {
"version": "v1.1",
"enable_graph": True, "enable_graph": True,
"llm": { "llm": {
"provider": "openai", "provider": "openai",
@ -46,17 +73,16 @@ class Mem0Manager:
"api_key": settings.openai_api_key, "api_key": settings.openai_api_key,
"openai_base_url": settings.openai_base_url, "openai_base_url": settings.openai_base_url,
"temperature": 0.1, "temperature": 0.1,
"top_p": None # Don't use top_p with Claude models "top_p": None,
} },
}, },
"embedder": { "embedder": {
"provider": "ollama", "provider": "ollama",
"config": { "config": {
"model": "qwen3-embedding:4b-q8_0", "model": settings.embedding_model,
# "api_key": settings.embedder_api_key, "ollama_base_url": settings.ollama_base_url,
"ollama_base_url": "http://172.17.0.1:11434", "embedding_dims": settings.embedding_dims,
"embedding_dims": 2560 },
}
}, },
"vector_store": { "vector_store": {
"provider": "qdrant", "provider": "qdrant",
@ -64,38 +90,39 @@ class Mem0Manager:
"collection_name": settings.qdrant_collection_name, "collection_name": settings.qdrant_collection_name,
"host": settings.qdrant_host, "host": settings.qdrant_host,
"port": settings.qdrant_port, "port": settings.qdrant_port,
"embedding_model_dims": 2560, "embedding_model_dims": settings.embedding_dims,
"on_disk": True "on_disk": True,
} },
}, },
"graph_store": { "graph_store": {
"provider": "neo4j", "provider": "neo4j",
"config": { "config": {
"url": settings.neo4j_uri, "url": settings.neo4j_uri,
"username": settings.neo4j_username, "username": settings.neo4j_username,
"password": settings.neo4j_password "password": settings.neo4j_password,
} },
}, },
"reranker": { "reranker": {
"provider": "cohere", "provider": "cohere",
"config": { "config": {
"api_key": settings.cohere_api_key, "api_key": settings.cohere_api_key,
"model": "rerank-english-v3.0", "model": "rerank-english-v3.0",
"top_n": 10 "top_n": 10,
} },
} },
} }
self.memory = Memory.from_config(config) self.memory = Memory.from_config(config)
self.openai_client = OpenAI( self.openai_client = OpenAI(
api_key=settings.openai_api_key, api_key=settings.openai_api_key,
base_url=settings.openai_base_url base_url=settings.openai_base_url,
timeout=60.0, # 60 second timeout for LLM calls
max_retries=2, # Retry failed requests up to 2 times
) )
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint") logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
# Pure passthrough methods - no custom logic # Pure passthrough methods - no custom logic
@db_retry
@timed("add_memories") @timed("add_memories")
async def add_memories( async def add_memories(
self, self,
@ -103,14 +130,14 @@ class Mem0Manager:
user_id: Optional[str] = "default", user_id: Optional[str] = "default",
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Add memories - simplified native Mem0 pattern (10 lines vs 45).""" """Add memories - simplified native Mem0 pattern (10 lines vs 45)."""
try: try:
# Convert ChatMessage objects to dict if needed # Convert ChatMessage objects to dict if needed
formatted_messages = [] formatted_messages = []
for msg in messages: for msg in messages:
if hasattr(msg, 'dict'): if hasattr(msg, "dict"):
formatted_messages.append(msg.dict()) formatted_messages.append(msg.dict())
else: else:
formatted_messages.append(msg) formatted_messages.append(msg)
@ -123,26 +150,35 @@ class Mem0Manager:
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"source": "chat_conversation", "source": "chat_conversation",
"message_count": len(formatted_messages), "message_count": len(formatted_messages),
"auto_generated": True "auto_generated": True,
} }
# Merge user metadata with auto metadata (user metadata takes precedence) # Merge user metadata with auto metadata (user metadata takes precedence)
enhanced_metadata = {**auto_metadata, **combined_metadata} enhanced_metadata = {**auto_metadata, **combined_metadata}
# Direct Mem0 add with enhanced metadata # Direct Mem0 add with enhanced metadata
result = self.memory.add(formatted_messages, user_id=user_id, result = self.memory.add(
agent_id=agent_id, run_id=run_id, formatted_messages,
metadata=enhanced_metadata) user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=enhanced_metadata,
)
return { return {
"added_memories": result if isinstance(result, list) else [result], "added_memories": result if isinstance(result, list) else [result],
"message": "Memories added successfully", "message": "Memories added successfully",
"hierarchy": {"user_id": user_id, "agent_id": agent_id, "run_id": run_id} "hierarchy": {
"user_id": user_id,
"agent_id": agent_id,
"run_id": run_id,
},
} }
except Exception as e: except Exception as e:
logger.error(f"Error adding memories: {e}") logger.error(f"Error adding memories: {e}")
raise e raise
@db_retry
@timed("search_memories") @timed("search_memories")
async def search_memories( async def search_memories(
self, self,
@ -155,37 +191,79 @@ class Mem0Manager:
# rerank: bool = False, # rerank: bool = False,
# filter_memories: bool = False, # filter_memories: bool = False,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None run_id: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Search memories - native Mem0 pattern""" """Search memories - native Mem0 pattern"""
try: try:
# Minimal empty query protection for API compatibility # Minimal empty query protection for API compatibility
if not query or query.strip() == "": if not query or query.strip() == "":
return {"memories": [], "total_count": 0, "query": query, "note": "Empty query provided, no results returned. Use a specific query to search memories."} return {
"memories": [],
"total_count": 0,
"query": query,
"note": "Empty query provided, no results returned. Use a specific query to search memories.",
}
# Direct Mem0 search - trust native handling # Direct Mem0 search - trust native handling
result = self.memory.search(query=query, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit, threshold=threshold, filters=filters) result = self.memory.search(
return {"memories": result.get("results", []), "total_count": len(result.get("results", [])), "query": query} query=query,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit,
threshold=threshold,
filters=filters,
)
return {
"memories": result.get("results", []),
"total_count": len(result.get("results", [])),
"query": query,
}
except Exception as e: except Exception as e:
logger.error(f"Error searching memories: {e}") logger.error(f"Error searching memories: {e}")
raise e raise
@db_retry
async def get_user_memories( async def get_user_memories(
self, self,
user_id: str, user_id: str,
limit: int = 10, limit: int = 10,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None filters: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Get all memories for a user - native Mem0 pattern.""" """Get all memories for a user - native Mem0 pattern."""
try: try:
# Direct Mem0 get_all call - trust native parameter handling # Direct Mem0 get_all call - trust native parameter handling
result = self.memory.get_all(user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id, filters=filters) result = self.memory.get_all(
user_id=user_id,
limit=limit,
agent_id=agent_id,
run_id=run_id,
filters=filters,
)
return result.get("results", []) return result.get("results", [])
except Exception as e: except Exception as e:
logger.error(f"Error getting user memories: {e}") logger.error(f"Error getting user memories: {e}")
raise e raise
@db_retry
async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
"""Get a single memory by ID. Returns None if not found."""
try:
result = self.memory.get(memory_id=memory_id)
return result
except Exception as e:
logger.debug(f"Memory {memory_id} not found or error: {e}")
return None
async def verify_memory_ownership(self, memory_id: str, user_id: str) -> bool:
"""Check if a memory belongs to a user. O(1) instead of O(n)."""
memory = await self.get_memory(memory_id)
if memory is None:
return False
return memory.get("user_id") == user_id
@db_retry
@timed("update_memory") @timed("update_memory")
async def update_memory( async def update_memory(
self, self,
@ -194,15 +272,13 @@ class Mem0Manager:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Update memory - pure Mem0 passthrough.""" """Update memory - pure Mem0 passthrough."""
try: try:
result = self.memory.update( result = self.memory.update(memory_id=memory_id, data=content)
memory_id=memory_id,
data=content
)
return {"message": "Memory updated successfully", "result": result} return {"message": "Memory updated successfully", "result": result}
except Exception as e: except Exception as e:
logger.error(f"Error updating memory: {e}") logger.error(f"Error updating memory: {e}")
raise e raise
@db_retry
@timed("delete_memory") @timed("delete_memory")
async def delete_memory(self, memory_id: str) -> Dict[str, Any]: async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
"""Delete memory - pure Mem0 passthrough.""" """Delete memory - pure Mem0 passthrough."""
@ -211,7 +287,7 @@ class Mem0Manager:
return {"message": "Memory deleted successfully"} return {"message": "Memory deleted successfully"}
except Exception as e: except Exception as e:
logger.error(f"Error deleting memory: {e}") logger.error(f"Error deleting memory: {e}")
raise e raise
async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]: async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]:
"""Delete all user memories - pure Mem0 passthrough.""" """Delete all user memories - pure Mem0 passthrough."""
@ -220,7 +296,7 @@ class Mem0Manager:
return {"message": "All user memories deleted successfully"} return {"message": "All user memories deleted successfully"}
except Exception as e: except Exception as e:
logger.error(f"Error deleting user memories: {e}") logger.error(f"Error deleting user memories: {e}")
raise e raise
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]: async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
"""Get memory change history - pure Mem0 passthrough.""" """Get memory change history - pure Mem0 passthrough."""
@ -229,22 +305,24 @@ class Mem0Manager:
return { return {
"memory_id": memory_id, "memory_id": memory_id,
"history": history, "history": history,
"message": "Memory history retrieved successfully" "message": "Memory history retrieved successfully",
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting memory history: {e}") logger.error(f"Error getting memory history: {e}")
raise e raise
async def get_graph_relationships(
async def get_graph_relationships(self, user_id: Optional[str], agent_id: Optional[str], run_id: Optional[str], limit: int = 50) -> Dict[str, Any]: self,
user_id: Optional[str],
agent_id: Optional[str],
run_id: Optional[str],
limit: int = 50,
) -> Dict[str, Any]:
"""Get graph relationships - using correct Mem0 get_all() method.""" """Get graph relationships - using correct Mem0 get_all() method."""
try: try:
# Use get_all() to retrieve memories with graph relationships # Use get_all() to retrieve memories with graph relationships
result = self.memory.get_all( result = self.memory.get_all(
user_id=user_id, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit
agent_id=agent_id,
run_id=run_id,
limit=limit
) )
# Extract relationships from Mem0's response structure # Extract relationships from Mem0's response structure
@ -272,7 +350,7 @@ class Mem0Manager:
"agent_id": agent_id, "agent_id": agent_id,
"run_id": run_id, "run_id": run_id,
"total_memories": len(result.get("results", [])), "total_memories": len(result.get("results", [])),
"total_relationships": len(relationships) "total_relationships": len(relationships),
} }
except Exception as e: except Exception as e:
@ -286,7 +364,7 @@ class Mem0Manager:
"run_id": run_id, "run_id": run_id,
"total_memories": 0, "total_memories": 0,
"total_relationships": 0, "total_relationships": 0,
"error": str(e) "error": str(e),
} }
@timed("chat_with_memory") @timed("chat_with_memory")
@ -304,49 +382,70 @@ class Mem0Manager:
try: try:
total_start_time = time.time() total_start_time = time.time()
print(f"\n🚀 Starting chat request for user: {user_id}") logger.info("Starting chat request", user_id=user_id)
# Stage 1: Memory Search
search_start_time = time.time() search_start_time = time.time()
search_result = self.memory.search(query=message, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=10, threshold=0.3) search_result = self.memory.search(
query=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=10,
threshold=0.3,
)
relevant_memories = search_result.get("results", []) relevant_memories = search_result.get("results", [])
memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories) memories_str = "\n".join(
f"- {entry['memory']}" for entry in relevant_memories
)
search_time = time.time() - search_start_time search_time = time.time() - search_start_time
print(f"🔍 Memory search took: {search_time:.2f}s (found {len(relevant_memories)} memories)") logger.debug(
"Memory search completed",
search_time_s=round(search_time, 2),
memories_found=len(relevant_memories),
)
# Stage 2: Prepare LLM messages
prep_start_time = time.time() prep_start_time = time.time()
system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}" system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
messages = [{"role": "system", "content": system_prompt}] messages = [{"role": "system", "content": system_prompt}]
# Add conversation context if provided (last 50 messages)
if context: if context:
messages.extend(context) messages.extend(context)
print(f"📝 Added {len(context)} context messages") logger.debug("Added context messages", context_count=len(context))
# Add current user message
messages.append({"role": "user", "content": message}) messages.append({"role": "user", "content": message})
prep_time = time.time() - prep_start_time prep_time = time.time() - prep_start_time
print(f"📋 Message preparation took: {prep_time:.3f}s")
# Stage 3: LLM Call
llm_start_time = time.time() llm_start_time = time.time()
response = self.openai_client.chat.completions.create(model=settings.default_model, messages=messages) response = self.openai_client.chat.completions.create(
model=settings.default_model, messages=messages
)
assistant_response = response.choices[0].message.content assistant_response = response.choices[0].message.content
llm_time = time.time() - llm_start_time llm_time = time.time() - llm_start_time
print(f"🤖 LLM call took: {llm_time:.2f}s (model: {settings.default_model})") logger.debug(
"LLM call completed",
llm_time_s=round(llm_time, 2),
model=settings.default_model,
)
# Stage 4: Memory Add
add_start_time = time.time() add_start_time = time.time()
memory_messages = [{"role": "user", "content": message}, {"role": "assistant", "content": assistant_response}] memory_messages = [
{"role": "user", "content": message},
{"role": "assistant", "content": assistant_response},
]
self.memory.add(memory_messages, user_id=user_id) self.memory.add(memory_messages, user_id=user_id)
add_time = time.time() - add_start_time add_time = time.time() - add_start_time
print(f"💾 Memory add took: {add_time:.2f}s")
# Total timing summary
total_time = time.time() - total_start_time total_time = time.time() - total_start_time
print(f"⏱️ TOTAL: {total_time:.2f}s | Search: {search_time:.2f}s | LLM: {llm_time:.2f}s | Add: {add_time:.2f}s | Prep: {prep_time:.3f}s") logger.info(
print(f"📊 Breakdown: Search {(search_time/total_time)*100:.1f}% | LLM {(llm_time/total_time)*100:.1f}% | Add {(add_time/total_time)*100:.1f}%\n") "Chat request completed",
user_id=user_id,
total_time_s=round(total_time, 2),
search_time_s=round(search_time, 2),
llm_time_s=round(llm_time, 2),
add_time_s=round(add_time, 2),
memories_used=len(relevant_memories),
model=settings.default_model,
)
return { return {
"response": assistant_response, "response": assistant_response,
@ -356,17 +455,22 @@ class Mem0Manager:
"total": round(total_time, 2), "total": round(total_time, 2),
"search": round(search_time, 2), "search": round(search_time, 2),
"llm": round(llm_time, 2), "llm": round(llm_time, 2),
"add": round(add_time, 2) "add": round(add_time, 2),
} },
} }
except Exception as e: except Exception as e:
logger.error(f"Error in chat_with_memory: {e}") logger.error(
"Error in chat_with_memory",
error=str(e),
user_id=user_id,
exc_info=True,
)
return { return {
"error": str(e), "error": str(e),
"response": "I apologize, but I encountered an error processing your request.", "response": "I apologize, but I encountered an error processing your request.",
"memories_used": 0, "memories_used": 0,
"model_used": None "model_used": None,
} }
async def health_check(self) -> Dict[str, str]: async def health_check(self) -> Dict[str, str]:

View file

@ -2,53 +2,115 @@
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import re
# Constants for input validation
MAX_MESSAGE_LENGTH = 50000 # ~12k tokens max per message
MAX_QUERY_LENGTH = 10000 # ~2.5k tokens max per query
MAX_USER_ID_LENGTH = 100 # Reasonable user ID length
MAX_MEMORY_ID_LENGTH = 100 # Memory IDs are typically UUIDs
MAX_CONTEXT_MESSAGES = 100 # Max conversation context messages
USER_ID_PATTERN = r"^[a-zA-Z0-9_\-\.@]+$" # Alphanumeric with common separators
# Request Models # Request Models
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
"""Chat message structure.""" """Chat message structure."""
role: str = Field(..., description="Message role (user, assistant, system)")
content: str = Field(..., description="Message content") role: str = Field(
..., max_length=20, description="Message role (user, assistant, system)"
)
content: str = Field(
..., max_length=MAX_MESSAGE_LENGTH, description="Message content"
)
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
"""Ultra-minimal chat request.""" """Ultra-minimal chat request."""
message: str = Field(..., description="User message")
user_id: Optional[str] = Field("default", description="User identifier") message: str = Field(..., max_length=MAX_MESSAGE_LENGTH, description="User message")
agent_id: Optional[str] = Field(None, description="Agent identifier") user_id: Optional[str] = Field(
run_id: Optional[str] = Field(None, description="Run identifier") "default",
context: Optional[List[ChatMessage]] = Field(None, description="Previous conversation context") max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier (alphanumeric, _, -, ., @)",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
context: Optional[List[ChatMessage]] = Field(
None,
max_length=MAX_CONTEXT_MESSAGES,
description="Previous conversation context (max 100 messages)",
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class MemoryAddRequest(BaseModel): class MemoryAddRequest(BaseModel):
"""Request to add memories with hierarchy support - open-source compatible.""" """Request to add memories with hierarchy support - open-source compatible."""
messages: List[ChatMessage] = Field(..., description="Messages to process")
user_id: Optional[str] = Field("default", description="User identifier") messages: List[ChatMessage] = Field(
agent_id: Optional[str] = Field(None, description="Agent identifier") ...,
run_id: Optional[str] = Field(None, description="Run identifier") max_length=MAX_CONTEXT_MESSAGES,
description="Messages to process (max 100 messages)",
)
user_id: Optional[str] = Field(
"default",
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class MemorySearchRequest(BaseModel): class MemorySearchRequest(BaseModel):
"""Request to search memories with hierarchy filtering.""" """Request to search memories with hierarchy filtering."""
query: str = Field(..., description="Search query")
user_id: Optional[str] = Field("default", description="User identifier")
agent_id: Optional[str] = Field(None, description="Agent identifier")
run_id: Optional[str] = Field(None, description="Run identifier")
limit: int = Field(5, description="Maximum number of results")
threshold: Optional[float] = Field(None, description="Minimum relevance score")
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters")
# Hierarchy filters (open-source compatible) query: str = Field(..., max_length=MAX_QUERY_LENGTH, description="Search query")
agent_id: Optional[str] = Field(None, description="Filter by agent identifier") user_id: Optional[str] = Field(
run_id: Optional[str] = Field(None, description="Filter by run identifier") "default",
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier",
)
agent_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
)
run_id: Optional[str] = Field(
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
)
limit: int = Field(5, ge=1, le=100, description="Maximum number of results (1-100)")
threshold: Optional[float] = Field(
None, ge=0.0, le=1.0, description="Minimum relevance score (0-1)"
)
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters")
class MemoryUpdateRequest(BaseModel): class MemoryUpdateRequest(BaseModel):
"""Request to update a memory.""" """Request to update a memory."""
memory_id: str = Field(..., description="Memory ID to update")
content: str = Field(..., description="New memory content") memory_id: str = Field(
..., max_length=MAX_MEMORY_ID_LENGTH, description="Memory ID to update"
)
user_id: str = Field(
...,
max_length=MAX_USER_ID_LENGTH,
pattern=USER_ID_PATTERN,
description="User identifier for ownership verification",
)
content: str = Field(
..., max_length=MAX_MESSAGE_LENGTH, description="New memory content"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata")
@ -57,19 +119,23 @@ class MemoryUpdateRequest(BaseModel):
class MemoryItem(BaseModel): class MemoryItem(BaseModel):
"""Individual memory item.""" """Individual memory item."""
id: str = Field(..., description="Memory unique identifier") id: str = Field(..., description="Memory unique identifier")
memory: str = Field(..., description="Memory content") memory: str = Field(..., description="Memory content")
user_id: Optional[str] = Field(None, description="Associated user ID") user_id: Optional[str] = Field(None, description="Associated user ID")
agent_id: Optional[str] = Field(None, description="Associated agent ID") agent_id: Optional[str] = Field(None, description="Associated agent ID")
run_id: Optional[str] = Field(None, description="Associated run ID") run_id: Optional[str] = Field(None, description="Associated run ID")
metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata") metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata")
score: Optional[float] = Field(None, description="Relevance score (for search results)") score: Optional[float] = Field(
None, description="Relevance score (for search results)"
)
created_at: Optional[str] = Field(None, description="Creation timestamp") created_at: Optional[str] = Field(None, description="Creation timestamp")
updated_at: Optional[str] = Field(None, description="Last update timestamp") updated_at: Optional[str] = Field(None, description="Last update timestamp")
class MemorySearchResponse(BaseModel): class MemorySearchResponse(BaseModel):
"""Memory search results - pure Mem0 structure.""" """Memory search results - pure Mem0 structure."""
memories: List[MemoryItem] = Field(..., description="Found memories") memories: List[MemoryItem] = Field(..., description="Found memories")
total_count: int = Field(..., description="Total number of memories found") total_count: int = Field(..., description="Total number of memories found")
query: str = Field(..., description="Original search query") query: str = Field(..., description="Original search query")
@ -77,27 +143,37 @@ class MemorySearchResponse(BaseModel):
class MemoryAddResponse(BaseModel): class MemoryAddResponse(BaseModel):
"""Response from adding memories - pure Mem0 structure.""" """Response from adding memories - pure Mem0 structure."""
added_memories: List[Dict[str, Any]] = Field(..., description="Memories that were added")
added_memories: List[Dict[str, Any]] = Field(
..., description="Memories that were added"
)
message: str = Field(..., description="Success message") message: str = Field(..., description="Success message")
class GraphRelationship(BaseModel): class GraphRelationship(BaseModel):
"""Graph relationship structure.""" """Graph relationship structure."""
source: str = Field(..., description="Source entity") source: str = Field(..., description="Source entity")
relationship: str = Field(..., description="Relationship type") relationship: str = Field(..., description="Relationship type")
target: str = Field(..., description="Target entity") target: str = Field(..., description="Target entity")
properties: Optional[Dict[str, Any]] = Field(None, description="Relationship properties") properties: Optional[Dict[str, Any]] = Field(
None, description="Relationship properties"
)
class GraphResponse(BaseModel): class GraphResponse(BaseModel):
"""Graph relationships - pure Mem0 structure.""" """Graph relationships - pure Mem0 structure."""
relationships: List[GraphRelationship] = Field(..., description="Found relationships")
relationships: List[GraphRelationship] = Field(
..., description="Found relationships"
)
entities: List[str] = Field(..., description="Unique entities") entities: List[str] = Field(..., description="Unique entities")
user_id: str = Field(..., description="User identifier") user_id: str = Field(..., description="User identifier")
class HealthResponse(BaseModel): class HealthResponse(BaseModel):
"""Health check response.""" """Health check response."""
status: str = Field(..., description="Service status") status: str = Field(..., description="Service status")
services: Dict[str, str] = Field(..., description="Individual service statuses") services: Dict[str, str] = Field(..., description="Individual service statuses")
timestamp: str = Field(..., description="Health check timestamp") timestamp: str = Field(..., description="Health check timestamp")
@ -105,6 +181,7 @@ class HealthResponse(BaseModel):
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
"""Error response structure.""" """Error response structure."""
error: str = Field(..., description="Error message") error: str = Field(..., description="Error message")
detail: Optional[str] = Field(None, description="Detailed error information") detail: Optional[str] = Field(None, description="Detailed error information")
status_code: int = Field(..., description="HTTP status code") status_code: int = Field(..., description="HTTP status code")
@ -112,8 +189,10 @@ class ErrorResponse(BaseModel):
# Statistics and Monitoring Models # Statistics and Monitoring Models
class MemoryOperationStats(BaseModel): class MemoryOperationStats(BaseModel):
"""Memory operation statistics.""" """Memory operation statistics."""
add: int = Field(..., description="Number of add operations") add: int = Field(..., description="Number of add operations")
search: int = Field(..., description="Number of search operations") search: int = Field(..., description="Number of search operations")
update: int = Field(..., description="Number of update operations") update: int = Field(..., description="Number of update operations")
@ -122,19 +201,29 @@ class MemoryOperationStats(BaseModel):
class GlobalStatsResponse(BaseModel): class GlobalStatsResponse(BaseModel):
"""Global application statistics.""" """Global application statistics."""
total_memories: int = Field(..., description="Total memories across all users") total_memories: int = Field(..., description="Total memories across all users")
total_users: int = Field(..., description="Total number of users") total_users: int = Field(..., description="Total number of users")
api_calls_today: int = Field(..., description="Total API calls today") api_calls_today: int = Field(..., description="Total API calls today")
avg_response_time_ms: float = Field(..., description="Average response time in milliseconds") avg_response_time_ms: float = Field(
memory_operations: MemoryOperationStats = Field(..., description="Memory operation breakdown") ..., description="Average response time in milliseconds"
)
memory_operations: MemoryOperationStats = Field(
..., description="Memory operation breakdown"
)
uptime_seconds: float = Field(..., description="Application uptime in seconds") uptime_seconds: float = Field(..., description="Application uptime in seconds")
class UserStatsResponse(BaseModel): class UserStatsResponse(BaseModel):
"""User-specific statistics.""" """User-specific statistics."""
user_id: str = Field(..., description="User identifier") user_id: str = Field(..., description="User identifier")
memory_count: int = Field(..., description="Number of memories for this user") memory_count: int = Field(..., description="Number of memories for this user")
relationship_count: int = Field(..., description="Number of graph relationships for this user") relationship_count: int = Field(
..., description="Number of graph relationships for this user"
)
last_activity: Optional[str] = Field(None, description="Last activity timestamp") last_activity: Optional[str] = Field(None, description="Last activity timestamp")
api_calls_today: int = Field(..., description="API calls made by this user today") api_calls_today: int = Field(..., description="API calls made by this user today")
avg_response_time_ms: float = Field(..., description="Average response time for this user's requests") avg_response_time_ms: float = Field(
..., description="Average response time for this user's requests"
)

View file

@ -19,6 +19,7 @@ ollama
# Utilities # Utilities
pydantic pydantic
pydantic-settings pydantic-settings
tenacity
python-dotenv python-dotenv
httpx httpx
aiofiles aiofiles
@ -32,5 +33,8 @@ python-json-logger
python-jose[cryptography] python-jose[cryptography]
passlib[bcrypt] passlib[bcrypt]
# Rate Limiting
slowapi
# MCP Server # MCP Server
mcp[server]>=1.0.0 mcp[server]>=1.0.0

View file

@ -19,13 +19,24 @@ import time
BASE_URL = "http://localhost:8000" BASE_URL = "http://localhost:8000"
TEST_USER = f"test_user_{int(datetime.now().timestamp())}" TEST_USER = f"test_user_{int(datetime.now().timestamp())}"
# API Key for authentication - set via environment or use default test key
import os
API_KEY = os.environ.get("MEM0_API_KEY", "test-api-key")
AUTH_HEADERS = {"X-API-Key": API_KEY}
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Mem0 Integration Tests - Real API Testing (Zero Mocking)", description="Mem0 Integration Tests - Real API Testing (Zero Mocking)",
formatter_class=argparse.RawDescriptionHelpFormatter formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="Show detailed output and API responses",
) )
parser.add_argument("--verbose", "-v", action="store_true",
help="Show detailed output and API responses")
args = parser.parse_args() args = parser.parse_args()
verbose = args.verbose verbose = args.verbose
@ -39,6 +50,9 @@ def main():
# Test sequence - order matters for data dependencies # Test sequence - order matters for data dependencies
tests = [ tests = [
test_health_check, test_health_check,
test_auth_required_endpoints,
test_ownership_verification,
test_request_size_limit,
test_empty_search_protection, test_empty_search_protection,
test_add_memories_with_hierarchy, test_add_memories_with_hierarchy,
test_search_memories_basic, test_search_memories_basic,
@ -51,7 +65,7 @@ def main():
test_graph_relationships, test_graph_relationships,
test_delete_specific_memory, test_delete_specific_memory,
test_delete_all_user_memories, test_delete_all_user_memories,
test_cleanup_verification test_cleanup_verification,
] ]
results = [] results = []
@ -82,6 +96,7 @@ def main():
print("❌ Some tests failed! Check the output above.") print("❌ Some tests failed! Check the output above.")
sys.exit(1) sys.exit(1)
def run_test(name, test_func, verbose): def run_test(name, test_func, verbose):
"""Run a single test with error handling""" """Run a single test with error handling"""
try: try:
@ -102,6 +117,7 @@ def run_test(name, test_func, verbose):
print(f"{name}: {e}") print(f"{name}: {e}")
return False return False
def log_response(response, verbose, context=""): def log_response(response, verbose, context=""):
"""Log API response details if verbose""" """Log API response details if verbose"""
if verbose: if verbose:
@ -111,22 +127,30 @@ def log_response(response, verbose, context=""):
if isinstance(data, dict) and len(data) < 5: if isinstance(data, dict) and len(data) < 5:
print(f" {context} Response: {data}") print(f" {context} Response: {data}")
else: else:
print(f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}") print(
f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}"
)
except: except:
print(f" {context} Response: {response.text[:100]}...") print(f" {context} Response: {response.text[:100]}...")
# ================== TEST FUNCTIONS ================== # ================== TEST FUNCTIONS ==================
def test_health_check(verbose): def test_health_check(verbose):
"""Test service health endpoint""" """Test service health endpoint"""
response = requests.get(f"{BASE_URL}/health", timeout=10) response = requests.get(
f"{BASE_URL}/health", timeout=10
) # Health doesn't require auth
log_response(response, verbose, "Health") log_response(response, verbose, "Health")
assert response.status_code == 200, f"Expected 200, got {response.status_code}" assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json() data = response.json()
assert "status" in data, "Health response missing 'status' field" assert "status" in data, "Health response missing 'status' field"
assert data["status"] in ["healthy", "degraded"], f"Invalid status: {data['status']}" assert data["status"] in ["healthy", "degraded"], (
f"Invalid status: {data['status']}"
)
# Check individual services # Check individual services
assert "services" in data, "Health response missing 'services' field" assert "services" in data, "Health response missing 'services' field"
@ -136,18 +160,19 @@ def test_health_check(verbose):
for service, status in data["services"].items(): for service, status in data["services"].items():
print(f" {service}: {status}") print(f" {service}: {status}")
def test_empty_search_protection(verbose): def test_empty_search_protection(verbose):
"""Test empty query protection (should not return 500 error)""" """Test empty query protection (should not return 500 error)"""
payload = { payload = {"query": "", "user_id": TEST_USER, "limit": 5}
"query": "",
"user_id": TEST_USER,
"limit": 5
}
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=10) response = requests.post(
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Empty Search") log_response(response, verbose, "Empty Search")
assert response.status_code == 200, f"Empty query failed with {response.status_code}" assert response.status_code == 200, (
f"Empty query failed with {response.status_code}"
)
data = response.json() data = response.json()
assert data["memories"] == [], "Empty query should return empty memories list" assert data["memories"] == [], "Empty query should return empty memories list"
@ -158,25 +183,39 @@ def test_empty_search_protection(verbose):
print(f" Empty search note: {data['note']}") print(f" Empty search note: {data['note']}")
print(f" Total count: {data.get('total_count', 0)}") print(f" Total count: {data.get('total_count', 0)}")
def test_add_memories_with_hierarchy(verbose): def test_add_memories_with_hierarchy(verbose):
"""Test adding memories with multi-level hierarchy support""" """Test adding memories with multi-level hierarchy support"""
payload = { payload = {
"messages": [ "messages": [
{"role": "user", "content": "I work at TechCorp as a Senior Software Engineer"}, {
{"role": "user", "content": "My colleague Sarah from Marketing team helped with Q3 presentation"}, "role": "user",
{"role": "user", "content": "Meeting with John the Product Manager tomorrow about new feature development"} "content": "I work at TechCorp as a Senior Software Engineer",
},
{
"role": "user",
"content": "My colleague Sarah from Marketing team helped with Q3 presentation",
},
{
"role": "user",
"content": "Meeting with John the Product Manager tomorrow about new feature development",
},
], ],
"user_id": TEST_USER, "user_id": TEST_USER,
"agent_id": "test_agent", "agent_id": "test_agent",
"run_id": "test_run_001", "run_id": "test_run_001",
"session_id": "test_session_001", "session_id": "test_session_001",
"metadata": {"test": "integration", "scenario": "work_context"} "metadata": {"test": "integration", "scenario": "work_context"},
} }
response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60) response = requests.post(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
)
log_response(response, verbose, "Add Memories") log_response(response, verbose, "Add Memories")
assert response.status_code == 200, f"Add memories failed with {response.status_code}" assert response.status_code == 200, (
f"Add memories failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "added_memories" in data, "Response missing 'added_memories'" assert "added_memories" in data, "Response missing 'added_memories'"
@ -191,23 +230,26 @@ def test_add_memories_with_hierarchy(verbose):
relations = first_memory["relations"] relations = first_memory["relations"]
if "added_entities" in relations and relations["added_entities"]: if "added_entities" in relations and relations["added_entities"]:
if verbose: if verbose:
print(f" Graph extracted: {len(relations['added_entities'])} relationships") print(
f" Graph extracted: {len(relations['added_entities'])} relationships"
)
print(f" Sample relations: {relations['added_entities'][:3]}") print(f" Sample relations: {relations['added_entities'][:3]}")
if verbose: if verbose:
print(f" Added {len(memories)} memory blocks") print(f" Added {len(memories)} memory blocks")
print(f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001") print(
f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001"
)
def test_search_memories_basic(verbose): def test_search_memories_basic(verbose):
"""Test basic memory search functionality""" """Test basic memory search functionality"""
# Test meaningful search # Test meaningful search
payload = { payload = {"query": "TechCorp", "user_id": TEST_USER, "limit": 10}
"query": "TechCorp",
"user_id": TEST_USER,
"limit": 10
}
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15) response = requests.post(
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Search") log_response(response, verbose, "Search")
assert response.status_code == 200, f"Search failed with {response.status_code}" assert response.status_code == 200, f"Search failed with {response.status_code}"
@ -232,6 +274,7 @@ def test_search_memories_basic(verbose):
print(f" Found {data['total_count']} memories") print(f" Found {data['total_count']} memories")
print(f" First memory: {memory['memory'][:50]}...") print(f" First memory: {memory['memory'][:50]}...")
def test_search_memories_hierarchy_filters(verbose): def test_search_memories_hierarchy_filters(verbose):
"""Test multi-level hierarchy filtering in search""" """Test multi-level hierarchy filtering in search"""
# Test with hierarchy filters # Test with hierarchy filters
@ -241,13 +284,17 @@ def test_search_memories_hierarchy_filters(verbose):
"agent_id": "test_agent", "agent_id": "test_agent",
"run_id": "test_run_001", "run_id": "test_run_001",
"session_id": "test_session_001", "session_id": "test_session_001",
"limit": 10 "limit": 10,
} }
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15) response = requests.post(
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Hierarchy Search") log_response(response, verbose, "Hierarchy Search")
assert response.status_code == 200, f"Hierarchy search failed with {response.status_code}" assert response.status_code == 200, (
f"Hierarchy search failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "memories" in data, "Hierarchy search response missing 'memories'" assert "memories" in data, "Hierarchy search response missing 'memories'"
@ -257,7 +304,10 @@ def test_search_memories_hierarchy_filters(verbose):
if verbose: if verbose:
print(f" Found {len(data['memories'])} memories with hierarchy filters") print(f" Found {len(data['memories'])} memories with hierarchy filters")
print(f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001") print(
f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001"
)
def test_get_user_memories_with_hierarchy(verbose): def test_get_user_memories_with_hierarchy(verbose):
"""Test retrieving user memories with hierarchy filtering""" """Test retrieving user memories with hierarchy filtering"""
@ -266,13 +316,20 @@ def test_get_user_memories_with_hierarchy(verbose):
"limit": 20, "limit": 20,
"agent_id": "test_agent", "agent_id": "test_agent",
"run_id": "test_run_001", "run_id": "test_run_001",
"session_id": "test_session_001" "session_id": "test_session_001",
} }
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}", params=params, timeout=15) response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}",
params=params,
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Get User Memories with Hierarchy") log_response(response, verbose, "Get User Memories with Hierarchy")
assert response.status_code == 200, f"Get user memories with hierarchy failed with {response.status_code}" assert response.status_code == 200, (
f"Get user memories with hierarchy failed with {response.status_code}"
)
memories = response.json() memories = response.json()
assert isinstance(memories, list), "User memories should return a list" assert isinstance(memories, list), "User memories should return a list"
@ -290,10 +347,13 @@ def test_get_user_memories_with_hierarchy(verbose):
if verbose: if verbose:
print(" No memories found with hierarchy filters (may be expected)") print(" No memories found with hierarchy filters (may be expected)")
def test_memory_history(verbose): def test_memory_history(verbose):
"""Test memory history endpoint""" """Test memory history endpoint"""
# First get a memory to check history for # First get a memory to check history for
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10) response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for history test" assert response.status_code == 200, "Failed to get memory for history test"
memories = response.json() memories = response.json()
@ -305,27 +365,38 @@ def test_memory_history(verbose):
memory_id = memories[0]["id"] memory_id = memories[0]["id"]
# Test memory history endpoint # Test memory history endpoint
response = requests.get(f"{BASE_URL}/memories/{memory_id}/history", timeout=15) response = requests.get(
f"{BASE_URL}/memories/{memory_id}/history?user_id={TEST_USER}",
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Memory History") log_response(response, verbose, "Memory History")
assert response.status_code == 200, f"Memory history failed with {response.status_code}" assert response.status_code == 200, (
f"Memory history failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "memory_id" in data, "History response missing 'memory_id'" assert "memory_id" in data, "History response missing 'memory_id'"
assert "history" in data, "History response missing 'history'" assert "history" in data, "History response missing 'history'"
assert "message" in data, "History response missing success message" assert "message" in data, "History response missing success message"
assert data["memory_id"] == memory_id, f"Wrong memory_id in response: {data['memory_id']}" assert data["memory_id"] == memory_id, (
f"Wrong memory_id in response: {data['memory_id']}"
)
if verbose: if verbose:
print(f" Retrieved history for memory {memory_id}") print(f" Retrieved history for memory {memory_id}")
print(f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}") print(
f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}"
)
def test_update_memory(verbose): def test_update_memory(verbose):
"""Test updating a specific memory""" """Test updating a specific memory"""
# First get a memory to update # First get a memory to update
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10) response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for update test" assert response.status_code == 200, "Failed to get memory for update test"
memories = response.json() memories = response.json()
@ -337,10 +408,13 @@ def test_update_memory(verbose):
# Update the memory # Update the memory
payload = { payload = {
"memory_id": memory_id, "memory_id": memory_id,
"content": f"UPDATED: {original_content}" "user_id": TEST_USER,
"content": f"UPDATED: {original_content}",
} }
response = requests.put(f"{BASE_URL}/memories", json=payload, timeout=10) response = requests.put(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Update") log_response(response, verbose, "Update")
assert response.status_code == 200, f"Update failed with {response.status_code}" assert response.status_code == 200, f"Update failed with {response.status_code}"
@ -352,15 +426,15 @@ def test_update_memory(verbose):
print(f" Updated memory {memory_id}") print(f" Updated memory {memory_id}")
print(f" Original: {original_content[:30]}...") print(f" Original: {original_content[:30]}...")
def test_chat_with_memory(verbose): def test_chat_with_memory(verbose):
"""Test memory-enhanced chat functionality""" """Test memory-enhanced chat functionality"""
payload = { payload = {"message": "What company do I work for?", "user_id": TEST_USER}
"message": "What company do I work for?",
"user_id": TEST_USER
}
try: try:
response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=90) response = requests.post(
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=90
)
log_response(response, verbose, "Chat") log_response(response, verbose, "Chat")
assert response.status_code == 200, f"Chat failed with {response.status_code}" assert response.status_code == 200, f"Chat failed with {response.status_code}"
@ -383,12 +457,15 @@ def test_chat_with_memory(verbose):
print(" Chat endpoint timed out (LLM API may be slow)") print(" Chat endpoint timed out (LLM API may be slow)")
# Still test that the endpoint exists and accepts requests # Still test that the endpoint exists and accepts requests
try: try:
response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=5) response = requests.post(
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=5
)
except requests.exceptions.ReadTimeout: except requests.exceptions.ReadTimeout:
# This is expected - endpoint exists but processing is slow # This is expected - endpoint exists but processing is slow
if verbose: if verbose:
print(" Chat endpoint confirmed active (processing timeout expected)") print(" Chat endpoint confirmed active (processing timeout expected)")
def test_graph_relationships_creation(verbose): def test_graph_relationships_creation(verbose):
"""Test graph relationships creation with entity-rich memories""" """Test graph relationships creation with entity-rich memories"""
# Create a separate test user for graph relationship testing # Create a separate test user for graph relationship testing
@ -397,41 +474,67 @@ def test_graph_relationships_creation(verbose):
# Add memories with clear entity relationships # Add memories with clear entity relationships
payload = { payload = {
"messages": [ "messages": [
{"role": "user", "content": "John Smith works at Microsoft as a Senior Software Engineer"}, {
{"role": "user", "content": "John Smith is friends with Sarah Johnson who works at Google"}, "role": "user",
{"role": "user", "content": "Sarah Johnson lives in Seattle and loves hiking"}, "content": "John Smith works at Microsoft as a Senior Software Engineer",
},
{
"role": "user",
"content": "John Smith is friends with Sarah Johnson who works at Google",
},
{
"role": "user",
"content": "Sarah Johnson lives in Seattle and loves hiking",
},
{"role": "user", "content": "Microsoft is located in Redmond, Washington"}, {"role": "user", "content": "Microsoft is located in Redmond, Washington"},
{"role": "user", "content": "John Smith and Sarah Johnson both graduated from Stanford University"} {
"role": "user",
"content": "John Smith and Sarah Johnson both graduated from Stanford University",
},
], ],
"user_id": graph_test_user, "user_id": graph_test_user,
"metadata": {"test": "graph_relationships", "scenario": "entity_creation"} "metadata": {"test": "graph_relationships", "scenario": "entity_creation"},
} }
response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60) response = requests.post(
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
)
log_response(response, verbose, "Add Graph Memories") log_response(response, verbose, "Add Graph Memories")
assert response.status_code == 200, f"Add graph memories failed with {response.status_code}" assert response.status_code == 200, (
f"Add graph memories failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "added_memories" in data, "Response missing 'added_memories'" assert "added_memories" in data, "Response missing 'added_memories'"
if verbose: if verbose:
print(f" Added {len(data['added_memories'])} memories for graph relationship testing") print(
f" Added {len(data['added_memories'])} memories for graph relationship testing"
)
# Wait a moment for graph processing (Mem0 graph extraction can be async) # Wait a moment for graph processing (Mem0 graph extraction can be async)
time.sleep(2) time.sleep(2)
# Test graph relationships endpoint # Test graph relationships endpoint
response = requests.get(f"{BASE_URL}/graph/relationships/{graph_test_user}", timeout=15) response = requests.get(
f"{BASE_URL}/graph/relationships/{graph_test_user}",
headers=AUTH_HEADERS,
timeout=15,
)
log_response(response, verbose, "Graph Relationships") log_response(response, verbose, "Graph Relationships")
assert response.status_code == 200, f"Graph relationships failed with {response.status_code}" assert response.status_code == 200, (
f"Graph relationships failed with {response.status_code}"
)
graph_data = response.json() graph_data = response.json()
assert "relationships" in graph_data, "Graph response missing 'relationships'" assert "relationships" in graph_data, "Graph response missing 'relationships'"
assert "entities" in graph_data, "Graph response missing 'entities'" assert "entities" in graph_data, "Graph response missing 'entities'"
assert "user_id" in graph_data, "Graph response missing 'user_id'" assert "user_id" in graph_data, "Graph response missing 'user_id'"
assert graph_data["user_id"] == graph_test_user, f"Wrong user_id in graph: {graph_data['user_id']}" assert graph_data["user_id"] == graph_test_user, (
f"Wrong user_id in graph: {graph_data['user_id']}"
)
relationships = graph_data["relationships"] relationships = graph_data["relationships"]
entities = graph_data["entities"] entities = graph_data["entities"]
@ -447,20 +550,28 @@ def test_graph_relationships_creation(verbose):
source = rel.get("source", "unknown") source = rel.get("source", "unknown")
target = rel.get("target", "unknown") target = rel.get("target", "unknown")
relationship = rel.get("relationship", "unknown") relationship = rel.get("relationship", "unknown")
print(f" {i+1}. {source} --{relationship}--> {target}") print(f" {i + 1}. {source} --{relationship}--> {target}")
# Print sample entities if they exist # Print sample entities if they exist
if entities: if entities:
print(f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}") print(
f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}"
)
# Verify relationship structure (if relationships exist) # Verify relationship structure (if relationships exist)
for rel in relationships: for rel in relationships:
assert "source" in rel or "from" in rel, f"Relationship missing source/from: {rel}" assert "source" in rel or "from" in rel, (
f"Relationship missing source/from: {rel}"
)
assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}" assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}"
assert "relationship" in rel or "type" in rel, f"Relationship missing type: {rel}" assert "relationship" in rel or "type" in rel, (
f"Relationship missing type: {rel}"
)
# Clean up graph test user memories # Clean up graph test user memories
cleanup_response = requests.delete(f"{BASE_URL}/memories/user/{graph_test_user}", timeout=15) cleanup_response = requests.delete(
f"{BASE_URL}/memories/user/{graph_test_user}", headers=AUTH_HEADERS, timeout=15
)
assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories" assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories"
if verbose: if verbose:
@ -469,12 +580,17 @@ def test_graph_relationships_creation(verbose):
# Note: We expect some relationships even if graph extraction is basic # Note: We expect some relationships even if graph extraction is basic
# The test passes if the endpoint works and returns proper structure # The test passes if the endpoint works and returns proper structure
def test_graph_relationships(verbose): def test_graph_relationships(verbose):
"""Test graph relationships endpoint""" """Test graph relationships endpoint"""
response = requests.get(f"{BASE_URL}/graph/relationships/{TEST_USER}", timeout=15) response = requests.get(
f"{BASE_URL}/graph/relationships/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Graph") log_response(response, verbose, "Graph")
assert response.status_code == 200, f"Graph endpoint failed with {response.status_code}" assert response.status_code == 200, (
f"Graph endpoint failed with {response.status_code}"
)
data = response.json() data = response.json()
assert "relationships" in data, "Graph response missing 'relationships'" assert "relationships" in data, "Graph response missing 'relationships'"
@ -486,10 +602,13 @@ def test_graph_relationships(verbose):
print(f" Relationships: {len(data['relationships'])}") print(f" Relationships: {len(data['relationships'])}")
print(f" Entities: {len(data['entities'])}") print(f" Entities: {len(data['entities'])}")
def test_delete_specific_memory(verbose): def test_delete_specific_memory(verbose):
"""Test deleting a specific memory""" """Test deleting a specific memory"""
# Get a memory to delete # Get a memory to delete
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10) response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
)
assert response.status_code == 200, "Failed to get memory for deletion test" assert response.status_code == 200, "Failed to get memory for deletion test"
memories = response.json() memories = response.json()
@ -498,7 +617,9 @@ def test_delete_specific_memory(verbose):
memory_id = memories[0]["id"] memory_id = memories[0]["id"]
# Delete the memory # Delete the memory
response = requests.delete(f"{BASE_URL}/memories/{memory_id}", timeout=10) response = requests.delete(
f"{BASE_URL}/memories/{memory_id}", headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Delete") log_response(response, verbose, "Delete")
assert response.status_code == 200, f"Delete failed with {response.status_code}" assert response.status_code == 200, f"Delete failed with {response.status_code}"
@ -509,9 +630,12 @@ def test_delete_specific_memory(verbose):
if verbose: if verbose:
print(f" Deleted memory {memory_id}") print(f" Deleted memory {memory_id}")
def test_delete_all_user_memories(verbose): def test_delete_all_user_memories(verbose):
"""Test deleting all memories for a user""" """Test deleting all memories for a user"""
response = requests.delete(f"{BASE_URL}/memories/user/{TEST_USER}", timeout=15) response = requests.delete(
f"{BASE_URL}/memories/user/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
)
log_response(response, verbose, "Delete All") log_response(response, verbose, "Delete All")
assert response.status_code == 200, f"Delete all failed with {response.status_code}" assert response.status_code == 200, f"Delete all failed with {response.status_code}"
@ -522,12 +646,17 @@ def test_delete_all_user_memories(verbose):
if verbose: if verbose:
print(f"Deleted all memories for {TEST_USER}") print(f"Deleted all memories for {TEST_USER}")
def test_cleanup_verification(verbose): def test_cleanup_verification(verbose):
"""Verify cleanup was successful""" """Verify cleanup was successful"""
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=10", timeout=10) response = requests.get(
f"{BASE_URL}/memories/{TEST_USER}?limit=10", headers=AUTH_HEADERS, timeout=10
)
log_response(response, verbose, "Cleanup Check") log_response(response, verbose, "Cleanup Check")
assert response.status_code == 200, f"Cleanup verification failed with {response.status_code}" assert response.status_code == 200, (
f"Cleanup verification failed with {response.status_code}"
)
memories = response.json() memories = response.json()
assert isinstance(memories, list), "Should return list even if empty" assert isinstance(memories, list), "Should return list even if empty"
@ -539,5 +668,79 @@ def test_cleanup_verification(verbose):
if verbose: if verbose:
print(" Cleanup successful - no memories remain") print(" Cleanup successful - no memories remain")
# ================== SECURITY TEST FUNCTIONS ==================
def test_auth_required_endpoints(verbose):
"""Test that protected endpoints require authentication"""
endpoints_requiring_auth = [
("GET", f"{BASE_URL}/memories/{TEST_USER}"),
("POST", f"{BASE_URL}/memories/search"),
("GET", f"{BASE_URL}/stats"),
("GET", f"{BASE_URL}/models"),
("GET", f"{BASE_URL}/users"),
]
for method, url in endpoints_requiring_auth:
if method == "GET":
response = requests.get(url, timeout=5)
else:
response = requests.post(
url, json={"query": "test", "user_id": TEST_USER}, timeout=5
)
assert response.status_code in [401, 403], (
f"{method} {url} should require auth, got {response.status_code}"
)
if verbose:
print(f" {method} {url}: {response.status_code} (auth required)")
def test_ownership_verification(verbose):
"""Test that users can only access their own data"""
other_user = "other_user_not_me"
response = requests.get(
f"{BASE_URL}/memories/{other_user}", headers=AUTH_HEADERS, timeout=5
)
assert response.status_code in [403, 404], (
f"Accessing other user's memories should be denied, got {response.status_code}"
)
if verbose:
print(f" Ownership check passed: {response.status_code}")
def test_request_size_limit(verbose):
"""Test request size limit enforcement (10MB max)"""
large_payload = {
"messages": [{"role": "user", "content": "x" * (11 * 1024 * 1024)}],
"user_id": TEST_USER,
}
try:
response = requests.post(
f"{BASE_URL}/memories",
json=large_payload,
headers={**AUTH_HEADERS, "Content-Length": str(11 * 1024 * 1024)},
timeout=5,
)
assert response.status_code == 413, (
f"Large request should return 413, got {response.status_code}"
)
if verbose:
print(f" Request size limit enforced: {response.status_code}")
except requests.exceptions.RequestException as e:
if verbose:
print(
f" Request size limit test: connection issue (expected for large payload)"
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()