"""MCP Server for Mem0 Memory Service. Exposes memory operations as MCP tools over HTTP with API key authentication. """ import contextlib import logging from contextvars import ContextVar from typing import Optional from pydantic import Field from starlette.applications import Starlette from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware from starlette.responses import JSONResponse from starlette.routing import Mount from mcp.server.fastmcp import FastMCP from config import settings logger = logging.getLogger(__name__) # Context variable for authenticated user_id # Set by middleware, read by tools current_user_id: ContextVar[str] = ContextVar("current_user_id", default="") def get_authenticated_user() -> str: """Get the authenticated user_id from context. Raises: ValueError: If no authenticated user in context. """ user_id = current_user_id.get() if not user_id: raise ValueError("No authenticated user in context") return user_id class MCPAuthMiddleware(BaseHTTPMiddleware): """Middleware to authenticate MCP requests using API key.""" async def dispatch(self, request, call_next): # Extract API key from headers api_key = ( request.headers.get("x-api-key") or request.headers.get("X-API-Key") or request.headers.get("Authorization", "").replace("Bearer ", "") ) if not api_key: return JSONResponse( {"error": "Missing API key. Provide x-api-key or Authorization header."}, status_code=401 ) # Map API key to user_id user_id = settings.api_key_mapping.get(api_key) if not user_id: return JSONResponse( {"error": "Invalid API key"}, status_code=401 ) # Store user_id in context for tools to access token = current_user_id.set(user_id) try: response = await call_next(request) return response finally: current_user_id.reset(token) # Create FastMCP server # streamable_http_path="/" since we mount at /mcp in main.py mcp = FastMCP( "Mem0 Memory Service", stateless_http=True, json_response=True, streamable_http_path="/" ) @mcp.tool() async def add_memory( content: str = Field(description="Content to add to memory"), agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"), run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"), metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"), ) -> dict: """Add content to the user's memory. The user_id is automatically determined from the API key authentication. Use agent_id and run_id for multi-agent or session-based memory organization. """ from mem0_manager import mem0_manager user_id = get_authenticated_user() logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}") result = await mem0_manager.add_memories( messages=[{"role": "user", "content": content}], user_id=user_id, agent_id=agent_id, run_id=run_id, metadata=metadata ) return result @mcp.tool() async def search_memory( query: str = Field(description="Search query to find relevant memories"), agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"), run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"), limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"), ) -> dict: """Search the user's memories. Returns memories most relevant to the query. The user_id is automatically determined from API key authentication. """ from mem0_manager import mem0_manager user_id = get_authenticated_user() logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}") result = await mem0_manager.search_memories( query=query, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit ) return result @mcp.tool() async def remove_memory( memory_id: str = Field(description="The ID of the memory to remove"), ) -> dict: """Remove a specific memory by its ID. Only memories belonging to the authenticated user can be deleted. Verifies ownership before deletion. """ from mem0_manager import mem0_manager user_id = get_authenticated_user() logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}") # Verify ownership: get user's memories and check if memory_id exists user_memories = await mem0_manager.get_user_memories( user_id=user_id, limit=10000 # Get all to check ownership ) memory_ids = {m.get("id") for m in user_memories if m.get("id")} if memory_id not in memory_ids: raise ValueError(f"Memory '{memory_id}' not found or access denied") result = await mem0_manager.delete_memory(memory_id=memory_id) return result @mcp.tool() async def chat( message: str = Field(description="The user's message to chat with"), agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"), run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"), ) -> str: """Chat with memory context. Retrieves relevant memories based on the message, generates a response using the configured LLM, and stores the conversation in memory. The user_id is automatically determined from API key authentication. """ from mem0_manager import mem0_manager user_id = get_authenticated_user() logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...") result = await mem0_manager.chat_with_memory( message=message, user_id=user_id, agent_id=agent_id, run_id=run_id ) return result.get("response", "") @contextlib.asynccontextmanager async def mcp_lifespan(): """Context manager for MCP session lifecycle. Must be used in the main FastAPI lifespan since mounted app lifespans don't run automatically. """ async with mcp.session_manager.run(): logger.info("MCP session manager started") yield logger.info("MCP session manager stopped") def create_mcp_app() -> Starlette: """Create and configure the MCP Starlette application. Returns a Starlette app with MCP endpoints, authentication middleware, and CORS support. Note: The MCP session manager must be started via mcp_lifespan() in the main FastAPI lifespan, not here. """ # Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp streamable_app = mcp.streamable_http_app() # Create Starlette app with MCP routes app = Starlette( routes=[Mount("/", app=streamable_app)], ) # Add authentication middleware app.add_middleware(MCPAuthMiddleware) # Add CORS middleware for browser clients app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["GET", "POST", "DELETE", "OPTIONS"], allow_headers=["*"], expose_headers=["Mcp-Session-Id"], ) return app