added MCP HTTP endpoint with auth

Exposes memory operations as MCP tools over /mcp endpoint:
- add_memory, search_memory, remove_memory, chat
- API key auth via x-api-key or Authorization header
- User isolation enforced via contextvars
This commit is contained in:
Pratik Narola 2026-01-11 14:00:16 +05:30
parent 997865283f
commit 35c1bbec4e
3 changed files with 276 additions and 4 deletions

View file

@ -58,9 +58,26 @@ async def lifespan(app: FastAPI):
else: else:
logger.info("All services are healthy") logger.info("All services are healthy")
# Start MCP session manager if available
mcp_context = None
try:
from mcp_server import mcp_lifespan
mcp_context = mcp_lifespan()
await mcp_context.__aenter__()
except ImportError:
logger.warning("MCP server not available")
except Exception as e:
logger.error(f"Failed to start MCP session manager: {e}")
yield yield
# Shutdown # Shutdown
if mcp_context:
try:
await mcp_context.__aexit__(None, None, None)
except Exception as e:
logger.error(f"Error stopping MCP session manager: {e}")
logger.info("Shutting down Mem0 Interface POC") logger.info("Shutting down Mem0 Interface POC")
@ -578,6 +595,18 @@ async def get_active_users():
} }
# Mount MCP server at /mcp endpoint
try:
from mcp_server import create_mcp_app
mcp_app = create_mcp_app()
app.mount("/mcp", mcp_app)
logger.info("MCP server mounted at /mcp")
except ImportError as e:
logger.warning(f"MCP server not available (missing dependencies): {e}")
except Exception as e:
logger.error(f"Failed to mount MCP server: {e}")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
print("Starting UVicorn server...") print("Starting UVicorn server...")

240
backend/mcp_server.py Normal file
View file

@ -0,0 +1,240 @@
"""MCP Server for Mem0 Memory Service.
Exposes memory operations as MCP tools over HTTP with API key authentication.
"""
import contextlib
import logging
from contextvars import ContextVar
from typing import Optional
from pydantic import Field
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse
from starlette.routing import Mount
from mcp.server.fastmcp import FastMCP
from config import settings
logger = logging.getLogger(__name__)
# Context variable for authenticated user_id
# Set by middleware, read by tools
current_user_id: ContextVar[str] = ContextVar("current_user_id", default="")
def get_authenticated_user() -> str:
"""Get the authenticated user_id from context.
Raises:
ValueError: If no authenticated user in context.
"""
user_id = current_user_id.get()
if not user_id:
raise ValueError("No authenticated user in context")
return user_id
class MCPAuthMiddleware(BaseHTTPMiddleware):
"""Middleware to authenticate MCP requests using API key."""
async def dispatch(self, request, call_next):
# Extract API key from headers
api_key = (
request.headers.get("x-api-key") or
request.headers.get("X-API-Key") or
request.headers.get("Authorization", "").replace("Bearer ", "")
)
if not api_key:
return JSONResponse(
{"error": "Missing API key. Provide x-api-key or Authorization header."},
status_code=401
)
# Map API key to user_id
user_id = settings.api_key_mapping.get(api_key)
if not user_id:
return JSONResponse(
{"error": "Invalid API key"},
status_code=401
)
# Store user_id in context for tools to access
token = current_user_id.set(user_id)
try:
response = await call_next(request)
return response
finally:
current_user_id.reset(token)
# Create FastMCP server
# streamable_http_path="/" since we mount at /mcp in main.py
mcp = FastMCP(
"Mem0 Memory Service",
stateless_http=True,
json_response=True,
streamable_http_path="/"
)
@mcp.tool()
async def add_memory(
content: str = Field(description="Content to add to memory"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"),
) -> dict:
"""Add content to the user's memory.
The user_id is automatically determined from the API key authentication.
Use agent_id and run_id for multi-agent or session-based memory organization.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}")
result = await mem0_manager.add_memories(
messages=[{"role": "user", "content": content}],
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=metadata
)
return result
@mcp.tool()
async def search_memory(
query: str = Field(description="Search query to find relevant memories"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"),
limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"),
) -> dict:
"""Search the user's memories.
Returns memories most relevant to the query. The user_id is automatically
determined from API key authentication.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}")
result = await mem0_manager.search_memories(
query=query,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit
)
return result
@mcp.tool()
async def remove_memory(
memory_id: str = Field(description="The ID of the memory to remove"),
) -> dict:
"""Remove a specific memory by its ID.
Only memories belonging to the authenticated user can be deleted.
Verifies ownership before deletion.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}")
# Verify ownership: get user's memories and check if memory_id exists
user_memories = await mem0_manager.get_user_memories(
user_id=user_id,
limit=10000 # Get all to check ownership
)
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
if memory_id not in memory_ids:
raise ValueError(f"Memory '{memory_id}' not found or access denied")
result = await mem0_manager.delete_memory(memory_id=memory_id)
return result
@mcp.tool()
async def chat(
message: str = Field(description="The user's message to chat with"),
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
) -> str:
"""Chat with memory context.
Retrieves relevant memories based on the message, generates a response
using the configured LLM, and stores the conversation in memory.
The user_id is automatically determined from API key authentication.
"""
from mem0_manager import mem0_manager
user_id = get_authenticated_user()
logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...")
result = await mem0_manager.chat_with_memory(
message=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id
)
return result.get("response", "")
@contextlib.asynccontextmanager
async def mcp_lifespan():
"""Context manager for MCP session lifecycle.
Must be used in the main FastAPI lifespan since mounted app
lifespans don't run automatically.
"""
async with mcp.session_manager.run():
logger.info("MCP session manager started")
yield
logger.info("MCP session manager stopped")
def create_mcp_app() -> Starlette:
"""Create and configure the MCP Starlette application.
Returns a Starlette app with MCP endpoints, authentication middleware,
and CORS support.
Note: The MCP session manager must be started via mcp_lifespan() in
the main FastAPI lifespan, not here.
"""
# Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp
streamable_app = mcp.streamable_http_app()
# Create Starlette app with MCP routes
app = Starlette(
routes=[Mount("/", app=streamable_app)],
)
# Add authentication middleware
app.add_middleware(MCPAuthMiddleware)
# Add CORS middleware for browser clients
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
allow_headers=["*"],
expose_headers=["Mcp-Session-Id"],
)
return app

View file

@ -31,3 +31,6 @@ python-json-logger
# CORS and Security # CORS and Security
python-jose[cryptography] python-jose[cryptography]
passlib[bcrypt] passlib[bcrypt]
# MCP Server
mcp[server]>=1.0.0