knowledge-base/backend/main.py
Pratik Narola 7689409950 Initial commit: Production-ready Mem0 interface with monitoring
- Complete Mem0 OSS integration with hybrid datastore
- PostgreSQL + pgvector for vector storage
- Neo4j 5.18 for graph relationships
- Google Gemini embeddings integration
- Comprehensive monitoring with correlation IDs
- Real-time statistics and performance tracking
- Production-grade observability features
- Clean repository with no exposed secrets
2025-08-10 17:34:41 +05:30

476 lines
No EOL
15 KiB
Python

"""Main FastAPI application for Mem0 Interface POC."""
import logging
import time
from datetime import datetime
from typing import List, Dict, Any, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import structlog
from config import settings
from models import (
ChatRequest, MemoryAddRequest, MemoryAddResponse,
MemorySearchRequest, MemorySearchResponse, MemoryUpdateRequest,
MemoryItem, GraphResponse, HealthResponse, ErrorResponse,
GlobalStatsResponse, UserStatsResponse
)
from mem0_manager import mem0_manager
# Configure structured logging
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer()
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
logger = structlog.get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
# Startup
logger.info("Starting Mem0 Interface POC")
# Perform health check on startup
health_status = await mem0_manager.health_check()
unhealthy_services = [k for k, v in health_status.items() if "unhealthy" in v]
if unhealthy_services:
logger.warning(f"Some services are unhealthy: {unhealthy_services}")
else:
logger.info("All services are healthy")
yield
# Shutdown
logger.info("Shutting down Mem0 Interface POC")
# Initialize FastAPI app
app = FastAPI(
title="Mem0 Interface POC",
description="Minimal but fully functional Mem0 interface with PostgreSQL and Neo4j integration",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request logging middleware with monitoring
@app.middleware("http")
async def log_requests(request, call_next):
"""Log all HTTP requests with correlation ID and timing."""
from monitoring import generate_correlation_id, stats
correlation_id = generate_correlation_id()
start_time = time.time()
# Extract user_id from request if available
user_id = None
if request.method == "POST":
# Try to extract user_id from request body for POST requests
try:
body = await request.body()
if body:
import json
data = json.loads(body)
user_id = data.get('user_id')
except:
pass
elif "user_id" in str(request.url.path):
# Extract user_id from path for GET requests
path_parts = request.url.path.split('/')
if len(path_parts) > 2 and path_parts[-2] in ['memories', 'stats']:
user_id = path_parts[-1]
# Log start of request
logger.info(
"HTTP request started",
correlation_id=correlation_id,
method=request.method,
path=request.url.path,
user_id=user_id
)
response = await call_next(request)
process_time = time.time() - start_time
process_time_ms = process_time * 1000
# Record statistics
stats.record_api_call(user_id, process_time_ms)
# Log completion with enhanced details
if process_time_ms > 2000: # Slow request threshold
logger.warning(
"HTTP request completed (SLOW)",
correlation_id=correlation_id,
method=request.method,
path=request.url.path,
status_code=response.status_code,
process_time_ms=round(process_time_ms, 2),
user_id=user_id,
slow_request=True
)
elif response.status_code >= 400:
logger.error(
"HTTP request completed (ERROR)",
correlation_id=correlation_id,
method=request.method,
path=request.url.path,
status_code=response.status_code,
process_time_ms=round(process_time_ms, 2),
user_id=user_id,
slow_request=False
)
else:
logger.info(
"HTTP request completed",
correlation_id=correlation_id,
method=request.method,
path=request.url.path,
status_code=response.status_code,
process_time_ms=round(process_time_ms, 2),
user_id=user_id,
slow_request=False
)
return response
# Exception handlers
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""Global exception handler."""
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": "Internal server error", "detail": str(exc)}
)
# Health check endpoint
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Check the health of all services."""
try:
services = await mem0_manager.health_check()
overall_status = "healthy" if all("healthy" in status for status in services.values()) else "degraded"
return HealthResponse(
status=overall_status,
services=services,
timestamp=datetime.utcnow().isoformat()
)
except Exception as e:
logger.error(f"Health check failed: {e}")
return HealthResponse(
status="unhealthy",
services={"error": str(e)},
timestamp=datetime.utcnow().isoformat()
)
# Core chat endpoint with memory enhancement
@app.post("/chat")
async def chat_with_memory(request: ChatRequest):
"""Ultra-minimal chat endpoint - pure Mem0 + custom endpoint."""
try:
logger.info(f"Processing chat request for user: {request.user_id}")
result = await mem0_manager.chat_with_memory(
message=request.message,
user_id=request.user_id,
context=request.context,
metadata=request.metadata
)
return result
except Exception as e:
logger.error(f"Error in chat endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Memory management endpoints - pure Mem0 passthroughs
@app.post("/memories")
async def add_memories(request: MemoryAddRequest):
"""Add memories - pure Mem0 passthrough."""
try:
logger.info(f"Adding memories for user: {request.user_id}")
result = await mem0_manager.add_memories(
messages=request.messages,
user_id=request.user_id,
agent_id=request.agent_id,
run_id=request.run_id,
metadata=request.metadata
)
return result
except Exception as e:
logger.error(f"Error adding memories: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/memories/search")
async def search_memories(request: MemorySearchRequest):
"""Search memories - pure Mem0 passthrough."""
try:
logger.info(f"Searching memories for user: {request.user_id}, query: {request.query}")
result = await mem0_manager.search_memories(
query=request.query,
user_id=request.user_id,
limit=request.limit,
threshold=request.threshold,
filters=request.filters,
agent_id=request.agent_id,
run_id=request.run_id
)
return result
except Exception as e:
logger.error(f"Error searching memories: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/memories/{user_id}")
async def get_user_memories(
user_id: str,
limit: int = 10,
agent_id: Optional[str] = None,
run_id: Optional[str] = None
):
"""Get all memories for a user with hierarchy filtering - pure Mem0 passthrough."""
try:
logger.info(f"Retrieving memories for user: {user_id}")
memories = await mem0_manager.get_user_memories(
user_id=user_id,
limit=limit,
agent_id=agent_id,
run_id=run_id
)
return memories
except Exception as e:
logger.error(f"Error retrieving user memories: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.put("/memories")
async def update_memory(request: MemoryUpdateRequest):
"""Update memory - pure Mem0 passthrough."""
try:
logger.info(f"Updating memory: {request.memory_id}")
result = await mem0_manager.update_memory(
memory_id=request.memory_id,
content=request.content,
)
return result
except Exception as e:
logger.error(f"Error updating memory: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/memories/{memory_id}")
async def delete_memory(memory_id: str):
"""Delete a specific memory."""
try:
logger.info(f"Deleting memory: {memory_id}")
result = await mem0_manager.delete_memory(memory_id=memory_id)
return result
except Exception as e:
logger.error(f"Error deleting memory: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/memories/user/{user_id}")
async def delete_user_memories(user_id: str):
"""Delete all memories for a specific user."""
try:
logger.info(f"Deleting all memories for user: {user_id}")
result = await mem0_manager.delete_user_memories(user_id=user_id)
return result
except Exception as e:
logger.error(f"Error deleting user memories: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Graph relationships endpoint - pure Mem0 passthrough
@app.get("/graph/relationships/{user_id}")
async def get_graph_relationships(user_id: str):
"""Get graph relationships - pure Mem0 passthrough."""
try:
logger.info(f"Retrieving graph relationships for user: {user_id}")
result = await mem0_manager.get_graph_relationships(user_id=user_id)
return result
except Exception as e:
logger.error(f"Error retrieving graph relationships: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Memory history endpoint - new feature
@app.get("/memories/{memory_id}/history")
async def get_memory_history(memory_id: str):
"""Get memory change history - pure Mem0 passthrough."""
try:
logger.info(f"Retrieving history for memory: {memory_id}")
result = await mem0_manager.get_memory_history(memory_id=memory_id)
return result
except Exception as e:
logger.error(f"Error retrieving memory history: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Statistics and monitoring endpoints
@app.get("/stats", response_model=GlobalStatsResponse)
async def get_global_stats():
"""Get global application statistics."""
try:
from monitoring import stats
# Get basic stats from monitoring
basic_stats = stats.get_global_stats()
# Get actual memory count from Mem0 (simplified approach)
try:
# This is a rough estimate - in production you might want a more efficient method
sample_result = await mem0_manager.search_memories(query="*", user_id="__stats_check__", limit=1)
# For now, we'll use the basic stats total_memories value
# You could implement a more accurate count by querying the database directly
total_memories = basic_stats['total_memories'] # Will be 0 for now
except:
total_memories = 0
return GlobalStatsResponse(
total_memories=total_memories,
total_users=basic_stats['total_users'],
api_calls_today=basic_stats['api_calls_today'],
avg_response_time_ms=basic_stats['avg_response_time_ms'],
memory_operations={
"add": basic_stats['memory_operations']['add'],
"search": basic_stats['memory_operations']['search'],
"update": basic_stats['memory_operations']['update'],
"delete": basic_stats['memory_operations']['delete']
},
uptime_seconds=basic_stats['uptime_seconds']
)
except Exception as e:
logger.error(f"Error getting global stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stats/{user_id}", response_model=UserStatsResponse)
async def get_user_stats(user_id: str):
"""Get user-specific statistics."""
try:
from monitoring import stats
# Get basic user stats from monitoring
basic_stats = stats.get_user_stats(user_id)
# Get actual memory count for this user
try:
user_memories = await mem0_manager.get_user_memories(user_id=user_id, limit=1000)
memory_count = len(user_memories)
except:
memory_count = 0
# Get relationship count for this user
try:
graph_data = await mem0_manager.get_graph_relationships(user_id=user_id)
relationship_count = len(graph_data.get('relationships', []))
except:
relationship_count = 0
return UserStatsResponse(
user_id=user_id,
memory_count=memory_count,
relationship_count=relationship_count,
last_activity=basic_stats['last_activity'],
api_calls_today=basic_stats['api_calls_today'],
avg_response_time_ms=basic_stats['avg_response_time_ms']
)
except Exception as e:
logger.error(f"Error getting user stats for {user_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Utility endpoints
@app.get("/models")
async def get_available_models():
"""Get current model configuration."""
return {
"current_model": settings.default_model,
"endpoint": settings.openai_base_url,
"note": "Using single model with pure Mem0 intelligence"
}
@app.get("/users")
async def get_active_users():
"""Get list of users with memories (simplified implementation)."""
# This would typically query the database for users with memories
# For now, return a placeholder
return {
"message": "This endpoint would return users with stored memories",
"note": "Implementation depends on direct database access or Mem0 user enumeration capabilities"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
log_level=settings.log_level.lower(),
reload=True
)