added MCP HTTP endpoint with auth
Exposes memory operations as MCP tools over /mcp endpoint: - add_memory, search_memory, remove_memory, chat - API key auth via x-api-key or Authorization header - User isolation enforced via contextvars
This commit is contained in:
parent
997865283f
commit
35c1bbec4e
3 changed files with 276 additions and 4 deletions
|
|
@ -58,9 +58,26 @@ async def lifespan(app: FastAPI):
|
||||||
else:
|
else:
|
||||||
logger.info("All services are healthy")
|
logger.info("All services are healthy")
|
||||||
|
|
||||||
|
# Start MCP session manager if available
|
||||||
|
mcp_context = None
|
||||||
|
try:
|
||||||
|
from mcp_server import mcp_lifespan
|
||||||
|
mcp_context = mcp_lifespan()
|
||||||
|
await mcp_context.__aenter__()
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("MCP server not available")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start MCP session manager: {e}")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown
|
# Shutdown
|
||||||
|
if mcp_context:
|
||||||
|
try:
|
||||||
|
await mcp_context.__aexit__(None, None, None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error stopping MCP session manager: {e}")
|
||||||
|
|
||||||
logger.info("Shutting down Mem0 Interface POC")
|
logger.info("Shutting down Mem0 Interface POC")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -578,6 +595,18 @@ async def get_active_users():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Mount MCP server at /mcp endpoint
|
||||||
|
try:
|
||||||
|
from mcp_server import create_mcp_app
|
||||||
|
mcp_app = create_mcp_app()
|
||||||
|
app.mount("/mcp", mcp_app)
|
||||||
|
logger.info("MCP server mounted at /mcp")
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"MCP server not available (missing dependencies): {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to mount MCP server: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
print("Starting UVicorn server...")
|
print("Starting UVicorn server...")
|
||||||
|
|
|
||||||
240
backend/mcp_server.py
Normal file
240
backend/mcp_server.py
Normal file
|
|
@ -0,0 +1,240 @@
|
||||||
|
"""MCP Server for Mem0 Memory Service.
|
||||||
|
|
||||||
|
Exposes memory operations as MCP tools over HTTP with API key authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.routing import Mount
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Context variable for authenticated user_id
|
||||||
|
# Set by middleware, read by tools
|
||||||
|
current_user_id: ContextVar[str] = ContextVar("current_user_id", default="")
|
||||||
|
|
||||||
|
|
||||||
|
def get_authenticated_user() -> str:
|
||||||
|
"""Get the authenticated user_id from context.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no authenticated user in context.
|
||||||
|
"""
|
||||||
|
user_id = current_user_id.get()
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("No authenticated user in context")
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
|
class MCPAuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to authenticate MCP requests using API key."""
|
||||||
|
|
||||||
|
async def dispatch(self, request, call_next):
|
||||||
|
# Extract API key from headers
|
||||||
|
api_key = (
|
||||||
|
request.headers.get("x-api-key") or
|
||||||
|
request.headers.get("X-API-Key") or
|
||||||
|
request.headers.get("Authorization", "").replace("Bearer ", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "Missing API key. Provide x-api-key or Authorization header."},
|
||||||
|
status_code=401
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map API key to user_id
|
||||||
|
user_id = settings.api_key_mapping.get(api_key)
|
||||||
|
if not user_id:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "Invalid API key"},
|
||||||
|
status_code=401
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store user_id in context for tools to access
|
||||||
|
token = current_user_id.set(user_id)
|
||||||
|
try:
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
finally:
|
||||||
|
current_user_id.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
# Create FastMCP server
|
||||||
|
# streamable_http_path="/" since we mount at /mcp in main.py
|
||||||
|
mcp = FastMCP(
|
||||||
|
"Mem0 Memory Service",
|
||||||
|
stateless_http=True,
|
||||||
|
json_response=True,
|
||||||
|
streamable_http_path="/"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def add_memory(
|
||||||
|
content: str = Field(description="Content to add to memory"),
|
||||||
|
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
|
||||||
|
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
|
||||||
|
metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"),
|
||||||
|
) -> dict:
|
||||||
|
"""Add content to the user's memory.
|
||||||
|
|
||||||
|
The user_id is automatically determined from the API key authentication.
|
||||||
|
Use agent_id and run_id for multi-agent or session-based memory organization.
|
||||||
|
"""
|
||||||
|
from mem0_manager import mem0_manager
|
||||||
|
|
||||||
|
user_id = get_authenticated_user()
|
||||||
|
logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}")
|
||||||
|
|
||||||
|
result = await mem0_manager.add_memories(
|
||||||
|
messages=[{"role": "user", "content": content}],
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
run_id=run_id,
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_memory(
|
||||||
|
query: str = Field(description="Search query to find relevant memories"),
|
||||||
|
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"),
|
||||||
|
run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"),
|
||||||
|
limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"),
|
||||||
|
) -> dict:
|
||||||
|
"""Search the user's memories.
|
||||||
|
|
||||||
|
Returns memories most relevant to the query. The user_id is automatically
|
||||||
|
determined from API key authentication.
|
||||||
|
"""
|
||||||
|
from mem0_manager import mem0_manager
|
||||||
|
|
||||||
|
user_id = get_authenticated_user()
|
||||||
|
logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}")
|
||||||
|
|
||||||
|
result = await mem0_manager.search_memories(
|
||||||
|
query=query,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
run_id=run_id,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def remove_memory(
|
||||||
|
memory_id: str = Field(description="The ID of the memory to remove"),
|
||||||
|
) -> dict:
|
||||||
|
"""Remove a specific memory by its ID.
|
||||||
|
|
||||||
|
Only memories belonging to the authenticated user can be deleted.
|
||||||
|
Verifies ownership before deletion.
|
||||||
|
"""
|
||||||
|
from mem0_manager import mem0_manager
|
||||||
|
|
||||||
|
user_id = get_authenticated_user()
|
||||||
|
logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}")
|
||||||
|
|
||||||
|
# Verify ownership: get user's memories and check if memory_id exists
|
||||||
|
user_memories = await mem0_manager.get_user_memories(
|
||||||
|
user_id=user_id,
|
||||||
|
limit=10000 # Get all to check ownership
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
|
||||||
|
|
||||||
|
if memory_id not in memory_ids:
|
||||||
|
raise ValueError(f"Memory '{memory_id}' not found or access denied")
|
||||||
|
|
||||||
|
result = await mem0_manager.delete_memory(memory_id=memory_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def chat(
|
||||||
|
message: str = Field(description="The user's message to chat with"),
|
||||||
|
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
|
||||||
|
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
|
||||||
|
) -> str:
|
||||||
|
"""Chat with memory context.
|
||||||
|
|
||||||
|
Retrieves relevant memories based on the message, generates a response
|
||||||
|
using the configured LLM, and stores the conversation in memory.
|
||||||
|
The user_id is automatically determined from API key authentication.
|
||||||
|
"""
|
||||||
|
from mem0_manager import mem0_manager
|
||||||
|
|
||||||
|
user_id = get_authenticated_user()
|
||||||
|
logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...")
|
||||||
|
|
||||||
|
result = await mem0_manager.chat_with_memory(
|
||||||
|
message=message,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
run_id=run_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def mcp_lifespan():
|
||||||
|
"""Context manager for MCP session lifecycle.
|
||||||
|
|
||||||
|
Must be used in the main FastAPI lifespan since mounted app
|
||||||
|
lifespans don't run automatically.
|
||||||
|
"""
|
||||||
|
async with mcp.session_manager.run():
|
||||||
|
logger.info("MCP session manager started")
|
||||||
|
yield
|
||||||
|
logger.info("MCP session manager stopped")
|
||||||
|
|
||||||
|
|
||||||
|
def create_mcp_app() -> Starlette:
|
||||||
|
"""Create and configure the MCP Starlette application.
|
||||||
|
|
||||||
|
Returns a Starlette app with MCP endpoints, authentication middleware,
|
||||||
|
and CORS support.
|
||||||
|
|
||||||
|
Note: The MCP session manager must be started via mcp_lifespan() in
|
||||||
|
the main FastAPI lifespan, not here.
|
||||||
|
"""
|
||||||
|
# Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp
|
||||||
|
streamable_app = mcp.streamable_http_app()
|
||||||
|
|
||||||
|
# Create Starlette app with MCP routes
|
||||||
|
app = Starlette(
|
||||||
|
routes=[Mount("/", app=streamable_app)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add authentication middleware
|
||||||
|
app.add_middleware(MCPAuthMiddleware)
|
||||||
|
|
||||||
|
# Add CORS middleware for browser clients
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=False,
|
||||||
|
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
expose_headers=["Mcp-Session-Id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
@ -31,3 +31,6 @@ python-json-logger
|
||||||
# CORS and Security
|
# CORS and Security
|
||||||
python-jose[cryptography]
|
python-jose[cryptography]
|
||||||
passlib[bcrypt]
|
passlib[bcrypt]
|
||||||
|
|
||||||
|
# MCP Server
|
||||||
|
mcp[server]>=1.0.0
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue