"""Main FastAPI application for Mem0 Interface POC.""" import logging import time from datetime import datetime from typing import List, Dict, Any, Optional from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import structlog from config import settings from models import ( ChatRequest, MemoryAddRequest, MemoryAddResponse, MemorySearchRequest, MemorySearchResponse, MemoryUpdateRequest, MemoryItem, GraphResponse, HealthResponse, ErrorResponse, GlobalStatsResponse, UserStatsResponse ) from mem0_manager import mem0_manager # Configure structured logging structlog.configure( processors=[ structlog.stdlib.filter_by_level, structlog.stdlib.add_logger_name, structlog.stdlib.add_log_level, structlog.stdlib.PositionalArgumentsFormatter(), structlog.processors.TimeStamper(fmt="iso"), structlog.processors.StackInfoRenderer(), structlog.processors.format_exc_info, structlog.processors.UnicodeDecoder(), structlog.processors.JSONRenderer() ], context_class=dict, logger_factory=structlog.stdlib.LoggerFactory(), wrapper_class=structlog.stdlib.BoundLogger, cache_logger_on_first_use=True, ) logger = structlog.get_logger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager.""" # Startup logger.info("Starting Mem0 Interface POC") # Perform health check on startup health_status = await mem0_manager.health_check() unhealthy_services = [k for k, v in health_status.items() if "unhealthy" in v] if unhealthy_services: logger.warning(f"Some services are unhealthy: {unhealthy_services}") else: logger.info("All services are healthy") yield # Shutdown logger.info("Shutting down Mem0 Interface POC") # Initialize FastAPI app app = FastAPI( title="Mem0 Interface POC", description="Minimal but fully functional Mem0 interface with PostgreSQL and Neo4j integration", version="1.0.0", lifespan=lifespan ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins_list, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request logging middleware with monitoring @app.middleware("http") async def log_requests(request, call_next): """Log all HTTP requests with correlation ID and timing.""" from monitoring import generate_correlation_id, stats correlation_id = generate_correlation_id() start_time = time.time() # Extract user_id from request if available user_id = None if request.method == "POST": # Try to extract user_id from request body for POST requests try: body = await request.body() if body: import json data = json.loads(body) user_id = data.get('user_id') except: pass elif "user_id" in str(request.url.path): # Extract user_id from path for GET requests path_parts = request.url.path.split('/') if len(path_parts) > 2 and path_parts[-2] in ['memories', 'stats']: user_id = path_parts[-1] # Log start of request logger.info( "HTTP request started", correlation_id=correlation_id, method=request.method, path=request.url.path, user_id=user_id ) response = await call_next(request) process_time = time.time() - start_time process_time_ms = process_time * 1000 # Record statistics stats.record_api_call(user_id, process_time_ms) # Log completion with enhanced details if process_time_ms > 2000: # Slow request threshold logger.warning( "HTTP request completed (SLOW)", correlation_id=correlation_id, method=request.method, path=request.url.path, status_code=response.status_code, process_time_ms=round(process_time_ms, 2), user_id=user_id, slow_request=True ) elif response.status_code >= 400: logger.error( "HTTP request completed (ERROR)", correlation_id=correlation_id, method=request.method, path=request.url.path, status_code=response.status_code, process_time_ms=round(process_time_ms, 2), user_id=user_id, slow_request=False ) else: logger.info( "HTTP request completed", correlation_id=correlation_id, method=request.method, path=request.url.path, status_code=response.status_code, process_time_ms=round(process_time_ms, 2), user_id=user_id, slow_request=False ) return response # Exception handlers @app.exception_handler(Exception) async def global_exception_handler(request, exc): """Global exception handler.""" logger.error(f"Unhandled exception: {exc}", exc_info=True) return JSONResponse( status_code=500, content={"error": "Internal server error", "detail": str(exc)} ) # Health check endpoint @app.get("/health", response_model=HealthResponse) async def health_check(): """Check the health of all services.""" try: services = await mem0_manager.health_check() overall_status = "healthy" if all("healthy" in status for status in services.values()) else "degraded" return HealthResponse( status=overall_status, services=services, timestamp=datetime.utcnow().isoformat() ) except Exception as e: logger.error(f"Health check failed: {e}") return HealthResponse( status="unhealthy", services={"error": str(e)}, timestamp=datetime.utcnow().isoformat() ) # Core chat endpoint with memory enhancement @app.post("/chat") async def chat_with_memory(request: ChatRequest): """Ultra-minimal chat endpoint - pure Mem0 + custom endpoint.""" try: logger.info(f"Processing chat request for user: {request.user_id}") result = await mem0_manager.chat_with_memory( message=request.message, user_id=request.user_id, context=request.context, metadata=request.metadata ) return result except Exception as e: logger.error(f"Error in chat endpoint: {e}") raise HTTPException(status_code=500, detail=str(e)) # Memory management endpoints - pure Mem0 passthroughs @app.post("/memories") async def add_memories(request: MemoryAddRequest): """Add memories - pure Mem0 passthrough.""" try: logger.info(f"Adding memories for user: {request.user_id}") result = await mem0_manager.add_memories( messages=request.messages, user_id=request.user_id, agent_id=request.agent_id, run_id=request.run_id, metadata=request.metadata ) return result except Exception as e: logger.error(f"Error adding memories: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/memories/search") async def search_memories(request: MemorySearchRequest): """Search memories - pure Mem0 passthrough.""" try: logger.info(f"Searching memories for user: {request.user_id}, query: {request.query}") result = await mem0_manager.search_memories( query=request.query, user_id=request.user_id, limit=request.limit, threshold=request.threshold, filters=request.filters, agent_id=request.agent_id, run_id=request.run_id ) return result except Exception as e: logger.error(f"Error searching memories: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/memories/{user_id}") async def get_user_memories( user_id: str, limit: int = 10, agent_id: Optional[str] = None, run_id: Optional[str] = None ): """Get all memories for a user with hierarchy filtering - pure Mem0 passthrough.""" try: logger.info(f"Retrieving memories for user: {user_id}") memories = await mem0_manager.get_user_memories( user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id ) return memories except Exception as e: logger.error(f"Error retrieving user memories: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.put("/memories") async def update_memory(request: MemoryUpdateRequest): """Update memory - pure Mem0 passthrough.""" try: logger.info(f"Updating memory: {request.memory_id}") result = await mem0_manager.update_memory( memory_id=request.memory_id, content=request.content, ) return result except Exception as e: logger.error(f"Error updating memory: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.delete("/memories/{memory_id}") async def delete_memory(memory_id: str): """Delete a specific memory.""" try: logger.info(f"Deleting memory: {memory_id}") result = await mem0_manager.delete_memory(memory_id=memory_id) return result except Exception as e: logger.error(f"Error deleting memory: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.delete("/memories/user/{user_id}") async def delete_user_memories(user_id: str): """Delete all memories for a specific user.""" try: logger.info(f"Deleting all memories for user: {user_id}") result = await mem0_manager.delete_user_memories(user_id=user_id) return result except Exception as e: logger.error(f"Error deleting user memories: {e}") raise HTTPException(status_code=500, detail=str(e)) # Graph relationships endpoint - pure Mem0 passthrough @app.get("/graph/relationships/{user_id}") async def get_graph_relationships(user_id: str): """Get graph relationships - pure Mem0 passthrough.""" try: logger.info(f"Retrieving graph relationships for user: {user_id}") result = await mem0_manager.get_graph_relationships(user_id=user_id) return result except Exception as e: logger.error(f"Error retrieving graph relationships: {e}") raise HTTPException(status_code=500, detail=str(e)) # Memory history endpoint - new feature @app.get("/memories/{memory_id}/history") async def get_memory_history(memory_id: str): """Get memory change history - pure Mem0 passthrough.""" try: logger.info(f"Retrieving history for memory: {memory_id}") result = await mem0_manager.get_memory_history(memory_id=memory_id) return result except Exception as e: logger.error(f"Error retrieving memory history: {e}") raise HTTPException(status_code=500, detail=str(e)) # Statistics and monitoring endpoints @app.get("/stats", response_model=GlobalStatsResponse) async def get_global_stats(): """Get global application statistics.""" try: from monitoring import stats # Get basic stats from monitoring basic_stats = stats.get_global_stats() # Get actual memory count from Mem0 (simplified approach) try: # This is a rough estimate - in production you might want a more efficient method sample_result = await mem0_manager.search_memories(query="*", user_id="__stats_check__", limit=1) # For now, we'll use the basic stats total_memories value # You could implement a more accurate count by querying the database directly total_memories = basic_stats['total_memories'] # Will be 0 for now except: total_memories = 0 return GlobalStatsResponse( total_memories=total_memories, total_users=basic_stats['total_users'], api_calls_today=basic_stats['api_calls_today'], avg_response_time_ms=basic_stats['avg_response_time_ms'], memory_operations={ "add": basic_stats['memory_operations']['add'], "search": basic_stats['memory_operations']['search'], "update": basic_stats['memory_operations']['update'], "delete": basic_stats['memory_operations']['delete'] }, uptime_seconds=basic_stats['uptime_seconds'] ) except Exception as e: logger.error(f"Error getting global stats: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/stats/{user_id}", response_model=UserStatsResponse) async def get_user_stats(user_id: str): """Get user-specific statistics.""" try: from monitoring import stats # Get basic user stats from monitoring basic_stats = stats.get_user_stats(user_id) # Get actual memory count for this user try: user_memories = await mem0_manager.get_user_memories(user_id=user_id, limit=1000) memory_count = len(user_memories) except: memory_count = 0 # Get relationship count for this user try: graph_data = await mem0_manager.get_graph_relationships(user_id=user_id) relationship_count = len(graph_data.get('relationships', [])) except: relationship_count = 0 return UserStatsResponse( user_id=user_id, memory_count=memory_count, relationship_count=relationship_count, last_activity=basic_stats['last_activity'], api_calls_today=basic_stats['api_calls_today'], avg_response_time_ms=basic_stats['avg_response_time_ms'] ) except Exception as e: logger.error(f"Error getting user stats for {user_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) # Utility endpoints @app.get("/models") async def get_available_models(): """Get current model configuration.""" return { "current_model": settings.default_model, "endpoint": settings.openai_base_url, "note": "Using single model with pure Mem0 intelligence" } @app.get("/users") async def get_active_users(): """Get list of users with memories (simplified implementation).""" # This would typically query the database for users with memories # For now, return a placeholder return { "message": "This endpoint would return users with stored memories", "note": "Implementation depends on direct database access or Mem0 user enumeration capabilities" } if __name__ == "__main__": import uvicorn uvicorn.run( "main:app", host="0.0.0.0", port=8000, log_level=settings.log_level.lower(), reload=True )