Exposes memory operations as MCP tools over /mcp endpoint: - add_memory, search_memory, remove_memory, chat - API key auth via x-api-key or Authorization header - User isolation enforced via contextvars
240 lines
7.4 KiB
Python
240 lines
7.4 KiB
Python
"""MCP Server for Mem0 Memory Service.
|
|
|
|
Exposes memory operations as MCP tools over HTTP with API key authentication.
|
|
"""
|
|
|
|
import contextlib
|
|
import logging
|
|
from contextvars import ContextVar
|
|
from typing import Optional
|
|
|
|
from pydantic import Field
|
|
from starlette.applications import Starlette
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
from starlette.responses import JSONResponse
|
|
from starlette.routing import Mount
|
|
from mcp.server.fastmcp import FastMCP
|
|
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Context variable for authenticated user_id
|
|
# Set by middleware, read by tools
|
|
current_user_id: ContextVar[str] = ContextVar("current_user_id", default="")
|
|
|
|
|
|
def get_authenticated_user() -> str:
|
|
"""Get the authenticated user_id from context.
|
|
|
|
Raises:
|
|
ValueError: If no authenticated user in context.
|
|
"""
|
|
user_id = current_user_id.get()
|
|
if not user_id:
|
|
raise ValueError("No authenticated user in context")
|
|
return user_id
|
|
|
|
|
|
class MCPAuthMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to authenticate MCP requests using API key."""
|
|
|
|
async def dispatch(self, request, call_next):
|
|
# Extract API key from headers
|
|
api_key = (
|
|
request.headers.get("x-api-key") or
|
|
request.headers.get("X-API-Key") or
|
|
request.headers.get("Authorization", "").replace("Bearer ", "")
|
|
)
|
|
|
|
if not api_key:
|
|
return JSONResponse(
|
|
{"error": "Missing API key. Provide x-api-key or Authorization header."},
|
|
status_code=401
|
|
)
|
|
|
|
# Map API key to user_id
|
|
user_id = settings.api_key_mapping.get(api_key)
|
|
if not user_id:
|
|
return JSONResponse(
|
|
{"error": "Invalid API key"},
|
|
status_code=401
|
|
)
|
|
|
|
# Store user_id in context for tools to access
|
|
token = current_user_id.set(user_id)
|
|
try:
|
|
response = await call_next(request)
|
|
return response
|
|
finally:
|
|
current_user_id.reset(token)
|
|
|
|
|
|
# Create FastMCP server
|
|
# streamable_http_path="/" since we mount at /mcp in main.py
|
|
mcp = FastMCP(
|
|
"Mem0 Memory Service",
|
|
stateless_http=True,
|
|
json_response=True,
|
|
streamable_http_path="/"
|
|
)
|
|
|
|
|
|
@mcp.tool()
|
|
async def add_memory(
|
|
content: str = Field(description="Content to add to memory"),
|
|
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
|
|
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
|
|
metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"),
|
|
) -> dict:
|
|
"""Add content to the user's memory.
|
|
|
|
The user_id is automatically determined from the API key authentication.
|
|
Use agent_id and run_id for multi-agent or session-based memory organization.
|
|
"""
|
|
from mem0_manager import mem0_manager
|
|
|
|
user_id = get_authenticated_user()
|
|
logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}")
|
|
|
|
result = await mem0_manager.add_memories(
|
|
messages=[{"role": "user", "content": content}],
|
|
user_id=user_id,
|
|
agent_id=agent_id,
|
|
run_id=run_id,
|
|
metadata=metadata
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_memory(
|
|
query: str = Field(description="Search query to find relevant memories"),
|
|
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"),
|
|
run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"),
|
|
limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"),
|
|
) -> dict:
|
|
"""Search the user's memories.
|
|
|
|
Returns memories most relevant to the query. The user_id is automatically
|
|
determined from API key authentication.
|
|
"""
|
|
from mem0_manager import mem0_manager
|
|
|
|
user_id = get_authenticated_user()
|
|
logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}")
|
|
|
|
result = await mem0_manager.search_memories(
|
|
query=query,
|
|
user_id=user_id,
|
|
agent_id=agent_id,
|
|
run_id=run_id,
|
|
limit=limit
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
@mcp.tool()
|
|
async def remove_memory(
|
|
memory_id: str = Field(description="The ID of the memory to remove"),
|
|
) -> dict:
|
|
"""Remove a specific memory by its ID.
|
|
|
|
Only memories belonging to the authenticated user can be deleted.
|
|
Verifies ownership before deletion.
|
|
"""
|
|
from mem0_manager import mem0_manager
|
|
|
|
user_id = get_authenticated_user()
|
|
logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}")
|
|
|
|
# Verify ownership: get user's memories and check if memory_id exists
|
|
user_memories = await mem0_manager.get_user_memories(
|
|
user_id=user_id,
|
|
limit=10000 # Get all to check ownership
|
|
)
|
|
|
|
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
|
|
|
|
if memory_id not in memory_ids:
|
|
raise ValueError(f"Memory '{memory_id}' not found or access denied")
|
|
|
|
result = await mem0_manager.delete_memory(memory_id=memory_id)
|
|
|
|
return result
|
|
|
|
|
|
@mcp.tool()
|
|
async def chat(
|
|
message: str = Field(description="The user's message to chat with"),
|
|
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
|
|
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
|
|
) -> str:
|
|
"""Chat with memory context.
|
|
|
|
Retrieves relevant memories based on the message, generates a response
|
|
using the configured LLM, and stores the conversation in memory.
|
|
The user_id is automatically determined from API key authentication.
|
|
"""
|
|
from mem0_manager import mem0_manager
|
|
|
|
user_id = get_authenticated_user()
|
|
logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...")
|
|
|
|
result = await mem0_manager.chat_with_memory(
|
|
message=message,
|
|
user_id=user_id,
|
|
agent_id=agent_id,
|
|
run_id=run_id
|
|
)
|
|
|
|
return result.get("response", "")
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def mcp_lifespan():
|
|
"""Context manager for MCP session lifecycle.
|
|
|
|
Must be used in the main FastAPI lifespan since mounted app
|
|
lifespans don't run automatically.
|
|
"""
|
|
async with mcp.session_manager.run():
|
|
logger.info("MCP session manager started")
|
|
yield
|
|
logger.info("MCP session manager stopped")
|
|
|
|
|
|
def create_mcp_app() -> Starlette:
|
|
"""Create and configure the MCP Starlette application.
|
|
|
|
Returns a Starlette app with MCP endpoints, authentication middleware,
|
|
and CORS support.
|
|
|
|
Note: The MCP session manager must be started via mcp_lifespan() in
|
|
the main FastAPI lifespan, not here.
|
|
"""
|
|
# Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp
|
|
streamable_app = mcp.streamable_http_app()
|
|
|
|
# Create Starlette app with MCP routes
|
|
app = Starlette(
|
|
routes=[Mount("/", app=streamable_app)],
|
|
)
|
|
|
|
# Add authentication middleware
|
|
app.add_middleware(MCPAuthMiddleware)
|
|
|
|
# Add CORS middleware for browser clients
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=False,
|
|
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
|
|
allow_headers=["*"],
|
|
expose_headers=["Mcp-Session-Id"],
|
|
)
|
|
|
|
return app
|