diff --git a/backend/main.py b/backend/main.py index 3a6d6db..5277981 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,22 +1,46 @@ """Main FastAPI application for Mem0 Interface POC.""" +import json import logging import time from datetime import datetime from typing import List, Dict, Any, Optional from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security +from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import structlog +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded from config import settings + + +# Rate limiter - uses IP address as key, falls back to API key for authenticated requests +def get_rate_limit_key(request: Request) -> str: + """Get rate limit key - prefer API key if available, otherwise IP.""" + api_key = request.headers.get("x-api-key", "") + if api_key: + return f"apikey:{api_key[:16]}" # Use first 16 chars of API key + return get_remote_address(request) + + +limiter = Limiter(key_func=get_rate_limit_key) from models import ( - ChatRequest, MemoryAddRequest, MemoryAddResponse, - MemorySearchRequest, MemorySearchResponse, MemoryUpdateRequest, - MemoryItem, GraphResponse, HealthResponse, ErrorResponse, - GlobalStatsResponse, UserStatsResponse + ChatRequest, + MemoryAddRequest, + MemoryAddResponse, + MemorySearchRequest, + MemorySearchResponse, + MemoryUpdateRequest, + MemoryItem, + GraphResponse, + HealthResponse, + ErrorResponse, + GlobalStatsResponse, + UserStatsResponse, ) from mem0_manager import mem0_manager from auth import get_current_user, auth_service @@ -32,7 +56,7 @@ structlog.configure( structlog.processors.StackInfoRenderer(), structlog.processors.format_exc_info, structlog.processors.UnicodeDecoder(), - structlog.processors.JSONRenderer() + structlog.processors.JSONRenderer(), ], context_class=dict, logger_factory=structlog.stdlib.LoggerFactory(), @@ -62,6 +86,7 @@ async def lifespan(app: FastAPI): mcp_context = None try: from mcp_server import mcp_lifespan + mcp_context = mcp_lifespan() await mcp_context.__aenter__() except ImportError: @@ -86,63 +111,91 @@ app = FastAPI( title="Mem0 Interface POC", description="Minimal but fully functional Mem0 interface with PostgreSQL and Neo4j integration", version="1.0.0", - lifespan=lifespan + lifespan=lifespan, ) -# Add CORS middleware - Allow all origins for development +# Add rate limiter to app state and exception handler +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +# Add CORS middleware - Allow all origins (secured via API key auth) app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allow all origins for development + allow_origins=["*"], # Allow all origins - secured via API key authentication allow_credentials=False, # Must be False when allow_origins=["*"] allow_methods=["*"], allow_headers=["*"], ) +# Request size limit middleware - prevent DoS via large payloads +MAX_REQUEST_SIZE = 10 * 1024 * 1024 # 10MB limit + + +@app.middleware("http") +async def limit_request_size(request, call_next): + """Reject requests that exceed the maximum allowed size.""" + content_length = request.headers.get("content-length") + if content_length: + try: + if int(content_length) > MAX_REQUEST_SIZE: + return JSONResponse( + status_code=413, + content={ + "error": "Request payload too large", + "max_size_bytes": MAX_REQUEST_SIZE, + "max_size_mb": MAX_REQUEST_SIZE / (1024 * 1024), + }, + ) + except ValueError: + pass # Invalid content-length header, let it through for other validation + return await call_next(request) + + # Request logging middleware with monitoring @app.middleware("http") async def log_requests(request, call_next): """Log all HTTP requests with correlation ID and timing.""" from monitoring import generate_correlation_id, stats - + correlation_id = generate_correlation_id() start_time = time.time() - + # Extract user_id from request if available user_id = None if request.method == "POST": - # Try to extract user_id from request body for POST requests try: body = await request.body() if body: - import json data = json.loads(body) - user_id = data.get('user_id') - except: - pass + user_id = data.get("user_id") + except json.JSONDecodeError: + pass # Non-JSON body, user_id extraction not possible + except Exception as e: + logger.debug("Could not extract user_id from request body", error=str(e)) elif "user_id" in str(request.url.path): # Extract user_id from path for GET requests - path_parts = request.url.path.split('/') - if len(path_parts) > 2 and path_parts[-2] in ['memories', 'stats']: + path_parts = request.url.path.split("/") + if len(path_parts) > 2 and path_parts[-2] in ["memories", "stats"]: user_id = path_parts[-1] - + # Log start of request logger.info( "HTTP request started", correlation_id=correlation_id, method=request.method, path=request.url.path, - user_id=user_id + user_id=user_id, ) - + response = await call_next(request) - + process_time = time.time() - start_time process_time_ms = process_time * 1000 - + # Record statistics stats.record_api_call(user_id, process_time_ms) - + # Log completion with enhanced details if process_time_ms > 2000: # Slow request threshold logger.warning( @@ -153,7 +206,7 @@ async def log_requests(request, call_next): status_code=response.status_code, process_time_ms=round(process_time_ms, 2), user_id=user_id, - slow_request=True + slow_request=True, ) elif response.status_code >= 400: logger.error( @@ -164,7 +217,7 @@ async def log_requests(request, call_next): status_code=response.status_code, process_time_ms=round(process_time_ms, 2), user_id=user_id, - slow_request=False + slow_request=False, ) else: logger.info( @@ -175,20 +228,32 @@ async def log_requests(request, call_next): status_code=response.status_code, process_time_ms=round(process_time_ms, 2), user_id=user_id, - slow_request=False + slow_request=False, ) - + return response # Exception handlers @app.exception_handler(Exception) async def global_exception_handler(request, exc): - """Global exception handler.""" - logger.error(f"Unhandled exception: {exc}", exc_info=True) + """Global exception handler - logs details but returns generic message.""" + # Log full exception details for debugging (internal only) + logger.error( + "Unhandled exception", + exc_info=True, + path=request.url.path, + method=request.method, + error_type=type(exc).__name__, + error_message=str(exc), + ) + # Return generic error to client - don't expose internal details return JSONResponse( status_code=500, - content={"error": "Internal server error", "detail": str(exc)} + content={ + "error": "Internal server error", + "message": "An unexpected error occurred", + }, ) @@ -198,50 +263,59 @@ async def health_check(): """Check the health of all services.""" try: services = await mem0_manager.health_check() - overall_status = "healthy" if all("healthy" in status for status in services.values()) else "degraded" - + overall_status = ( + "healthy" + if all("healthy" in status for status in services.values()) + else "degraded" + ) + return HealthResponse( status=overall_status, services=services, - timestamp=datetime.utcnow().isoformat() + timestamp=datetime.utcnow().isoformat(), ) except Exception as e: - logger.error(f"Health check failed: {e}") + logger.error(f"Health check failed: {e}", exc_info=True) return HealthResponse( status="unhealthy", - services={"error": str(e)}, - timestamp=datetime.utcnow().isoformat() + services={"error": "Health check failed - see logs for details"}, + timestamp=datetime.utcnow().isoformat(), ) # Core chat endpoint with memory enhancement @app.post("/chat") +@limiter.limit("30/minute") # Chat is expensive - limit to 30/min async def chat_with_memory( - request: ChatRequest, - authenticated_user: str = Depends(get_current_user) + request: Request, + chat_request: ChatRequest, + authenticated_user: str = Depends(get_current_user), ): """Ultra-minimal chat endpoint - pure Mem0 + custom endpoint.""" try: # Verify user can only access their own data - if authenticated_user != request.user_id: + if authenticated_user != chat_request.user_id: raise HTTPException( status_code=403, - detail=f"Access denied: You can only chat as yourself (authenticated as '{authenticated_user}')" + detail=f"Access denied: You can only chat as yourself (authenticated as '{authenticated_user}')", ) - logger.info(f"Processing chat request for user: {request.user_id}") + logger.info(f"Processing chat request for user: {chat_request.user_id}") # Convert ChatMessage objects to dict format if context provided context_dict = None - if request.context: - context_dict = [{"role": msg.role, "content": msg.content} for msg in request.context] + if chat_request.context: + context_dict = [ + {"role": msg.role, "content": msg.content} + for msg in chat_request.context + ] result = await mem0_manager.chat_with_memory( - message=request.message, - user_id=request.user_id, - agent_id=request.agent_id, - run_id=request.run_id, - context=context_dict + message=chat_request.message, + user_id=chat_request.user_id, + agent_id=chat_request.agent_id, + run_id=chat_request.run_id, + context=context_dict, ) return result @@ -250,32 +324,37 @@ async def chat_with_memory( raise except Exception as e: logger.error(f"Error in chat endpoint: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) # Memory management endpoints - pure Mem0 passthroughs @app.post("/memories") +@limiter.limit("60/minute") # Memory operations - 60/min async def add_memories( - request: MemoryAddRequest, - authenticated_user: str = Depends(get_current_user) + request: Request, + memory_request: MemoryAddRequest, + authenticated_user: str = Depends(get_current_user), ): """Add memories - pure Mem0 passthrough.""" try: # Verify user can only add to their own memories - if authenticated_user != request.user_id: + if authenticated_user != memory_request.user_id: raise HTTPException( status_code=403, - detail=f"Access denied: You can only add memories for yourself (authenticated as '{authenticated_user}')" + detail=f"Access denied: You can only add memories for yourself (authenticated as '{authenticated_user}')", ) - logger.info(f"Adding memories for user: {request.user_id}") + logger.info(f"Adding memories for user: {memory_request.user_id}") result = await mem0_manager.add_memories( - messages=request.messages, - user_id=request.user_id, - agent_id=request.agent_id, - run_id=request.run_id, - metadata=request.metadata + messages=memory_request.messages, + user_id=memory_request.user_id, + agent_id=memory_request.agent_id, + run_id=memory_request.run_id, + metadata=memory_request.metadata, ) return result @@ -284,33 +363,40 @@ async def add_memories( raise except Exception as e: logger.error(f"Error adding memories: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) @app.post("/memories/search") +@limiter.limit("120/minute") # Search is lighter - 120/min async def search_memories( - request: MemorySearchRequest, - authenticated_user: str = Depends(get_current_user) + request: Request, + search_request: MemorySearchRequest, + authenticated_user: str = Depends(get_current_user), ): """Search memories - pure Mem0 passthrough.""" try: # Verify user can only search their own memories - if authenticated_user != request.user_id: + if authenticated_user != search_request.user_id: raise HTTPException( status_code=403, - detail=f"Access denied: You can only search your own memories (authenticated as '{authenticated_user}')" + detail=f"Access denied: You can only search your own memories (authenticated as '{authenticated_user}')", ) - logger.info(f"Searching memories for user: {request.user_id}, query: {request.query}") + logger.info( + f"Searching memories for user: {search_request.user_id}, query: {search_request.query}" + ) result = await mem0_manager.search_memories( - query=request.query, - user_id=request.user_id, - limit=request.limit, - threshold=request.threshold or 0.2, - filters=request.filters, - agent_id=request.agent_id, - run_id=request.run_id + query=search_request.query, + user_id=search_request.user_id, + limit=search_request.limit, + threshold=search_request.threshold or 0.2, + filters=search_request.filters, + agent_id=search_request.agent_id, + run_id=search_request.run_id, ) return result @@ -319,16 +405,21 @@ async def search_memories( raise except Exception as e: logger.error(f"Error searching memories: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) @app.get("/memories/{user_id}") +@limiter.limit("120/minute") async def get_user_memories( + request: Request, user_id: str, authenticated_user: str = Depends(get_current_user), limit: int = 10, agent_id: Optional[str] = None, - run_id: Optional[str] = None + run_id: Optional[str] = None, ): """Get all memories for a user with hierarchy filtering - pure Mem0 passthrough.""" try: @@ -336,16 +427,13 @@ async def get_user_memories( if authenticated_user != user_id: raise HTTPException( status_code=403, - detail=f"Access denied: You can only retrieve your own memories (authenticated as '{authenticated_user}')" + detail=f"Access denied: You can only retrieve your own memories (authenticated as '{authenticated_user}')", ) logger.info(f"Retrieving memories for user: {user_id}") memories = await mem0_manager.get_user_memories( - user_id=user_id, - limit=limit, - agent_id=agent_id, - run_id=run_id + user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id ) return memories @@ -354,28 +442,47 @@ async def get_user_memories( raise except Exception as e: logger.error(f"Error retrieving user memories: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) @app.put("/memories") +@limiter.limit("60/minute") async def update_memory( - request: MemoryUpdateRequest, - authenticated_user: str = Depends(get_current_user) + request: Request, + update_request: MemoryUpdateRequest, + authenticated_user: str = Depends(get_current_user), ): - """Update memory - pure Mem0 passthrough.""" + """Update memory - verifies ownership before update.""" try: # Verify user owns the memory being updated - if authenticated_user != request.user_id: + if authenticated_user != update_request.user_id: raise HTTPException( status_code=403, - detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')" + detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')", ) - logger.info(f"Updating memory: {request.memory_id}") + # Verify the memory actually belongs to the authenticated user + user_memories = await mem0_manager.get_user_memories( + user_id=authenticated_user, limit=10000 + ) + memory_ids = {m.get("id") for m in user_memories if m.get("id")} + + if update_request.memory_id not in memory_ids: + raise HTTPException( + status_code=404, + detail=f"Memory '{update_request.memory_id}' not found or access denied", + ) + + logger.info( + f"Updating memory: {update_request.memory_id}", user_id=authenticated_user + ) result = await mem0_manager.update_memory( - memory_id=request.memory_id, - content=request.content, + memory_id=update_request.memory_id, + content=update_request.content, ) return result @@ -384,25 +491,34 @@ async def update_memory( raise except Exception as e: logger.error(f"Error updating memory: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) @app.delete("/memories/{memory_id}") +@limiter.limit("60/minute") async def delete_memory( + request: Request, memory_id: str, - user_id: str, # Add user_id as query parameter for verification - authenticated_user: str = Depends(get_current_user) + authenticated_user: str = Depends(get_current_user), ): - """Delete a specific memory.""" + """Delete a specific memory - verifies ownership before deletion.""" try: - # Verify user owns the memory being deleted - if authenticated_user != user_id: + # Verify the memory actually belongs to the authenticated user + user_memories = await mem0_manager.get_user_memories( + user_id=authenticated_user, limit=10000 + ) + memory_ids = {m.get("id") for m in user_memories if m.get("id")} + + if memory_id not in memory_ids: raise HTTPException( - status_code=403, - detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')" + status_code=404, + detail=f"Memory '{memory_id}' not found or access denied", ) - logger.info(f"Deleting memory: {memory_id}") + logger.info(f"Deleting memory: {memory_id}", user_id=authenticated_user) result = await mem0_manager.delete_memory(memory_id=memory_id) @@ -412,13 +528,16 @@ async def delete_memory( raise except Exception as e: logger.error(f"Error deleting memory: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) @app.delete("/memories/user/{user_id}") +@limiter.limit("10/minute") # Dangerous bulk delete - heavily rate limited async def delete_user_memories( - user_id: str, - authenticated_user: str = Depends(get_current_user) + request: Request, user_id: str, authenticated_user: str = Depends(get_current_user) ): """Delete all memories for a specific user.""" try: @@ -426,7 +545,7 @@ async def delete_user_memories( if authenticated_user != user_id: raise HTTPException( status_code=403, - detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')" + detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')", ) logger.info(f"Deleting all memories for user: {user_id}") @@ -439,14 +558,17 @@ async def delete_user_memories( raise except Exception as e: logger.error(f"Error deleting user memories: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) # Graph relationships endpoint - pure Mem0 passthrough @app.get("/graph/relationships/{user_id}") +@limiter.limit("60/minute") async def get_graph_relationships( - user_id: str, - authenticated_user: str = Depends(get_current_user) + request: Request, user_id: str, authenticated_user: str = Depends(get_current_user) ): """Get graph relationships - pure Mem0 passthrough.""" try: @@ -454,11 +576,13 @@ async def get_graph_relationships( if authenticated_user != user_id: raise HTTPException( status_code=403, - detail=f"Access denied: You can only view your own relationships (authenticated as '{authenticated_user}')" + detail=f"Access denied: You can only view your own relationships (authenticated as '{authenticated_user}')", ) logger.info(f"Retrieving graph relationships for user: {user_id}") - result = await mem0_manager.get_graph_relationships(user_id=user_id, agent_id=None, run_id=None, limit=10000) + result = await mem0_manager.get_graph_relationships( + user_id=user_id, agent_id=None, run_id=None, limit=10000 + ) return result @@ -466,68 +590,106 @@ async def get_graph_relationships( raise except Exception as e: logger.error(f"Error retrieving graph relationships: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) # Memory history endpoint - new feature @app.get("/memories/{memory_id}/history") -async def get_memory_history(memory_id: str): +@limiter.limit("120/minute") +async def get_memory_history( + request: Request, + memory_id: str, + user_id: str, # Required query param to verify ownership + authenticated_user: str = Depends(get_current_user), +): """Get memory change history - pure Mem0 passthrough.""" try: - logger.info(f"Retrieving history for memory: {memory_id}") - + # Verify user can only access their own memory history + if authenticated_user != user_id: + raise HTTPException( + status_code=403, + detail=f"Access denied: You can only view your own memory history (authenticated as '{authenticated_user}')", + ) + + # Verify the memory belongs to this user before returning history + user_memories = await mem0_manager.get_user_memories( + user_id=user_id, limit=10000 + ) + memory_ids = {m.get("id") for m in user_memories if m.get("id")} + + if memory_id not in memory_ids: + raise HTTPException( + status_code=404, + detail=f"Memory '{memory_id}' not found or access denied", + ) + + logger.info(f"Retrieving history for memory: {memory_id}", user_id=user_id) + result = await mem0_manager.get_memory_history(memory_id=memory_id) - + return result - + + except HTTPException: + raise except Exception as e: logger.error(f"Error retrieving memory history: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) # Statistics and monitoring endpoints @app.get("/stats", response_model=GlobalStatsResponse) -async def get_global_stats(): - """Get global application statistics.""" +@limiter.limit("60/minute") +async def get_global_stats( + request: Request, authenticated_user: str = Depends(get_current_user) +): + """Get global application statistics - requires authentication.""" try: from monitoring import stats - - # Get basic stats from monitoring + basic_stats = stats.get_global_stats() - - # Get actual memory count from Mem0 (simplified approach) + try: - # This is a rough estimate - in production you might want a more efficient method - sample_result = await mem0_manager.search_memories(query="*", user_id="__stats_check__", limit=1) - # For now, we'll use the basic stats total_memories value - # You could implement a more accurate count by querying the database directly - total_memories = basic_stats['total_memories'] # Will be 0 for now - except: + sample_result = await mem0_manager.search_memories( + query="*", user_id="__stats_check__", limit=1 + ) + total_memories = basic_stats["total_memories"] + except Exception: total_memories = 0 - + return GlobalStatsResponse( total_memories=total_memories, - total_users=basic_stats['total_users'], - api_calls_today=basic_stats['api_calls_today'], - avg_response_time_ms=basic_stats['avg_response_time_ms'], + total_users=basic_stats["total_users"], + api_calls_today=basic_stats["api_calls_today"], + avg_response_time_ms=basic_stats["avg_response_time_ms"], memory_operations={ - "add": basic_stats['memory_operations']['add'], - "search": basic_stats['memory_operations']['search'], - "update": basic_stats['memory_operations']['update'], - "delete": basic_stats['memory_operations']['delete'] + "add": basic_stats["memory_operations"]["add"], + "search": basic_stats["memory_operations"]["search"], + "update": basic_stats["memory_operations"]["update"], + "delete": basic_stats["memory_operations"]["delete"], }, - uptime_seconds=basic_stats['uptime_seconds'] + uptime_seconds=basic_stats["uptime_seconds"], ) - + + except HTTPException: + raise except Exception as e: logger.error(f"Error getting global stats: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) @app.get("/stats/{user_id}", response_model=UserStatsResponse) +@limiter.limit("120/minute") async def get_user_stats( - user_id: str, - authenticated_user: str = Depends(get_current_user) + request: Request, user_id: str, authenticated_user: str = Depends(get_current_user) ): """Get user-specific statistics.""" try: @@ -535,7 +697,7 @@ async def get_user_stats( if authenticated_user != user_id: raise HTTPException( status_code=403, - detail=f"Access denied: You can only view your own statistics (authenticated as '{authenticated_user}')" + detail=f"Access denied: You can only view your own statistics (authenticated as '{authenticated_user}')", ) from monitoring import stats @@ -545,59 +707,75 @@ async def get_user_stats( # Get actual memory count for this user try: - user_memories = await mem0_manager.get_user_memories(user_id=user_id, limit=10000) + user_memories = await mem0_manager.get_user_memories( + user_id=user_id, limit=10000 + ) memory_count = len(user_memories) - except: + except Exception as e: + logger.warning(f"Failed to get memory count for user {user_id}: {e}") memory_count = 0 # Get relationship count for this user try: - graph_data = await mem0_manager.get_graph_relationships(user_id=user_id, agent_id=None, run_id=None) - relationship_count = len(graph_data.get('relationships', [])) - except: + graph_data = await mem0_manager.get_graph_relationships( + user_id=user_id, agent_id=None, run_id=None + ) + relationship_count = len(graph_data.get("relationships", [])) + except Exception as e: + logger.warning(f"Failed to get relationship count for user {user_id}: {e}") relationship_count = 0 return UserStatsResponse( user_id=user_id, memory_count=memory_count, relationship_count=relationship_count, - last_activity=basic_stats['last_activity'], - api_calls_today=basic_stats['api_calls_today'], - avg_response_time_ms=basic_stats['avg_response_time_ms'] + last_activity=basic_stats["last_activity"], + api_calls_today=basic_stats["api_calls_today"], + avg_response_time_ms=basic_stats["avg_response_time_ms"], ) except HTTPException: raise except Exception as e: logger.error(f"Error getting user stats for {user_id}: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail="An internal error occurred. Please try again later.", + ) # Utility endpoints @app.get("/models") -async def get_available_models(): - """Get current model configuration.""" +@limiter.limit("120/minute") +async def get_available_models( + request: Request, authenticated_user: str = Depends(get_current_user) +): + """Get current model configuration - requires authentication.""" return { "current_model": settings.default_model, "endpoint": settings.openai_base_url, - "note": "Using single model with pure Mem0 intelligence" + "note": "Using single model with pure Mem0 intelligence", } @app.get("/users") -async def get_active_users(): - """Get list of users with memories (simplified implementation).""" +@limiter.limit("60/minute") +async def get_active_users( + request: Request, authenticated_user: str = Depends(get_current_user) +): + """Get list of users with memories (simplified implementation) - requires authentication.""" # This would typically query the database for users with memories # For now, return a placeholder return { "message": "This endpoint would return users with stored memories", - "note": "Implementation depends on direct database access or Mem0 user enumeration capabilities" + "note": "Implementation depends on direct database access or Mem0 user enumeration capabilities", } # Mount MCP server at /mcp endpoint try: from mcp_server import create_mcp_app + mcp_app = create_mcp_app() app.mount("/mcp", mcp_app) logger.info("MCP server mounted at /mcp") @@ -609,11 +787,12 @@ except Exception as e: if __name__ == "__main__": import uvicorn + print("Starting UVicorn server...") uvicorn.run( "main:app", host="0.0.0.0", port=8000, log_level=settings.log_level.lower(), - reload=True + reload=True, ) diff --git a/backend/mem0_manager.py b/backend/mem0_manager.py index c3c00c2..8fedb39 100644 --- a/backend/mem0_manager.py +++ b/backend/mem0_manager.py @@ -13,16 +13,23 @@ logger = logging.getLogger(__name__) # Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility from mem0.llms.openai import OpenAILLM + _original_generate_response = OpenAILLM.generate_response -def patched_generate_response(self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs): + +def patched_generate_response( + self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs +): # Remove 'store' parameter as LiteLLM doesn't support it - if hasattr(self.config, 'store'): + if hasattr(self.config, "store"): self.config.store = None # Remove 'top_p' to avoid conflict with temperature for Claude models - if hasattr(self.config, 'top_p'): + if hasattr(self.config, "top_p"): self.config.top_p = None - return _original_generate_response(self, messages, response_format, tools, tool_choice, **kwargs) + return _original_generate_response( + self, messages, response_format, tools, tool_choice, **kwargs + ) + OpenAILLM.generate_response = patched_generate_response logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter") @@ -33,10 +40,13 @@ class Mem0Manager: Ultra-minimal manager that bridges custom OpenAI endpoint with pure Mem0. No custom logic - let Mem0 handle all memory intelligence. """ - + def __init__(self): # Custom endpoint configuration with graph memory enabled - logger.info("Initializing ultra-minimal Mem0Manager with custom endpoint with settings:", settings) + logger.info( + "Initializing ultra-minimal Mem0Manager with custom endpoint with settings:", + settings, + ) config = { "enable_graph": True, "llm": { @@ -46,8 +56,8 @@ class Mem0Manager: "api_key": settings.openai_api_key, "openai_base_url": settings.openai_base_url, "temperature": 0.1, - "top_p": None # Don't use top_p with Claude models - } + "top_p": None, # Don't use top_p with Claude models + }, }, "embedder": { "provider": "ollama", @@ -55,8 +65,8 @@ class Mem0Manager: "model": "qwen3-embedding:4b-q8_0", # "api_key": settings.embedder_api_key, "ollama_base_url": "http://172.17.0.1:11434", - "embedding_dims": 2560 - } + "embedding_dims": 2560, + }, }, "vector_store": { "provider": "qdrant", @@ -65,36 +75,36 @@ class Mem0Manager: "host": settings.qdrant_host, "port": settings.qdrant_port, "embedding_model_dims": 2560, - "on_disk": True - } + "on_disk": True, + }, }, "graph_store": { "provider": "neo4j", "config": { "url": settings.neo4j_uri, "username": settings.neo4j_username, - "password": settings.neo4j_password - } + "password": settings.neo4j_password, + }, }, "reranker": { "provider": "cohere", "config": { "api_key": settings.cohere_api_key, "model": "rerank-english-v3.0", - "top_n": 10 - } - } + "top_n": 10, + }, + }, } - + self.memory = Memory.from_config(config) self.openai_client = OpenAI( api_key=settings.openai_api_key, - base_url=settings.openai_base_url + base_url=settings.openai_base_url, + timeout=60.0, # 60 second timeout for LLM calls + max_retries=2, # Retry failed requests up to 2 times ) logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint") - - - + # Pure passthrough methods - no custom logic @timed("add_memories") async def add_memories( @@ -103,18 +113,18 @@ class Mem0Manager: user_id: Optional[str] = "default", agent_id: Optional[str] = None, run_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Add memories - simplified native Mem0 pattern (10 lines vs 45).""" try: # Convert ChatMessage objects to dict if needed formatted_messages = [] for msg in messages: - if hasattr(msg, 'dict'): + if hasattr(msg, "dict"): formatted_messages.append(msg.dict()) else: formatted_messages.append(msg) - + # Auto-enhance metadata for better memory quality combined_metadata = metadata or {} @@ -123,26 +133,34 @@ class Mem0Manager: "timestamp": datetime.now().isoformat(), "source": "chat_conversation", "message_count": len(formatted_messages), - "auto_generated": True + "auto_generated": True, } # Merge user metadata with auto metadata (user metadata takes precedence) enhanced_metadata = {**auto_metadata, **combined_metadata} # Direct Mem0 add with enhanced metadata - result = self.memory.add(formatted_messages, user_id=user_id, - agent_id=agent_id, run_id=run_id, - metadata=enhanced_metadata) - + result = self.memory.add( + formatted_messages, + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + metadata=enhanced_metadata, + ) + return { "added_memories": result if isinstance(result, list) else [result], "message": "Memories added successfully", - "hierarchy": {"user_id": user_id, "agent_id": agent_id, "run_id": run_id} + "hierarchy": { + "user_id": user_id, + "agent_id": agent_id, + "run_id": run_id, + }, } except Exception as e: logger.error(f"Error adding memories: {e}") - raise e - + raise + @timed("search_memories") async def search_memories( self, @@ -155,37 +173,60 @@ class Mem0Manager: # rerank: bool = False, # filter_memories: bool = False, agent_id: Optional[str] = None, - run_id: Optional[str] = None + run_id: Optional[str] = None, ) -> Dict[str, Any]: """Search memories - native Mem0 pattern""" try: # Minimal empty query protection for API compatibility if not query or query.strip() == "": - return {"memories": [], "total_count": 0, "query": query, "note": "Empty query provided, no results returned. Use a specific query to search memories."} + return { + "memories": [], + "total_count": 0, + "query": query, + "note": "Empty query provided, no results returned. Use a specific query to search memories.", + } # Direct Mem0 search - trust native handling - result = self.memory.search(query=query, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit, threshold=threshold, filters=filters) - return {"memories": result.get("results", []), "total_count": len(result.get("results", [])), "query": query} + result = self.memory.search( + query=query, + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + limit=limit, + threshold=threshold, + filters=filters, + ) + return { + "memories": result.get("results", []), + "total_count": len(result.get("results", [])), + "query": query, + } except Exception as e: logger.error(f"Error searching memories: {e}") - raise e - + raise + async def get_user_memories( self, user_id: str, limit: int = 10, agent_id: Optional[str] = None, run_id: Optional[str] = None, - filters: Optional[Dict[str, Any]] = None + filters: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """Get all memories for a user - native Mem0 pattern.""" try: # Direct Mem0 get_all call - trust native parameter handling - result = self.memory.get_all(user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id, filters=filters) + result = self.memory.get_all( + user_id=user_id, + limit=limit, + agent_id=agent_id, + run_id=run_id, + filters=filters, + ) return result.get("results", []) except Exception as e: logger.error(f"Error getting user memories: {e}") - raise e - + raise + @timed("update_memory") async def update_memory( self, @@ -194,15 +235,12 @@ class Mem0Manager: ) -> Dict[str, Any]: """Update memory - pure Mem0 passthrough.""" try: - result = self.memory.update( - memory_id=memory_id, - data=content - ) + result = self.memory.update(memory_id=memory_id, data=content) return {"message": "Memory updated successfully", "result": result} except Exception as e: logger.error(f"Error updating memory: {e}") - raise e - + raise + @timed("delete_memory") async def delete_memory(self, memory_id: str) -> Dict[str, Any]: """Delete memory - pure Mem0 passthrough.""" @@ -211,7 +249,7 @@ class Mem0Manager: return {"message": "Memory deleted successfully"} except Exception as e: logger.error(f"Error deleting memory: {e}") - raise e + raise async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]: """Delete all user memories - pure Mem0 passthrough.""" @@ -220,8 +258,8 @@ class Mem0Manager: return {"message": "All user memories deleted successfully"} except Exception as e: logger.error(f"Error deleting user memories: {e}") - raise e - + raise + async def get_memory_history(self, memory_id: str) -> Dict[str, Any]: """Get memory change history - pure Mem0 passthrough.""" try: @@ -229,42 +267,44 @@ class Mem0Manager: return { "memory_id": memory_id, "history": history, - "message": "Memory history retrieved successfully" + "message": "Memory history retrieved successfully", } except Exception as e: logger.error(f"Error getting memory history: {e}") - raise e + raise - - async def get_graph_relationships(self, user_id: Optional[str], agent_id: Optional[str], run_id: Optional[str], limit: int = 50) -> Dict[str, Any]: + async def get_graph_relationships( + self, + user_id: Optional[str], + agent_id: Optional[str], + run_id: Optional[str], + limit: int = 50, + ) -> Dict[str, Any]: """Get graph relationships - using correct Mem0 get_all() method.""" try: # Use get_all() to retrieve memories with graph relationships result = self.memory.get_all( - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - limit=limit + user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit ) - + # Extract relationships from Mem0's response structure relationships = result.get("relations", []) - + # For entities, we can derive them from memory results or relations entities = [] if "results" in result: # Extract unique entities from memories and relationships entity_set = set() - + # Add entities from relationships for rel in relationships: if "source" in rel: entity_set.add(rel["source"]) if "target" in rel: entity_set.add(rel["target"]) - + entities = [{"name": entity} for entity in entity_set] - + return { "relationships": relationships, "entities": entities, @@ -272,9 +312,9 @@ class Mem0Manager: "agent_id": agent_id, "run_id": run_id, "total_memories": len(result.get("results", [])), - "total_relationships": len(relationships) + "total_relationships": len(relationships), } - + except Exception as e: logger.error(f"Error getting graph relationships: {e}") # Return empty but structured response on error @@ -286,9 +326,9 @@ class Mem0Manager: "run_id": run_id, "total_memories": 0, "total_relationships": 0, - "error": str(e) + "error": str(e), } - + @timed("chat_with_memory") async def chat_with_memory( self, @@ -301,53 +341,74 @@ class Mem0Manager: ) -> Dict[str, Any]: """Chat with memory - native Mem0 pattern with detailed timing.""" import time - + try: total_start_time = time.time() - print(f"\n๐Ÿš€ Starting chat request for user: {user_id}") - - # Stage 1: Memory Search + logger.info("Starting chat request", user_id=user_id) + search_start_time = time.time() - search_result = self.memory.search(query=message, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=10, threshold=0.3) + search_result = self.memory.search( + query=message, + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + limit=10, + threshold=0.3, + ) relevant_memories = search_result.get("results", []) - memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories) + memories_str = "\n".join( + f"- {entry['memory']}" for entry in relevant_memories + ) search_time = time.time() - search_start_time - print(f"๐Ÿ” Memory search took: {search_time:.2f}s (found {len(relevant_memories)} memories)") - - # Stage 2: Prepare LLM messages + logger.debug( + "Memory search completed", + search_time_s=round(search_time, 2), + memories_found=len(relevant_memories), + ) + prep_start_time = time.time() system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}" messages = [{"role": "system", "content": system_prompt}] - - # Add conversation context if provided (last 50 messages) + if context: messages.extend(context) - print(f"๐Ÿ“ Added {len(context)} context messages") - - # Add current user message + logger.debug("Added context messages", context_count=len(context)) + messages.append({"role": "user", "content": message}) prep_time = time.time() - prep_start_time - print(f"๐Ÿ“‹ Message preparation took: {prep_time:.3f}s") - - # Stage 3: LLM Call + llm_start_time = time.time() - response = self.openai_client.chat.completions.create(model=settings.default_model, messages=messages) + response = self.openai_client.chat.completions.create( + model=settings.default_model, messages=messages + ) assistant_response = response.choices[0].message.content llm_time = time.time() - llm_start_time - print(f"๐Ÿค– LLM call took: {llm_time:.2f}s (model: {settings.default_model})") - - # Stage 4: Memory Add + logger.debug( + "LLM call completed", + llm_time_s=round(llm_time, 2), + model=settings.default_model, + ) + add_start_time = time.time() - memory_messages = [{"role": "user", "content": message}, {"role": "assistant", "content": assistant_response}] + memory_messages = [ + {"role": "user", "content": message}, + {"role": "assistant", "content": assistant_response}, + ] self.memory.add(memory_messages, user_id=user_id) add_time = time.time() - add_start_time - print(f"๐Ÿ’พ Memory add took: {add_time:.2f}s") - - # Total timing summary + total_time = time.time() - total_start_time - print(f"โฑ๏ธ TOTAL: {total_time:.2f}s | Search: {search_time:.2f}s | LLM: {llm_time:.2f}s | Add: {add_time:.2f}s | Prep: {prep_time:.3f}s") - print(f"๐Ÿ“Š Breakdown: Search {(search_time/total_time)*100:.1f}% | LLM {(llm_time/total_time)*100:.1f}% | Add {(add_time/total_time)*100:.1f}%\n") - + logger.info( + "Chat request completed", + user_id=user_id, + total_time_s=round(total_time, 2), + search_time_s=round(search_time, 2), + llm_time_s=round(llm_time, 2), + add_time_s=round(add_time, 2), + memories_used=len(relevant_memories), + model=settings.default_model, + ) + return { "response": assistant_response, "memories_used": len(relevant_memories), @@ -356,37 +417,42 @@ class Mem0Manager: "total": round(total_time, 2), "search": round(search_time, 2), "llm": round(llm_time, 2), - "add": round(add_time, 2) - } + "add": round(add_time, 2), + }, } - + except Exception as e: - logger.error(f"Error in chat_with_memory: {e}") + logger.error( + "Error in chat_with_memory", + error=str(e), + user_id=user_id, + exc_info=True, + ) return { "error": str(e), "response": "I apologize, but I encountered an error processing your request.", "memories_used": 0, - "model_used": None + "model_used": None, } - + async def health_check(self) -> Dict[str, str]: """Basic health check - just connectivity.""" status = {} - + # Check custom OpenAI endpoint try: models = self.openai_client.models.list() status["openai_endpoint"] = "healthy" except Exception as e: status["openai_endpoint"] = f"unhealthy: {str(e)}" - + # Check Mem0 memory try: self.memory.search(query="test", user_id="health_check", limit=1) status["mem0_memory"] = "healthy" except Exception as e: status["mem0_memory"] = f"unhealthy: {str(e)}" - + return status diff --git a/backend/models.py b/backend/models.py index 82fe1fa..e832b1d 100644 --- a/backend/models.py +++ b/backend/models.py @@ -2,53 +2,115 @@ from typing import List, Optional, Dict, Any from pydantic import BaseModel, Field +import re + + +# Constants for input validation +MAX_MESSAGE_LENGTH = 50000 # ~12k tokens max per message +MAX_QUERY_LENGTH = 10000 # ~2.5k tokens max per query +MAX_USER_ID_LENGTH = 100 # Reasonable user ID length +MAX_MEMORY_ID_LENGTH = 100 # Memory IDs are typically UUIDs +MAX_CONTEXT_MESSAGES = 100 # Max conversation context messages +USER_ID_PATTERN = r"^[a-zA-Z0-9_\-\.@]+$" # Alphanumeric with common separators # Request Models class ChatMessage(BaseModel): """Chat message structure.""" - role: str = Field(..., description="Message role (user, assistant, system)") - content: str = Field(..., description="Message content") + + role: str = Field( + ..., max_length=20, description="Message role (user, assistant, system)" + ) + content: str = Field( + ..., max_length=MAX_MESSAGE_LENGTH, description="Message content" + ) class ChatRequest(BaseModel): """Ultra-minimal chat request.""" - message: str = Field(..., description="User message") - user_id: Optional[str] = Field("default", description="User identifier") - agent_id: Optional[str] = Field(None, description="Agent identifier") - run_id: Optional[str] = Field(None, description="Run identifier") - context: Optional[List[ChatMessage]] = Field(None, description="Previous conversation context") + + message: str = Field(..., max_length=MAX_MESSAGE_LENGTH, description="User message") + user_id: Optional[str] = Field( + "default", + max_length=MAX_USER_ID_LENGTH, + pattern=USER_ID_PATTERN, + description="User identifier (alphanumeric, _, -, ., @)", + ) + agent_id: Optional[str] = Field( + None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier" + ) + run_id: Optional[str] = Field( + None, max_length=MAX_USER_ID_LENGTH, description="Run identifier" + ) + context: Optional[List[ChatMessage]] = Field( + None, + max_length=MAX_CONTEXT_MESSAGES, + description="Previous conversation context (max 100 messages)", + ) metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata") class MemoryAddRequest(BaseModel): """Request to add memories with hierarchy support - open-source compatible.""" - messages: List[ChatMessage] = Field(..., description="Messages to process") - user_id: Optional[str] = Field("default", description="User identifier") - agent_id: Optional[str] = Field(None, description="Agent identifier") - run_id: Optional[str] = Field(None, description="Run identifier") + + messages: List[ChatMessage] = Field( + ..., + max_length=MAX_CONTEXT_MESSAGES, + description="Messages to process (max 100 messages)", + ) + user_id: Optional[str] = Field( + "default", + max_length=MAX_USER_ID_LENGTH, + pattern=USER_ID_PATTERN, + description="User identifier", + ) + agent_id: Optional[str] = Field( + None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier" + ) + run_id: Optional[str] = Field( + None, max_length=MAX_USER_ID_LENGTH, description="Run identifier" + ) metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata") class MemorySearchRequest(BaseModel): """Request to search memories with hierarchy filtering.""" - query: str = Field(..., description="Search query") - user_id: Optional[str] = Field("default", description="User identifier") - agent_id: Optional[str] = Field(None, description="Agent identifier") - run_id: Optional[str] = Field(None, description="Run identifier") - limit: int = Field(5, description="Maximum number of results") - threshold: Optional[float] = Field(None, description="Minimum relevance score") + + query: str = Field(..., max_length=MAX_QUERY_LENGTH, description="Search query") + user_id: Optional[str] = Field( + "default", + max_length=MAX_USER_ID_LENGTH, + pattern=USER_ID_PATTERN, + description="User identifier", + ) + agent_id: Optional[str] = Field( + None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier" + ) + run_id: Optional[str] = Field( + None, max_length=MAX_USER_ID_LENGTH, description="Run identifier" + ) + limit: int = Field(5, ge=1, le=100, description="Maximum number of results (1-100)") + threshold: Optional[float] = Field( + None, ge=0.0, le=1.0, description="Minimum relevance score (0-1)" + ) filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters") - - # Hierarchy filters (open-source compatible) - agent_id: Optional[str] = Field(None, description="Filter by agent identifier") - run_id: Optional[str] = Field(None, description="Filter by run identifier") class MemoryUpdateRequest(BaseModel): """Request to update a memory.""" - memory_id: str = Field(..., description="Memory ID to update") - content: str = Field(..., description="New memory content") + + memory_id: str = Field( + ..., max_length=MAX_MEMORY_ID_LENGTH, description="Memory ID to update" + ) + user_id: str = Field( + ..., + max_length=MAX_USER_ID_LENGTH, + pattern=USER_ID_PATTERN, + description="User identifier for ownership verification", + ) + content: str = Field( + ..., max_length=MAX_MESSAGE_LENGTH, description="New memory content" + ) metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata") @@ -57,19 +119,23 @@ class MemoryUpdateRequest(BaseModel): class MemoryItem(BaseModel): """Individual memory item.""" + id: str = Field(..., description="Memory unique identifier") memory: str = Field(..., description="Memory content") user_id: Optional[str] = Field(None, description="Associated user ID") agent_id: Optional[str] = Field(None, description="Associated agent ID") run_id: Optional[str] = Field(None, description="Associated run ID") metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata") - score: Optional[float] = Field(None, description="Relevance score (for search results)") + score: Optional[float] = Field( + None, description="Relevance score (for search results)" + ) created_at: Optional[str] = Field(None, description="Creation timestamp") updated_at: Optional[str] = Field(None, description="Last update timestamp") class MemorySearchResponse(BaseModel): """Memory search results - pure Mem0 structure.""" + memories: List[MemoryItem] = Field(..., description="Found memories") total_count: int = Field(..., description="Total number of memories found") query: str = Field(..., description="Original search query") @@ -77,27 +143,37 @@ class MemorySearchResponse(BaseModel): class MemoryAddResponse(BaseModel): """Response from adding memories - pure Mem0 structure.""" - added_memories: List[Dict[str, Any]] = Field(..., description="Memories that were added") + + added_memories: List[Dict[str, Any]] = Field( + ..., description="Memories that were added" + ) message: str = Field(..., description="Success message") class GraphRelationship(BaseModel): """Graph relationship structure.""" + source: str = Field(..., description="Source entity") relationship: str = Field(..., description="Relationship type") target: str = Field(..., description="Target entity") - properties: Optional[Dict[str, Any]] = Field(None, description="Relationship properties") + properties: Optional[Dict[str, Any]] = Field( + None, description="Relationship properties" + ) class GraphResponse(BaseModel): """Graph relationships - pure Mem0 structure.""" - relationships: List[GraphRelationship] = Field(..., description="Found relationships") + + relationships: List[GraphRelationship] = Field( + ..., description="Found relationships" + ) entities: List[str] = Field(..., description="Unique entities") user_id: str = Field(..., description="User identifier") class HealthResponse(BaseModel): """Health check response.""" + status: str = Field(..., description="Service status") services: Dict[str, str] = Field(..., description="Individual service statuses") timestamp: str = Field(..., description="Health check timestamp") @@ -105,6 +181,7 @@ class HealthResponse(BaseModel): class ErrorResponse(BaseModel): """Error response structure.""" + error: str = Field(..., description="Error message") detail: Optional[str] = Field(None, description="Detailed error information") status_code: int = Field(..., description="HTTP status code") @@ -112,8 +189,10 @@ class ErrorResponse(BaseModel): # Statistics and Monitoring Models + class MemoryOperationStats(BaseModel): """Memory operation statistics.""" + add: int = Field(..., description="Number of add operations") search: int = Field(..., description="Number of search operations") update: int = Field(..., description="Number of update operations") @@ -122,19 +201,29 @@ class MemoryOperationStats(BaseModel): class GlobalStatsResponse(BaseModel): """Global application statistics.""" + total_memories: int = Field(..., description="Total memories across all users") total_users: int = Field(..., description="Total number of users") api_calls_today: int = Field(..., description="Total API calls today") - avg_response_time_ms: float = Field(..., description="Average response time in milliseconds") - memory_operations: MemoryOperationStats = Field(..., description="Memory operation breakdown") + avg_response_time_ms: float = Field( + ..., description="Average response time in milliseconds" + ) + memory_operations: MemoryOperationStats = Field( + ..., description="Memory operation breakdown" + ) uptime_seconds: float = Field(..., description="Application uptime in seconds") class UserStatsResponse(BaseModel): """User-specific statistics.""" + user_id: str = Field(..., description="User identifier") memory_count: int = Field(..., description="Number of memories for this user") - relationship_count: int = Field(..., description="Number of graph relationships for this user") + relationship_count: int = Field( + ..., description="Number of graph relationships for this user" + ) last_activity: Optional[str] = Field(None, description="Last activity timestamp") api_calls_today: int = Field(..., description="API calls made by this user today") - avg_response_time_ms: float = Field(..., description="Average response time for this user's requests") \ No newline at end of file + avg_response_time_ms: float = Field( + ..., description="Average response time for this user's requests" + ) diff --git a/backend/requirements.txt b/backend/requirements.txt index 1ad77f8..4539504 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -32,5 +32,8 @@ python-json-logger python-jose[cryptography] passlib[bcrypt] +# Rate Limiting +slowapi + # MCP Server mcp[server]>=1.0.0 diff --git a/test_integration.py b/test_integration.py index 12c150e..3464c6f 100644 --- a/test_integration.py +++ b/test_integration.py @@ -19,27 +19,41 @@ import time BASE_URL = "http://localhost:8000" TEST_USER = f"test_user_{int(datetime.now().timestamp())}" +# API Key for authentication - set via environment or use default test key +import os + +API_KEY = os.environ.get("MEM0_API_KEY", "test-api-key") +AUTH_HEADERS = {"X-API-Key": API_KEY} + + def main(): parser = argparse.ArgumentParser( description="Mem0 Integration Tests - Real API Testing (Zero Mocking)", - formatter_class=argparse.RawDescriptionHelpFormatter + formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument("--verbose", "-v", action="store_true", - help="Show detailed output and API responses") - + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Show detailed output and API responses", + ) + args = parser.parse_args() verbose = args.verbose - + print("๐Ÿงช Mem0 Integration Tests - Real API Testing") print(f"๐ŸŽฏ Target: {BASE_URL}") print(f"๐Ÿ‘ค Test User: {TEST_USER}") print(f"โฐ Started: {datetime.now().strftime('%H:%M:%S')}") print("=" * 50) - + # Test sequence - order matters for data dependencies tests = [ test_health_check, - test_empty_search_protection, + test_auth_required_endpoints, + test_ownership_verification, + test_request_size_limit, + test_empty_search_protection, test_add_memories_with_hierarchy, test_search_memories_basic, test_search_memories_hierarchy_filters, @@ -51,30 +65,30 @@ def main(): test_graph_relationships, test_delete_specific_memory, test_delete_all_user_memories, - test_cleanup_verification + test_cleanup_verification, ] - + results = [] start_time = time.time() - + for test in tests: result = run_test(test.__name__, test, verbose) results.append(result) - + # Small delay between tests for API stability time.sleep(0.5) - + # Summary end_time = time.time() duration = end_time - start_time - + passed = sum(1 for r in results if r) total = len(results) - + print("=" * 50) print(f"๐Ÿ“Š Test Results: {passed}/{total} tests passed") print(f"โฑ๏ธ Duration: {duration:.2f} seconds") - + if passed == total: print("โœ… All tests passed! System is working correctly.") sys.exit(0) @@ -82,16 +96,17 @@ def main(): print("โŒ Some tests failed! Check the output above.") sys.exit(1) + def run_test(name, test_func, verbose): """Run a single test with error handling""" try: if verbose: print(f"\n๐Ÿ” Running {name}...") - + test_func(verbose) print(f"โœ… {name}") return True - + except AssertionError as e: print(f"โŒ {name}: Assertion failed - {e}") return False @@ -102,6 +117,7 @@ def run_test(name, test_func, verbose): print(f"โŒ {name}: {e}") return False + def log_response(response, verbose, context=""): """Log API response details if verbose""" if verbose: @@ -111,78 +127,101 @@ def log_response(response, verbose, context=""): if isinstance(data, dict) and len(data) < 5: print(f" {context} Response: {data}") else: - print(f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}") + print( + f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}" + ) except: print(f" {context} Response: {response.text[:100]}...") + # ================== TEST FUNCTIONS ================== + def test_health_check(verbose): """Test service health endpoint""" - response = requests.get(f"{BASE_URL}/health", timeout=10) + response = requests.get( + f"{BASE_URL}/health", timeout=10 + ) # Health doesn't require auth log_response(response, verbose, "Health") - + assert response.status_code == 200, f"Expected 200, got {response.status_code}" - + data = response.json() assert "status" in data, "Health response missing 'status' field" - assert data["status"] in ["healthy", "degraded"], f"Invalid status: {data['status']}" - + assert data["status"] in ["healthy", "degraded"], ( + f"Invalid status: {data['status']}" + ) + # Check individual services assert "services" in data, "Health response missing 'services' field" - + if verbose: print(f" Overall status: {data['status']}") for service, status in data["services"].items(): print(f" {service}: {status}") + def test_empty_search_protection(verbose): """Test empty query protection (should not return 500 error)""" - payload = { - "query": "", - "user_id": TEST_USER, - "limit": 5 - } - - response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=10) + payload = {"query": "", "user_id": TEST_USER, "limit": 5} + + response = requests.post( + f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=10 + ) log_response(response, verbose, "Empty Search") - - assert response.status_code == 200, f"Empty query failed with {response.status_code}" - + + assert response.status_code == 200, ( + f"Empty query failed with {response.status_code}" + ) + data = response.json() assert data["memories"] == [], "Empty query should return empty memories list" assert "note" in data, "Empty query response should include explanatory note" assert data["query"] == "", "Query should be echoed back" - + if verbose: print(f" Empty search note: {data['note']}") print(f" Total count: {data.get('total_count', 0)}") + def test_add_memories_with_hierarchy(verbose): """Test adding memories with multi-level hierarchy support""" payload = { "messages": [ - {"role": "user", "content": "I work at TechCorp as a Senior Software Engineer"}, - {"role": "user", "content": "My colleague Sarah from Marketing team helped with Q3 presentation"}, - {"role": "user", "content": "Meeting with John the Product Manager tomorrow about new feature development"} + { + "role": "user", + "content": "I work at TechCorp as a Senior Software Engineer", + }, + { + "role": "user", + "content": "My colleague Sarah from Marketing team helped with Q3 presentation", + }, + { + "role": "user", + "content": "Meeting with John the Product Manager tomorrow about new feature development", + }, ], "user_id": TEST_USER, "agent_id": "test_agent", "run_id": "test_run_001", "session_id": "test_session_001", - "metadata": {"test": "integration", "scenario": "work_context"} + "metadata": {"test": "integration", "scenario": "work_context"}, } - - response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60) + + response = requests.post( + f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60 + ) log_response(response, verbose, "Add Memories") - - assert response.status_code == 200, f"Add memories failed with {response.status_code}" - + + assert response.status_code == 200, ( + f"Add memories failed with {response.status_code}" + ) + data = response.json() assert "added_memories" in data, "Response missing 'added_memories'" assert "message" in data, "Response missing success message" assert len(data["added_memories"]) > 0, "No memories were added" - + # Verify graph extraction (if available) memories = data["added_memories"] if isinstance(memories, list) and len(memories) > 0: @@ -191,47 +230,51 @@ def test_add_memories_with_hierarchy(verbose): relations = first_memory["relations"] if "added_entities" in relations and relations["added_entities"]: if verbose: - print(f" Graph extracted: {len(relations['added_entities'])} relationships") + print( + f" Graph extracted: {len(relations['added_entities'])} relationships" + ) print(f" Sample relations: {relations['added_entities'][:3]}") - + if verbose: print(f" Added {len(memories)} memory blocks") - print(f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001") + print( + f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001" + ) + def test_search_memories_basic(verbose): """Test basic memory search functionality""" # Test meaningful search - payload = { - "query": "TechCorp", - "user_id": TEST_USER, - "limit": 10 - } - - response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15) + payload = {"query": "TechCorp", "user_id": TEST_USER, "limit": 10} + + response = requests.post( + f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15 + ) log_response(response, verbose, "Search") - + assert response.status_code == 200, f"Search failed with {response.status_code}" - + data = response.json() assert "memories" in data, "Search response missing 'memories'" assert "total_count" in data, "Search response missing 'total_count'" assert "query" in data, "Search response missing 'query'" assert data["query"] == "TechCorp", "Query not echoed correctly" - + # Should find memories since we just added some assert data["total_count"] > 0, "Search should find previously added memories" assert len(data["memories"]) > 0, "Search should return memory results" - + # Verify memory structure memory = data["memories"][0] assert "id" in memory, "Memory missing 'id'" assert "memory" in memory, "Memory missing 'memory' content" assert "user_id" in memory, "Memory missing 'user_id'" - + if verbose: print(f" Found {data['total_count']} memories") print(f" First memory: {memory['memory'][:50]}...") + def test_search_memories_hierarchy_filters(verbose): """Test multi-level hierarchy filtering in search""" # Test with hierarchy filters @@ -241,23 +284,30 @@ def test_search_memories_hierarchy_filters(verbose): "agent_id": "test_agent", "run_id": "test_run_001", "session_id": "test_session_001", - "limit": 10 + "limit": 10, } - - response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15) + + response = requests.post( + f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15 + ) log_response(response, verbose, "Hierarchy Search") - - assert response.status_code == 200, f"Hierarchy search failed with {response.status_code}" - + + assert response.status_code == 200, ( + f"Hierarchy search failed with {response.status_code}" + ) + data = response.json() assert "memories" in data, "Hierarchy search response missing 'memories'" - + # Should find memories since we added with these exact hierarchy values assert len(data["memories"]) > 0, "Should find memories with matching hierarchy" - + if verbose: print(f" Found {len(data['memories'])} memories with hierarchy filters") - print(f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001") + print( + f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001" + ) + def test_get_user_memories_with_hierarchy(verbose): """Test retrieving user memories with hierarchy filtering""" @@ -266,23 +316,30 @@ def test_get_user_memories_with_hierarchy(verbose): "limit": 20, "agent_id": "test_agent", "run_id": "test_run_001", - "session_id": "test_session_001" + "session_id": "test_session_001", } - - response = requests.get(f"{BASE_URL}/memories/{TEST_USER}", params=params, timeout=15) + + response = requests.get( + f"{BASE_URL}/memories/{TEST_USER}", + params=params, + headers=AUTH_HEADERS, + timeout=15, + ) log_response(response, verbose, "Get User Memories with Hierarchy") - - assert response.status_code == 200, f"Get user memories with hierarchy failed with {response.status_code}" - + + assert response.status_code == 200, ( + f"Get user memories with hierarchy failed with {response.status_code}" + ) + memories = response.json() assert isinstance(memories, list), "User memories should return a list" - + if len(memories) > 0: memory = memories[0] assert "id" in memory, "Memory missing 'id'" assert "memory" in memory, "Memory missing 'memory' content" assert memory["user_id"] == TEST_USER, f"Wrong user_id: {memory['user_id']}" - + if verbose: print(f" Retrieved {len(memories)} memories with hierarchy filters") print(f" First memory: {memory['memory'][:40]}...") @@ -290,248 +347,320 @@ def test_get_user_memories_with_hierarchy(verbose): if verbose: print(" No memories found with hierarchy filters (may be expected)") + def test_memory_history(verbose): """Test memory history endpoint""" # First get a memory to check history for - response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10) + response = requests.get( + f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10 + ) assert response.status_code == 200, "Failed to get memory for history test" - + memories = response.json() if len(memories) == 0: if verbose: print(" No memories available for history test (skipping)") return - + memory_id = memories[0]["id"] - + # Test memory history endpoint - response = requests.get(f"{BASE_URL}/memories/{memory_id}/history", timeout=15) + response = requests.get( + f"{BASE_URL}/memories/{memory_id}/history?user_id={TEST_USER}", + headers=AUTH_HEADERS, + timeout=15, + ) log_response(response, verbose, "Memory History") - - assert response.status_code == 200, f"Memory history failed with {response.status_code}" - + + assert response.status_code == 200, ( + f"Memory history failed with {response.status_code}" + ) + data = response.json() assert "memory_id" in data, "History response missing 'memory_id'" assert "history" in data, "History response missing 'history'" assert "message" in data, "History response missing success message" - assert data["memory_id"] == memory_id, f"Wrong memory_id in response: {data['memory_id']}" - + assert data["memory_id"] == memory_id, ( + f"Wrong memory_id in response: {data['memory_id']}" + ) + if verbose: print(f" Retrieved history for memory {memory_id}") - print(f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}") - + print( + f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}" + ) def test_update_memory(verbose): """Test updating a specific memory""" # First get a memory to update - response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10) + response = requests.get( + f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10 + ) assert response.status_code == 200, "Failed to get memory for update test" - + memories = response.json() assert len(memories) > 0, "No memories available to update" - + memory_id = memories[0]["id"] original_content = memories[0]["memory"] - + # Update the memory payload = { "memory_id": memory_id, - "content": f"UPDATED: {original_content}" + "user_id": TEST_USER, + "content": f"UPDATED: {original_content}", } - - response = requests.put(f"{BASE_URL}/memories", json=payload, timeout=10) + + response = requests.put( + f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=10 + ) log_response(response, verbose, "Update") - + assert response.status_code == 200, f"Update failed with {response.status_code}" - + data = response.json() assert "message" in data, "Update response missing success message" - + if verbose: print(f" Updated memory {memory_id}") print(f" Original: {original_content[:30]}...") + def test_chat_with_memory(verbose): """Test memory-enhanced chat functionality""" - payload = { - "message": "What company do I work for?", - "user_id": TEST_USER - } - + payload = {"message": "What company do I work for?", "user_id": TEST_USER} + try: - response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=90) + response = requests.post( + f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=90 + ) log_response(response, verbose, "Chat") - + assert response.status_code == 200, f"Chat failed with {response.status_code}" - + data = response.json() assert "response" in data, "Chat response missing 'response'" assert "memories_used" in data, "Chat response missing 'memories_used'" assert "model_used" in data, "Chat response missing 'model_used'" - + # Should use some memories for context assert data["memories_used"] >= 0, "Memories used should be non-negative" - + if verbose: print(f" Chat response: {data['response'][:60]}...") print(f" Memories used: {data['memories_used']}") print(f" Model: {data['model_used']}") - + except requests.exceptions.ReadTimeout: if verbose: print(" Chat endpoint timed out (LLM API may be slow)") # Still test that the endpoint exists and accepts requests try: - response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=5) + response = requests.post( + f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=5 + ) except requests.exceptions.ReadTimeout: # This is expected - endpoint exists but processing is slow if verbose: print(" Chat endpoint confirmed active (processing timeout expected)") + def test_graph_relationships_creation(verbose): """Test graph relationships creation with entity-rich memories""" # Create a separate test user for graph relationship testing graph_test_user = f"graph_test_user_{int(datetime.now().timestamp())}" - + # Add memories with clear entity relationships payload = { "messages": [ - {"role": "user", "content": "John Smith works at Microsoft as a Senior Software Engineer"}, - {"role": "user", "content": "John Smith is friends with Sarah Johnson who works at Google"}, - {"role": "user", "content": "Sarah Johnson lives in Seattle and loves hiking"}, + { + "role": "user", + "content": "John Smith works at Microsoft as a Senior Software Engineer", + }, + { + "role": "user", + "content": "John Smith is friends with Sarah Johnson who works at Google", + }, + { + "role": "user", + "content": "Sarah Johnson lives in Seattle and loves hiking", + }, {"role": "user", "content": "Microsoft is located in Redmond, Washington"}, - {"role": "user", "content": "John Smith and Sarah Johnson both graduated from Stanford University"} + { + "role": "user", + "content": "John Smith and Sarah Johnson both graduated from Stanford University", + }, ], "user_id": graph_test_user, - "metadata": {"test": "graph_relationships", "scenario": "entity_creation"} + "metadata": {"test": "graph_relationships", "scenario": "entity_creation"}, } - - response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60) + + response = requests.post( + f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60 + ) log_response(response, verbose, "Add Graph Memories") - - assert response.status_code == 200, f"Add graph memories failed with {response.status_code}" - + + assert response.status_code == 200, ( + f"Add graph memories failed with {response.status_code}" + ) + data = response.json() assert "added_memories" in data, "Response missing 'added_memories'" - + if verbose: - print(f" Added {len(data['added_memories'])} memories for graph relationship testing") - + print( + f" Added {len(data['added_memories'])} memories for graph relationship testing" + ) + # Wait a moment for graph processing (Mem0 graph extraction can be async) time.sleep(2) - + # Test graph relationships endpoint - response = requests.get(f"{BASE_URL}/graph/relationships/{graph_test_user}", timeout=15) + response = requests.get( + f"{BASE_URL}/graph/relationships/{graph_test_user}", + headers=AUTH_HEADERS, + timeout=15, + ) log_response(response, verbose, "Graph Relationships") - - assert response.status_code == 200, f"Graph relationships failed with {response.status_code}" - + + assert response.status_code == 200, ( + f"Graph relationships failed with {response.status_code}" + ) + graph_data = response.json() assert "relationships" in graph_data, "Graph response missing 'relationships'" assert "entities" in graph_data, "Graph response missing 'entities'" assert "user_id" in graph_data, "Graph response missing 'user_id'" - assert graph_data["user_id"] == graph_test_user, f"Wrong user_id in graph: {graph_data['user_id']}" - + assert graph_data["user_id"] == graph_test_user, ( + f"Wrong user_id in graph: {graph_data['user_id']}" + ) + relationships = graph_data["relationships"] entities = graph_data["entities"] - + if verbose: print(f" Found {len(relationships)} relationships") print(f" Found {len(entities)} entities") - + # Print sample relationships if they exist if relationships: print(f" Sample relationships:") for i, rel in enumerate(relationships[:3]): # Show first 3 source = rel.get("source", "unknown") - target = rel.get("target", "unknown") + target = rel.get("target", "unknown") relationship = rel.get("relationship", "unknown") - print(f" {i+1}. {source} --{relationship}--> {target}") - + print(f" {i + 1}. {source} --{relationship}--> {target}") + # Print sample entities if they exist if entities: - print(f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}") - + print( + f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}" + ) + # Verify relationship structure (if relationships exist) for rel in relationships: - assert "source" in rel or "from" in rel, f"Relationship missing source/from: {rel}" + assert "source" in rel or "from" in rel, ( + f"Relationship missing source/from: {rel}" + ) assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}" - assert "relationship" in rel or "type" in rel, f"Relationship missing type: {rel}" - + assert "relationship" in rel or "type" in rel, ( + f"Relationship missing type: {rel}" + ) + # Clean up graph test user memories - cleanup_response = requests.delete(f"{BASE_URL}/memories/user/{graph_test_user}", timeout=15) + cleanup_response = requests.delete( + f"{BASE_URL}/memories/user/{graph_test_user}", headers=AUTH_HEADERS, timeout=15 + ) assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories" - + if verbose: print(f" Cleaned up graph test user: {graph_test_user}") - + # Note: We expect some relationships even if graph extraction is basic # The test passes if the endpoint works and returns proper structure - + + def test_graph_relationships(verbose): """Test graph relationships endpoint""" - response = requests.get(f"{BASE_URL}/graph/relationships/{TEST_USER}", timeout=15) + response = requests.get( + f"{BASE_URL}/graph/relationships/{TEST_USER}", headers=AUTH_HEADERS, timeout=15 + ) log_response(response, verbose, "Graph") - - assert response.status_code == 200, f"Graph endpoint failed with {response.status_code}" - + + assert response.status_code == 200, ( + f"Graph endpoint failed with {response.status_code}" + ) + data = response.json() assert "relationships" in data, "Graph response missing 'relationships'" assert "entities" in data, "Graph response missing 'entities'" assert "user_id" in data, "Graph response missing 'user_id'" assert data["user_id"] == TEST_USER, f"Wrong user_id in graph: {data['user_id']}" - + if verbose: print(f" Relationships: {len(data['relationships'])}") print(f" Entities: {len(data['entities'])}") + def test_delete_specific_memory(verbose): """Test deleting a specific memory""" # Get a memory to delete - response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10) + response = requests.get( + f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10 + ) assert response.status_code == 200, "Failed to get memory for deletion test" - + memories = response.json() assert len(memories) > 0, "No memories available to delete" - + memory_id = memories[0]["id"] - + # Delete the memory - response = requests.delete(f"{BASE_URL}/memories/{memory_id}", timeout=10) + response = requests.delete( + f"{BASE_URL}/memories/{memory_id}", headers=AUTH_HEADERS, timeout=10 + ) log_response(response, verbose, "Delete") - + assert response.status_code == 200, f"Delete failed with {response.status_code}" - + data = response.json() assert "message" in data, "Delete response missing success message" - + if verbose: print(f" Deleted memory {memory_id}") + def test_delete_all_user_memories(verbose): """Test deleting all memories for a user""" - response = requests.delete(f"{BASE_URL}/memories/user/{TEST_USER}", timeout=15) + response = requests.delete( + f"{BASE_URL}/memories/user/{TEST_USER}", headers=AUTH_HEADERS, timeout=15 + ) log_response(response, verbose, "Delete All") - + assert response.status_code == 200, f"Delete all failed with {response.status_code}" - + data = response.json() assert "message" in data, "Delete all response missing success message" - + if verbose: print(f"Deleted all memories for {TEST_USER}") + def test_cleanup_verification(verbose): """Verify cleanup was successful""" - response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=10", timeout=10) + response = requests.get( + f"{BASE_URL}/memories/{TEST_USER}?limit=10", headers=AUTH_HEADERS, timeout=10 + ) log_response(response, verbose, "Cleanup Check") - - assert response.status_code == 200, f"Cleanup verification failed with {response.status_code}" - + + assert response.status_code == 200, ( + f"Cleanup verification failed with {response.status_code}" + ) + memories = response.json() assert isinstance(memories, list), "Should return list even if empty" - + # Should be empty after deletion if len(memories) > 0: print(f" Warning: {len(memories)} memories still exist after cleanup") @@ -539,5 +668,79 @@ def test_cleanup_verification(verbose): if verbose: print(" Cleanup successful - no memories remain") + +# ================== SECURITY TEST FUNCTIONS ================== + + +def test_auth_required_endpoints(verbose): + """Test that protected endpoints require authentication""" + endpoints_requiring_auth = [ + ("GET", f"{BASE_URL}/memories/{TEST_USER}"), + ("POST", f"{BASE_URL}/memories/search"), + ("GET", f"{BASE_URL}/stats"), + ("GET", f"{BASE_URL}/models"), + ("GET", f"{BASE_URL}/users"), + ] + + for method, url in endpoints_requiring_auth: + if method == "GET": + response = requests.get(url, timeout=5) + else: + response = requests.post( + url, json={"query": "test", "user_id": TEST_USER}, timeout=5 + ) + + assert response.status_code in [401, 403], ( + f"{method} {url} should require auth, got {response.status_code}" + ) + + if verbose: + print(f" {method} {url}: {response.status_code} (auth required)") + + +def test_ownership_verification(verbose): + """Test that users can only access their own data""" + other_user = "other_user_not_me" + + response = requests.get( + f"{BASE_URL}/memories/{other_user}", headers=AUTH_HEADERS, timeout=5 + ) + + assert response.status_code in [403, 404], ( + f"Accessing other user's memories should be denied, got {response.status_code}" + ) + + if verbose: + print(f" Ownership check passed: {response.status_code}") + + +def test_request_size_limit(verbose): + """Test request size limit enforcement (10MB max)""" + large_payload = { + "messages": [{"role": "user", "content": "x" * (11 * 1024 * 1024)}], + "user_id": TEST_USER, + } + + try: + response = requests.post( + f"{BASE_URL}/memories", + json=large_payload, + headers={**AUTH_HEADERS, "Content-Length": str(11 * 1024 * 1024)}, + timeout=5, + ) + + assert response.status_code == 413, ( + f"Large request should return 413, got {response.status_code}" + ) + + if verbose: + print(f" Request size limit enforced: {response.status_code}") + except requests.exceptions.RequestException as e: + if verbose: + print( + f" Request size limit test: connection issue (expected for large payload)" + ) + + if __name__ == "__main__": - main() \ No newline at end of file + main()