From 35c1bbec4e6bec73a3606174985d0abc433a9b7f Mon Sep 17 00:00:00 2001 From: Pratik Narola Date: Sun, 11 Jan 2026 14:00:16 +0530 Subject: [PATCH] added MCP HTTP endpoint with auth Exposes memory operations as MCP tools over /mcp endpoint: - add_memory, search_memory, remove_memory, chat - API key auth via x-api-key or Authorization header - User isolation enforced via contextvars --- backend/main.py | 37 +++++- backend/mcp_server.py | 240 +++++++++++++++++++++++++++++++++++++++ backend/requirements.txt | 3 + 3 files changed, 276 insertions(+), 4 deletions(-) create mode 100644 backend/mcp_server.py diff --git a/backend/main.py b/backend/main.py index ca8d6ea..3a6d6db 100644 --- a/backend/main.py +++ b/backend/main.py @@ -48,19 +48,36 @@ async def lifespan(app: FastAPI): """Application lifespan manager.""" # Startup logger.info("Starting Mem0 Interface POC") - + # Perform health check on startup health_status = await mem0_manager.health_check() unhealthy_services = [k for k, v in health_status.items() if "unhealthy" in v] - + if unhealthy_services: logger.warning(f"Some services are unhealthy: {unhealthy_services}") else: logger.info("All services are healthy") - + + # Start MCP session manager if available + mcp_context = None + try: + from mcp_server import mcp_lifespan + mcp_context = mcp_lifespan() + await mcp_context.__aenter__() + except ImportError: + logger.warning("MCP server not available") + except Exception as e: + logger.error(f"Failed to start MCP session manager: {e}") + yield - + # Shutdown + if mcp_context: + try: + await mcp_context.__aexit__(None, None, None) + except Exception as e: + logger.error(f"Error stopping MCP session manager: {e}") + logger.info("Shutting down Mem0 Interface POC") @@ -578,6 +595,18 @@ async def get_active_users(): } +# Mount MCP server at /mcp endpoint +try: + from mcp_server import create_mcp_app + mcp_app = create_mcp_app() + app.mount("/mcp", mcp_app) + logger.info("MCP server mounted at /mcp") +except ImportError as e: + logger.warning(f"MCP server not available (missing dependencies): {e}") +except Exception as e: + logger.error(f"Failed to mount MCP server: {e}") + + if __name__ == "__main__": import uvicorn print("Starting UVicorn server...") diff --git a/backend/mcp_server.py b/backend/mcp_server.py new file mode 100644 index 0000000..a7b7ae4 --- /dev/null +++ b/backend/mcp_server.py @@ -0,0 +1,240 @@ +"""MCP Server for Mem0 Memory Service. + +Exposes memory operations as MCP tools over HTTP with API key authentication. +""" + +import contextlib +import logging +from contextvars import ContextVar +from typing import Optional + +from pydantic import Field +from starlette.applications import Starlette +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.cors import CORSMiddleware +from starlette.responses import JSONResponse +from starlette.routing import Mount +from mcp.server.fastmcp import FastMCP + +from config import settings + +logger = logging.getLogger(__name__) + +# Context variable for authenticated user_id +# Set by middleware, read by tools +current_user_id: ContextVar[str] = ContextVar("current_user_id", default="") + + +def get_authenticated_user() -> str: + """Get the authenticated user_id from context. + + Raises: + ValueError: If no authenticated user in context. + """ + user_id = current_user_id.get() + if not user_id: + raise ValueError("No authenticated user in context") + return user_id + + +class MCPAuthMiddleware(BaseHTTPMiddleware): + """Middleware to authenticate MCP requests using API key.""" + + async def dispatch(self, request, call_next): + # Extract API key from headers + api_key = ( + request.headers.get("x-api-key") or + request.headers.get("X-API-Key") or + request.headers.get("Authorization", "").replace("Bearer ", "") + ) + + if not api_key: + return JSONResponse( + {"error": "Missing API key. Provide x-api-key or Authorization header."}, + status_code=401 + ) + + # Map API key to user_id + user_id = settings.api_key_mapping.get(api_key) + if not user_id: + return JSONResponse( + {"error": "Invalid API key"}, + status_code=401 + ) + + # Store user_id in context for tools to access + token = current_user_id.set(user_id) + try: + response = await call_next(request) + return response + finally: + current_user_id.reset(token) + + +# Create FastMCP server +# streamable_http_path="/" since we mount at /mcp in main.py +mcp = FastMCP( + "Mem0 Memory Service", + stateless_http=True, + json_response=True, + streamable_http_path="/" +) + + +@mcp.tool() +async def add_memory( + content: str = Field(description="Content to add to memory"), + agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"), + run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"), + metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"), +) -> dict: + """Add content to the user's memory. + + The user_id is automatically determined from the API key authentication. + Use agent_id and run_id for multi-agent or session-based memory organization. + """ + from mem0_manager import mem0_manager + + user_id = get_authenticated_user() + logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}") + + result = await mem0_manager.add_memories( + messages=[{"role": "user", "content": content}], + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + metadata=metadata + ) + + return result + + +@mcp.tool() +async def search_memory( + query: str = Field(description="Search query to find relevant memories"), + agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"), + run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"), + limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"), +) -> dict: + """Search the user's memories. + + Returns memories most relevant to the query. The user_id is automatically + determined from API key authentication. + """ + from mem0_manager import mem0_manager + + user_id = get_authenticated_user() + logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}") + + result = await mem0_manager.search_memories( + query=query, + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + limit=limit + ) + + return result + + +@mcp.tool() +async def remove_memory( + memory_id: str = Field(description="The ID of the memory to remove"), +) -> dict: + """Remove a specific memory by its ID. + + Only memories belonging to the authenticated user can be deleted. + Verifies ownership before deletion. + """ + from mem0_manager import mem0_manager + + user_id = get_authenticated_user() + logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}") + + # Verify ownership: get user's memories and check if memory_id exists + user_memories = await mem0_manager.get_user_memories( + user_id=user_id, + limit=10000 # Get all to check ownership + ) + + memory_ids = {m.get("id") for m in user_memories if m.get("id")} + + if memory_id not in memory_ids: + raise ValueError(f"Memory '{memory_id}' not found or access denied") + + result = await mem0_manager.delete_memory(memory_id=memory_id) + + return result + + +@mcp.tool() +async def chat( + message: str = Field(description="The user's message to chat with"), + agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"), + run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"), +) -> str: + """Chat with memory context. + + Retrieves relevant memories based on the message, generates a response + using the configured LLM, and stores the conversation in memory. + The user_id is automatically determined from API key authentication. + """ + from mem0_manager import mem0_manager + + user_id = get_authenticated_user() + logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...") + + result = await mem0_manager.chat_with_memory( + message=message, + user_id=user_id, + agent_id=agent_id, + run_id=run_id + ) + + return result.get("response", "") + + +@contextlib.asynccontextmanager +async def mcp_lifespan(): + """Context manager for MCP session lifecycle. + + Must be used in the main FastAPI lifespan since mounted app + lifespans don't run automatically. + """ + async with mcp.session_manager.run(): + logger.info("MCP session manager started") + yield + logger.info("MCP session manager stopped") + + +def create_mcp_app() -> Starlette: + """Create and configure the MCP Starlette application. + + Returns a Starlette app with MCP endpoints, authentication middleware, + and CORS support. + + Note: The MCP session manager must be started via mcp_lifespan() in + the main FastAPI lifespan, not here. + """ + # Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp + streamable_app = mcp.streamable_http_app() + + # Create Starlette app with MCP routes + app = Starlette( + routes=[Mount("/", app=streamable_app)], + ) + + # Add authentication middleware + app.add_middleware(MCPAuthMiddleware) + + # Add CORS middleware for browser clients + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["GET", "POST", "DELETE", "OPTIONS"], + allow_headers=["*"], + expose_headers=["Mcp-Session-Id"], + ) + + return app diff --git a/backend/requirements.txt b/backend/requirements.txt index a918526..1ad77f8 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -31,3 +31,6 @@ python-json-logger # CORS and Security python-jose[cryptography] passlib[bcrypt] + +# MCP Server +mcp[server]>=1.0.0