added MCP HTTP endpoint with auth
Exposes memory operations as MCP tools over /mcp endpoint: - add_memory, search_memory, remove_memory, chat - API key auth via x-api-key or Authorization header - User isolation enforced via contextvars
This commit is contained in:
parent
997865283f
commit
35c1bbec4e
3 changed files with 276 additions and 4 deletions
|
|
@ -48,19 +48,36 @@ async def lifespan(app: FastAPI):
|
|||
"""Application lifespan manager."""
|
||||
# Startup
|
||||
logger.info("Starting Mem0 Interface POC")
|
||||
|
||||
|
||||
# Perform health check on startup
|
||||
health_status = await mem0_manager.health_check()
|
||||
unhealthy_services = [k for k, v in health_status.items() if "unhealthy" in v]
|
||||
|
||||
|
||||
if unhealthy_services:
|
||||
logger.warning(f"Some services are unhealthy: {unhealthy_services}")
|
||||
else:
|
||||
logger.info("All services are healthy")
|
||||
|
||||
|
||||
# Start MCP session manager if available
|
||||
mcp_context = None
|
||||
try:
|
||||
from mcp_server import mcp_lifespan
|
||||
mcp_context = mcp_lifespan()
|
||||
await mcp_context.__aenter__()
|
||||
except ImportError:
|
||||
logger.warning("MCP server not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP session manager: {e}")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
# Shutdown
|
||||
if mcp_context:
|
||||
try:
|
||||
await mcp_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping MCP session manager: {e}")
|
||||
|
||||
logger.info("Shutting down Mem0 Interface POC")
|
||||
|
||||
|
||||
|
|
@ -578,6 +595,18 @@ async def get_active_users():
|
|||
}
|
||||
|
||||
|
||||
# Mount MCP server at /mcp endpoint
|
||||
try:
|
||||
from mcp_server import create_mcp_app
|
||||
mcp_app = create_mcp_app()
|
||||
app.mount("/mcp", mcp_app)
|
||||
logger.info("MCP server mounted at /mcp")
|
||||
except ImportError as e:
|
||||
logger.warning(f"MCP server not available (missing dependencies): {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mount MCP server: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
print("Starting UVicorn server...")
|
||||
|
|
|
|||
240
backend/mcp_server.py
Normal file
240
backend/mcp_server.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
"""MCP Server for Mem0 Memory Service.
|
||||
|
||||
Exposes memory operations as MCP tools over HTTP with API key authentication.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Mount
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variable for authenticated user_id
|
||||
# Set by middleware, read by tools
|
||||
current_user_id: ContextVar[str] = ContextVar("current_user_id", default="")
|
||||
|
||||
|
||||
def get_authenticated_user() -> str:
|
||||
"""Get the authenticated user_id from context.
|
||||
|
||||
Raises:
|
||||
ValueError: If no authenticated user in context.
|
||||
"""
|
||||
user_id = current_user_id.get()
|
||||
if not user_id:
|
||||
raise ValueError("No authenticated user in context")
|
||||
return user_id
|
||||
|
||||
|
||||
class MCPAuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to authenticate MCP requests using API key."""
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
# Extract API key from headers
|
||||
api_key = (
|
||||
request.headers.get("x-api-key") or
|
||||
request.headers.get("X-API-Key") or
|
||||
request.headers.get("Authorization", "").replace("Bearer ", "")
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
return JSONResponse(
|
||||
{"error": "Missing API key. Provide x-api-key or Authorization header."},
|
||||
status_code=401
|
||||
)
|
||||
|
||||
# Map API key to user_id
|
||||
user_id = settings.api_key_mapping.get(api_key)
|
||||
if not user_id:
|
||||
return JSONResponse(
|
||||
{"error": "Invalid API key"},
|
||||
status_code=401
|
||||
)
|
||||
|
||||
# Store user_id in context for tools to access
|
||||
token = current_user_id.set(user_id)
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
finally:
|
||||
current_user_id.reset(token)
|
||||
|
||||
|
||||
# Create FastMCP server
|
||||
# streamable_http_path="/" since we mount at /mcp in main.py
|
||||
mcp = FastMCP(
|
||||
"Mem0 Memory Service",
|
||||
stateless_http=True,
|
||||
json_response=True,
|
||||
streamable_http_path="/"
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def add_memory(
|
||||
content: str = Field(description="Content to add to memory"),
|
||||
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
|
||||
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
|
||||
metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"),
|
||||
) -> dict:
|
||||
"""Add content to the user's memory.
|
||||
|
||||
The user_id is automatically determined from the API key authentication.
|
||||
Use agent_id and run_id for multi-agent or session-based memory organization.
|
||||
"""
|
||||
from mem0_manager import mem0_manager
|
||||
|
||||
user_id = get_authenticated_user()
|
||||
logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}")
|
||||
|
||||
result = await mem0_manager.add_memories(
|
||||
messages=[{"role": "user", "content": content}],
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_memory(
|
||||
query: str = Field(description="Search query to find relevant memories"),
|
||||
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"),
|
||||
run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"),
|
||||
limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"),
|
||||
) -> dict:
|
||||
"""Search the user's memories.
|
||||
|
||||
Returns memories most relevant to the query. The user_id is automatically
|
||||
determined from API key authentication.
|
||||
"""
|
||||
from mem0_manager import mem0_manager
|
||||
|
||||
user_id = get_authenticated_user()
|
||||
logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}")
|
||||
|
||||
result = await mem0_manager.search_memories(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def remove_memory(
|
||||
memory_id: str = Field(description="The ID of the memory to remove"),
|
||||
) -> dict:
|
||||
"""Remove a specific memory by its ID.
|
||||
|
||||
Only memories belonging to the authenticated user can be deleted.
|
||||
Verifies ownership before deletion.
|
||||
"""
|
||||
from mem0_manager import mem0_manager
|
||||
|
||||
user_id = get_authenticated_user()
|
||||
logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}")
|
||||
|
||||
# Verify ownership: get user's memories and check if memory_id exists
|
||||
user_memories = await mem0_manager.get_user_memories(
|
||||
user_id=user_id,
|
||||
limit=10000 # Get all to check ownership
|
||||
)
|
||||
|
||||
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
|
||||
|
||||
if memory_id not in memory_ids:
|
||||
raise ValueError(f"Memory '{memory_id}' not found or access denied")
|
||||
|
||||
result = await mem0_manager.delete_memory(memory_id=memory_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def chat(
|
||||
message: str = Field(description="The user's message to chat with"),
|
||||
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
|
||||
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
|
||||
) -> str:
|
||||
"""Chat with memory context.
|
||||
|
||||
Retrieves relevant memories based on the message, generates a response
|
||||
using the configured LLM, and stores the conversation in memory.
|
||||
The user_id is automatically determined from API key authentication.
|
||||
"""
|
||||
from mem0_manager import mem0_manager
|
||||
|
||||
user_id = get_authenticated_user()
|
||||
logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...")
|
||||
|
||||
result = await mem0_manager.chat_with_memory(
|
||||
message=message,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id
|
||||
)
|
||||
|
||||
return result.get("response", "")
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def mcp_lifespan():
|
||||
"""Context manager for MCP session lifecycle.
|
||||
|
||||
Must be used in the main FastAPI lifespan since mounted app
|
||||
lifespans don't run automatically.
|
||||
"""
|
||||
async with mcp.session_manager.run():
|
||||
logger.info("MCP session manager started")
|
||||
yield
|
||||
logger.info("MCP session manager stopped")
|
||||
|
||||
|
||||
def create_mcp_app() -> Starlette:
|
||||
"""Create and configure the MCP Starlette application.
|
||||
|
||||
Returns a Starlette app with MCP endpoints, authentication middleware,
|
||||
and CORS support.
|
||||
|
||||
Note: The MCP session manager must be started via mcp_lifespan() in
|
||||
the main FastAPI lifespan, not here.
|
||||
"""
|
||||
# Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp
|
||||
streamable_app = mcp.streamable_http_app()
|
||||
|
||||
# Create Starlette app with MCP routes
|
||||
app = Starlette(
|
||||
routes=[Mount("/", app=streamable_app)],
|
||||
)
|
||||
|
||||
# Add authentication middleware
|
||||
app.add_middleware(MCPAuthMiddleware)
|
||||
|
||||
# Add CORS middleware for browser clients
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=False,
|
||||
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["Mcp-Session-Id"],
|
||||
)
|
||||
|
||||
return app
|
||||
|
|
@ -31,3 +31,6 @@ python-json-logger
|
|||
# CORS and Security
|
||||
python-jose[cryptography]
|
||||
passlib[bcrypt]
|
||||
|
||||
# MCP Server
|
||||
mcp[server]>=1.0.0
|
||||
|
|
|
|||
Loading…
Reference in a new issue