knowledge-base/backend/main.py
Pratik Narola a228780146 production improvements: configurable embeddings, v1.1, O(1) ownership, retries
- Make Ollama URL configurable via OLLAMA_BASE_URL env var
- Add version: v1.1 to Mem0 config (required for latest features)
- Make embedding model and dimensions configurable
- Fix ownership check: O(1) lookup instead of fetching 10k records
- Add tenacity retry logic for database operations
2026-01-15 23:01:18 +05:30

787 lines
26 KiB
Python

"""Main FastAPI application for Mem0 Interface POC."""
import json
import logging
import time
from datetime import datetime
from typing import List, Dict, Any, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import structlog
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from config import settings
# Rate limiter - uses IP address as key, falls back to API key for authenticated requests
def get_rate_limit_key(request: Request) -> str:
"""Get rate limit key - prefer API key if available, otherwise IP."""
api_key = request.headers.get("x-api-key", "")
if api_key:
return f"apikey:{api_key[:16]}" # Use first 16 chars of API key
return get_remote_address(request)
limiter = Limiter(key_func=get_rate_limit_key)
from models import (
ChatRequest,
MemoryAddRequest,
MemoryAddResponse,
MemorySearchRequest,
MemorySearchResponse,
MemoryUpdateRequest,
MemoryItem,
GraphResponse,
HealthResponse,
ErrorResponse,
GlobalStatsResponse,
UserStatsResponse,
)
from mem0_manager import mem0_manager
from auth import get_current_user, auth_service
# Configure structured logging
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer(),
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
logger = structlog.get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
# Startup
logger.info("Starting Mem0 Interface POC")
# Perform health check on startup
health_status = await mem0_manager.health_check()
unhealthy_services = [k for k, v in health_status.items() if "unhealthy" in v]
if unhealthy_services:
logger.warning(f"Some services are unhealthy: {unhealthy_services}")
else:
logger.info("All services are healthy")
# Start MCP session manager if available
mcp_context = None
try:
from mcp_server import mcp_lifespan
mcp_context = mcp_lifespan()
await mcp_context.__aenter__()
except ImportError:
logger.warning("MCP server not available")
except Exception as e:
logger.error(f"Failed to start MCP session manager: {e}")
yield
# Shutdown
if mcp_context:
try:
await mcp_context.__aexit__(None, None, None)
except Exception as e:
logger.error(f"Error stopping MCP session manager: {e}")
logger.info("Shutting down Mem0 Interface POC")
# Initialize FastAPI app
app = FastAPI(
title="Mem0 Interface POC",
description="Minimal but fully functional Mem0 interface with PostgreSQL and Neo4j integration",
version="1.0.0",
lifespan=lifespan,
)
# Add rate limiter to app state and exception handler
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Add CORS middleware - Allow all origins (secured via API key auth)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins - secured via API key authentication
allow_credentials=False, # Must be False when allow_origins=["*"]
allow_methods=["*"],
allow_headers=["*"],
)
# Request size limit middleware - prevent DoS via large payloads
MAX_REQUEST_SIZE = 10 * 1024 * 1024 # 10MB limit
@app.middleware("http")
async def limit_request_size(request, call_next):
"""Reject requests that exceed the maximum allowed size."""
content_length = request.headers.get("content-length")
if content_length:
try:
if int(content_length) > MAX_REQUEST_SIZE:
return JSONResponse(
status_code=413,
content={
"error": "Request payload too large",
"max_size_bytes": MAX_REQUEST_SIZE,
"max_size_mb": MAX_REQUEST_SIZE / (1024 * 1024),
},
)
except ValueError:
pass # Invalid content-length header, let it through for other validation
return await call_next(request)
# Request logging middleware with monitoring
@app.middleware("http")
async def log_requests(request, call_next):
"""Log all HTTP requests with correlation ID and timing."""
from monitoring import generate_correlation_id, stats
correlation_id = generate_correlation_id()
start_time = time.time()
# Extract user_id from request if available
user_id = None
if request.method == "POST":
try:
body = await request.body()
if body:
data = json.loads(body)
user_id = data.get("user_id")
except json.JSONDecodeError:
pass # Non-JSON body, user_id extraction not possible
except Exception as e:
logger.debug("Could not extract user_id from request body", error=str(e))
elif "user_id" in str(request.url.path):
# Extract user_id from path for GET requests
path_parts = request.url.path.split("/")
if len(path_parts) > 2 and path_parts[-2] in ["memories", "stats"]:
user_id = path_parts[-1]
# Log start of request
logger.info(
"HTTP request started",
correlation_id=correlation_id,
method=request.method,
path=request.url.path,
user_id=user_id,
)
response = await call_next(request)
process_time = time.time() - start_time
process_time_ms = process_time * 1000
# Record statistics
stats.record_api_call(user_id, process_time_ms)
# Log completion with enhanced details
if process_time_ms > 2000: # Slow request threshold
logger.warning(
"HTTP request completed (SLOW)",
correlation_id=correlation_id,
method=request.method,
path=request.url.path,
status_code=response.status_code,
process_time_ms=round(process_time_ms, 2),
user_id=user_id,
slow_request=True,
)
elif response.status_code >= 400:
logger.error(
"HTTP request completed (ERROR)",
correlation_id=correlation_id,
method=request.method,
path=request.url.path,
status_code=response.status_code,
process_time_ms=round(process_time_ms, 2),
user_id=user_id,
slow_request=False,
)
else:
logger.info(
"HTTP request completed",
correlation_id=correlation_id,
method=request.method,
path=request.url.path,
status_code=response.status_code,
process_time_ms=round(process_time_ms, 2),
user_id=user_id,
slow_request=False,
)
return response
# Exception handlers
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""Global exception handler - logs details but returns generic message."""
# Log full exception details for debugging (internal only)
logger.error(
"Unhandled exception",
exc_info=True,
path=request.url.path,
method=request.method,
error_type=type(exc).__name__,
error_message=str(exc),
)
# Return generic error to client - don't expose internal details
return JSONResponse(
status_code=500,
content={
"error": "Internal server error",
"message": "An unexpected error occurred",
},
)
# Health check endpoint
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Check the health of all services."""
try:
services = await mem0_manager.health_check()
overall_status = (
"healthy"
if all("healthy" in status for status in services.values())
else "degraded"
)
return HealthResponse(
status=overall_status,
services=services,
timestamp=datetime.utcnow().isoformat(),
)
except Exception as e:
logger.error(f"Health check failed: {e}", exc_info=True)
return HealthResponse(
status="unhealthy",
services={"error": "Health check failed - see logs for details"},
timestamp=datetime.utcnow().isoformat(),
)
# Core chat endpoint with memory enhancement
@app.post("/chat")
@limiter.limit("30/minute") # Chat is expensive - limit to 30/min
async def chat_with_memory(
request: Request,
chat_request: ChatRequest,
authenticated_user: str = Depends(get_current_user),
):
"""Ultra-minimal chat endpoint - pure Mem0 + custom endpoint."""
try:
# Verify user can only access their own data
if authenticated_user != chat_request.user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only chat as yourself (authenticated as '{authenticated_user}')",
)
logger.info(f"Processing chat request for user: {chat_request.user_id}")
# Convert ChatMessage objects to dict format if context provided
context_dict = None
if chat_request.context:
context_dict = [
{"role": msg.role, "content": msg.content}
for msg in chat_request.context
]
result = await mem0_manager.chat_with_memory(
message=chat_request.message,
user_id=chat_request.user_id,
agent_id=chat_request.agent_id,
run_id=chat_request.run_id,
context=context_dict,
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in chat endpoint: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Memory management endpoints - pure Mem0 passthroughs
@app.post("/memories")
@limiter.limit("60/minute") # Memory operations - 60/min
async def add_memories(
request: Request,
memory_request: MemoryAddRequest,
authenticated_user: str = Depends(get_current_user),
):
"""Add memories - pure Mem0 passthrough."""
try:
# Verify user can only add to their own memories
if authenticated_user != memory_request.user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only add memories for yourself (authenticated as '{authenticated_user}')",
)
logger.info(f"Adding memories for user: {memory_request.user_id}")
result = await mem0_manager.add_memories(
messages=memory_request.messages,
user_id=memory_request.user_id,
agent_id=memory_request.agent_id,
run_id=memory_request.run_id,
metadata=memory_request.metadata,
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error adding memories: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.post("/memories/search")
@limiter.limit("120/minute") # Search is lighter - 120/min
async def search_memories(
request: Request,
search_request: MemorySearchRequest,
authenticated_user: str = Depends(get_current_user),
):
"""Search memories - pure Mem0 passthrough."""
try:
# Verify user can only search their own memories
if authenticated_user != search_request.user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only search your own memories (authenticated as '{authenticated_user}')",
)
logger.info(
f"Searching memories for user: {search_request.user_id}, query: {search_request.query}"
)
result = await mem0_manager.search_memories(
query=search_request.query,
user_id=search_request.user_id,
limit=search_request.limit,
threshold=search_request.threshold or 0.2,
filters=search_request.filters,
agent_id=search_request.agent_id,
run_id=search_request.run_id,
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error searching memories: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.get("/memories/{user_id}")
@limiter.limit("120/minute")
async def get_user_memories(
request: Request,
user_id: str,
authenticated_user: str = Depends(get_current_user),
limit: int = 10,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
):
"""Get all memories for a user with hierarchy filtering - pure Mem0 passthrough."""
try:
# Verify user can only retrieve their own memories
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only retrieve your own memories (authenticated as '{authenticated_user}')",
)
logger.info(f"Retrieving memories for user: {user_id}")
memories = await mem0_manager.get_user_memories(
user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id
)
return memories
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving user memories: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.put("/memories")
@limiter.limit("60/minute")
async def update_memory(
request: Request,
update_request: MemoryUpdateRequest,
authenticated_user: str = Depends(get_current_user),
):
"""Update memory - verifies ownership before update."""
try:
# Verify user owns the memory being updated
if authenticated_user != update_request.user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')",
)
# Verify memory ownership with O(1) lookup instead of fetching all memories
if not await mem0_manager.verify_memory_ownership(
update_request.memory_id, authenticated_user
):
raise HTTPException(
status_code=404,
detail=f"Memory '{update_request.memory_id}' not found or access denied",
)
logger.info(
f"Updating memory: {update_request.memory_id}", user_id=authenticated_user
)
result = await mem0_manager.update_memory(
memory_id=update_request.memory_id,
content=update_request.content,
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error updating memory: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.delete("/memories/{memory_id}")
@limiter.limit("60/minute")
async def delete_memory(
request: Request,
memory_id: str,
authenticated_user: str = Depends(get_current_user),
):
"""Delete a specific memory - verifies ownership before deletion."""
try:
# Verify memory ownership with O(1) lookup instead of fetching all memories
if not await mem0_manager.verify_memory_ownership(
memory_id, authenticated_user
):
raise HTTPException(
status_code=404,
detail=f"Memory '{memory_id}' not found or access denied",
)
logger.info(f"Deleting memory: {memory_id}", user_id=authenticated_user)
result = await mem0_manager.delete_memory(memory_id=memory_id)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting memory: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.delete("/memories/user/{user_id}")
@limiter.limit("10/minute") # Dangerous bulk delete - heavily rate limited
async def delete_user_memories(
request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
):
"""Delete all memories for a specific user."""
try:
# Verify user can only delete their own memories
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')",
)
logger.info(f"Deleting all memories for user: {user_id}")
result = await mem0_manager.delete_user_memories(user_id=user_id)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting user memories: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Graph relationships endpoint - pure Mem0 passthrough
@app.get("/graph/relationships/{user_id}")
@limiter.limit("60/minute")
async def get_graph_relationships(
request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
):
"""Get graph relationships - pure Mem0 passthrough."""
try:
# Verify user can only access their own graph relationships
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only view your own relationships (authenticated as '{authenticated_user}')",
)
logger.info(f"Retrieving graph relationships for user: {user_id}")
result = await mem0_manager.get_graph_relationships(
user_id=user_id, agent_id=None, run_id=None, limit=10000
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving graph relationships: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Memory history endpoint - new feature
@app.get("/memories/{memory_id}/history")
@limiter.limit("120/minute")
async def get_memory_history(
request: Request,
memory_id: str,
user_id: str, # Required query param to verify ownership
authenticated_user: str = Depends(get_current_user),
):
"""Get memory change history - pure Mem0 passthrough."""
try:
# Verify user can only access their own memory history
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only view your own memory history (authenticated as '{authenticated_user}')",
)
# Verify memory ownership with O(1) lookup instead of fetching all memories
if not await mem0_manager.verify_memory_ownership(memory_id, user_id):
raise HTTPException(
status_code=404,
detail=f"Memory '{memory_id}' not found or access denied",
)
logger.info(f"Retrieving history for memory: {memory_id}", user_id=user_id)
result = await mem0_manager.get_memory_history(memory_id=memory_id)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving memory history: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Statistics and monitoring endpoints
@app.get("/stats", response_model=GlobalStatsResponse)
@limiter.limit("60/minute")
async def get_global_stats(
request: Request, authenticated_user: str = Depends(get_current_user)
):
"""Get global application statistics - requires authentication."""
try:
from monitoring import stats
basic_stats = stats.get_global_stats()
try:
sample_result = await mem0_manager.search_memories(
query="*", user_id="__stats_check__", limit=1
)
total_memories = basic_stats["total_memories"]
except Exception:
total_memories = 0
return GlobalStatsResponse(
total_memories=total_memories,
total_users=basic_stats["total_users"],
api_calls_today=basic_stats["api_calls_today"],
avg_response_time_ms=basic_stats["avg_response_time_ms"],
memory_operations={
"add": basic_stats["memory_operations"]["add"],
"search": basic_stats["memory_operations"]["search"],
"update": basic_stats["memory_operations"]["update"],
"delete": basic_stats["memory_operations"]["delete"],
},
uptime_seconds=basic_stats["uptime_seconds"],
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting global stats: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
@app.get("/stats/{user_id}", response_model=UserStatsResponse)
@limiter.limit("120/minute")
async def get_user_stats(
request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
):
"""Get user-specific statistics."""
try:
# Verify user can only view their own stats
if authenticated_user != user_id:
raise HTTPException(
status_code=403,
detail=f"Access denied: You can only view your own statistics (authenticated as '{authenticated_user}')",
)
from monitoring import stats
# Get basic user stats from monitoring
basic_stats = stats.get_user_stats(user_id)
# Get actual memory count for this user
try:
user_memories = await mem0_manager.get_user_memories(
user_id=user_id, limit=10000
)
memory_count = len(user_memories)
except Exception as e:
logger.warning(f"Failed to get memory count for user {user_id}: {e}")
memory_count = 0
# Get relationship count for this user
try:
graph_data = await mem0_manager.get_graph_relationships(
user_id=user_id, agent_id=None, run_id=None
)
relationship_count = len(graph_data.get("relationships", []))
except Exception as e:
logger.warning(f"Failed to get relationship count for user {user_id}: {e}")
relationship_count = 0
return UserStatsResponse(
user_id=user_id,
memory_count=memory_count,
relationship_count=relationship_count,
last_activity=basic_stats["last_activity"],
api_calls_today=basic_stats["api_calls_today"],
avg_response_time_ms=basic_stats["avg_response_time_ms"],
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting user stats for {user_id}: {e}")
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later.",
)
# Utility endpoints
@app.get("/models")
@limiter.limit("120/minute")
async def get_available_models(
request: Request, authenticated_user: str = Depends(get_current_user)
):
"""Get current model configuration - requires authentication."""
return {
"current_model": settings.default_model,
"endpoint": settings.openai_base_url,
"note": "Using single model with pure Mem0 intelligence",
}
@app.get("/users")
@limiter.limit("60/minute")
async def get_active_users(
request: Request, authenticated_user: str = Depends(get_current_user)
):
"""Get list of users with memories (simplified implementation) - requires authentication."""
# This would typically query the database for users with memories
# For now, return a placeholder
return {
"message": "This endpoint would return users with stored memories",
"note": "Implementation depends on direct database access or Mem0 user enumeration capabilities",
}
# Mount MCP server at /mcp endpoint
try:
from mcp_server import create_mcp_app
mcp_app = create_mcp_app()
app.mount("/mcp", mcp_app)
logger.info("MCP server mounted at /mcp")
except ImportError as e:
logger.warning(f"MCP server not available (missing dependencies): {e}")
except Exception as e:
logger.error(f"Failed to mount MCP server: {e}")
if __name__ == "__main__":
import uvicorn
print("Starting UVicorn server...")
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
log_level=settings.log_level.lower(),
reload=True,
)