diff --git a/backend/auth.py b/backend/auth.py index 70c5db4..187757d 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -130,6 +130,39 @@ async def get_current_user_openai( return auth_service.verify_api_key(api_key) +async def get_current_user_platform( + authorization: Optional[str] = Header(None), + x_api_key: Optional[str] = Header(None, alias="X-API-Key"), +) -> str: + """FastAPI dependency for the mem0 platform SDK (``MemoryClient``). + + ``MemoryClient`` authenticates with ``Authorization: Token ``. We also + accept ``Bearer `` and ``X-API-Key`` so the same routes work from curl + and OpenAI-style clients. Resolves to the mapped user_id (raises 401 if the + key is missing or invalid). + """ + api_key = None + + if authorization: + scheme, _, rest = authorization.partition(" ") + if rest and scheme.lower() in ("token", "bearer"): + api_key = rest.strip() + else: + # Bare value with no recognised scheme — treat the whole header as the key. + api_key = authorization.strip() + if not api_key and x_api_key: + api_key = x_api_key + + if not api_key: + logger.warning("No API key provided in Authorization or X-API-Key headers") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing API key. Provide 'Authorization: Token ' (mem0 SDK), 'Bearer ', or 'X-API-Key'.", + ) + + return auth_service.verify_api_key(api_key) + + async def verify_user_access( api_key: str = Security(api_key_header), user_id: Optional[str] = None ) -> str: diff --git a/backend/main.py b/backend/main.py index 78421ad..661ac40 100644 --- a/backend/main.py +++ b/backend/main.py @@ -47,6 +47,7 @@ from models import ( ) from mem0_manager import mem0_manager from auth import get_current_user, get_current_user_openai, auth_service +from platform_compat import router as platform_router # Configure structured logging structlog.configure( @@ -401,7 +402,6 @@ async def openai_chat_completions( # Memory management endpoints - pure Mem0 passthroughs @app.post("/memories") -@app.post("/v1/memories") @limiter.limit("60/minute") # Memory operations - 60/min async def add_memories( request: Request, @@ -440,7 +440,6 @@ async def add_memories( @app.post("/memories/search") -@app.post("/v1/memories/search") @limiter.limit("120/minute") # Search is lighter - 120/min async def search_memories( request: Request, @@ -481,7 +480,6 @@ async def search_memories( detail="An internal error occurred. Please try again later.", ) -@app.get("/v1/memories/{user_id}") @app.get("/memories/{user_id}") @limiter.limit("120/minute") async def get_user_memories( @@ -520,7 +518,6 @@ async def get_user_memories( @app.put("/memories") -@app.put("/v1/memories") @limiter.limit("60/minute") async def update_memory( request: Request, @@ -565,7 +562,6 @@ async def update_memory( detail="An internal error occurred. Please try again later.", ) -@app.delete("/v1/memories/{memory_id}") @app.delete("/memories/{memory_id}") @limiter.limit("60/minute") async def delete_memory( @@ -600,7 +596,6 @@ async def delete_memory( ) -@app.delete("/v1/memories/user/{user_id}") @app.delete("/memories/user/{user_id}") @limiter.limit("10/minute") # Dangerous bulk delete - heavily rate limited async def delete_user_memories( @@ -837,6 +832,11 @@ async def get_active_users( +# mem0 platform SDK (MemoryClient) compatibility routes: +# /v1/ping/, /v3/memories/{add,search}/, /v3/memories/, /v1/memories/*, /v1/entities/ +app.include_router(platform_router) + + # Mount MCP server at /mcp endpoint try: from mcp_server import create_mcp_app diff --git a/backend/platform_compat.py b/backend/platform_compat.py new file mode 100644 index 0000000..b7d4f0d --- /dev/null +++ b/backend/platform_compat.py @@ -0,0 +1,224 @@ +"""mem0 platform API compatibility layer. + +Implements the subset of the hosted mem0 platform API that the ``mem0ai`` +``MemoryClient`` (pinned ``mem0ai==2.0.2``) actually calls, so +``MemoryClient(host="https://memory.pratikn.com", api_key=...)`` works against +this self-hosted server. Each platform route maps onto the existing +``mem0_manager`` singleton. + +Contract notes (verified against the installed SDK source): +- Auth header is ``Authorization: Token `` (handled by + ``get_current_user_platform``, which also accepts Bearer / X-API-Key). +- Core ops use ``/v3/memories/*``; item ops use ``/v1/memories/*``; all paths + carry a trailing slash and are registered here at the exact path the SDK calls + (so FastAPI matches exactly and never issues a slash redirect). +- ``GET /v1/ping/`` runs at client construction and MUST return non-empty + ``org_id`` and ``project_id`` or the SDK's ``Project(...)`` raises. +- Scoping (user_id/agent_id/run_id) is carried in ``filters`` (search/get_all), + the top-level body (add), or the query string (delete_all). It defaults to the + authenticated user; a mismatch is rejected with 403 (same model as the native + endpoints). +""" + +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Request +import structlog + +from auth import get_current_user_platform +from mem0_manager import mem0_manager + +logger = structlog.get_logger(__name__) + +router = APIRouter(tags=["mem0-platform-compat"]) + + +def _require_self(requested: Optional[str], authed: str) -> str: + """Return the user_id to operate on: default to the authenticated user, + reject a mismatch with 403 (consistent with the native endpoints).""" + if not requested: + return authed + if requested != authed: + raise HTTPException( + status_code=403, + detail=f"Access denied: you can only access your own memories (authenticated as '{authed}')", + ) + return authed + + +async def _json_body(request: Request) -> Dict[str, Any]: + """Parse the JSON body defensively (the SDK sends varied shapes).""" + try: + body = await request.json() + except Exception: + body = None + return body if isinstance(body, dict) else {} + + +def _split_filters(body: Dict[str, Any], authed: str): + """Pull the scoping IDs out of the SDK's ``filters`` object and return + (target_user, agent_id, run_id, remaining_filters).""" + filters = dict(body.get("filters") or {}) + target = _require_self(filters.pop("user_id", None), authed) + agent_id = filters.pop("agent_id", None) + run_id = filters.pop("run_id", None) + filters.pop("app_id", None) # accepted by the SDK; unused here + return target, agent_id, run_id, (filters or None) + + +async def _owned_or_404(memory_id: str, user: str) -> None: + if not await mem0_manager.verify_memory_ownership(memory_id, user): + raise HTTPException(status_code=404, detail=f"Memory '{memory_id}' not found") + + +@router.get("/v1/ping/") +async def ping(user: str = Depends(get_current_user_platform)) -> Dict[str, Any]: + """Client construction validation. ``org_id``/``project_id`` MUST be + non-empty or the SDK's ``Project(...)`` raises 'org_id and project_id must + be set'.""" + return { + "status": "ok", + "user_email": user, + "org_id": "default-org", + "project_id": "default-project", + } + + +@router.post("/v3/memories/add/") +async def add_memories( + request: Request, user: str = Depends(get_current_user_platform) +) -> Dict[str, Any]: + body = await _json_body(request) + messages = body.get("messages") + if not messages: + raise HTTPException(status_code=422, detail="`messages` is required") + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + target = _require_self(body.get("user_id"), user) + + raw = await mem0_manager.add_memories( + messages=messages, + user_id=target, + agent_id=body.get("agent_id"), + run_id=body.get("run_id"), + metadata=body.get("metadata"), + ) + # mem0_manager wraps the mem0 result as {"added_memories": [], ...}; + # the mem0 dict is already {"results": [...]} (the platform shape). + added = raw.get("added_memories") or [] + if added and isinstance(added[0], dict) and "results" in added[0]: + return added[0] + return {"results": added} + + +@router.post("/v3/memories/search/") +async def search_memories( + request: Request, user: str = Depends(get_current_user_platform) +) -> Dict[str, Any]: + body = await _json_body(request) + query = body.get("query") + if not query: + raise HTTPException(status_code=422, detail="`query` is required") + target, agent_id, run_id, extra = _split_filters(body, user) + top_k = body.get("top_k") or body.get("limit") or 10 + + result = await mem0_manager.search_memories( + query=query, + user_id=target, + limit=int(top_k), + filters=extra, + agent_id=agent_id, + run_id=run_id, + ) + return {"results": result.get("memories", [])} + + +@router.post("/v3/memories/") +async def get_all_memories( + request: Request, user: str = Depends(get_current_user_platform) +) -> Dict[str, Any]: + body = await _json_body(request) + target, agent_id, run_id, extra = _split_filters(body, user) + + # The SDK sends page/page_size as query params. mem0's get_all has no offset, + # so we fetch up to page*page_size and slice the requested page. + try: + page = max(int(request.query_params.get("page", 1)), 1) + page_size = min(max(int(request.query_params.get("page_size", 100)), 1), 1000) + except ValueError: + page, page_size = 1, 100 + + items = await mem0_manager.get_user_memories( + user_id=target, + limit=page * page_size, + agent_id=agent_id, + run_id=run_id, + filters=extra, + ) + total = len(items) + start = (page - 1) * page_size + return { + "count": total, + "next": page + 1 if start + page_size < total else None, + "previous": page - 1 if page > 1 else None, + "results": items[start : start + page_size], + } + + +@router.get("/v1/memories/{memory_id}/") +async def get_memory( + memory_id: str, user: str = Depends(get_current_user_platform) +) -> Dict[str, Any]: + await _owned_or_404(memory_id, user) + mem = await mem0_manager.get_memory(memory_id) + if mem is None: + raise HTTPException(status_code=404, detail=f"Memory '{memory_id}' not found") + return mem + + +@router.put("/v1/memories/{memory_id}/") +async def update_memory( + memory_id: str, request: Request, user: str = Depends(get_current_user_platform) +) -> Dict[str, Any]: + await _owned_or_404(memory_id, user) + body = await _json_body(request) + text = body.get("text") or body.get("memory") or body.get("data") + if not text: + raise HTTPException( + status_code=422, detail="`text` is required to update a memory" + ) + return await mem0_manager.update_memory(memory_id=memory_id, content=text) + + +@router.delete("/v1/memories/{memory_id}/") +async def delete_memory( + memory_id: str, user: str = Depends(get_current_user_platform) +) -> Dict[str, Any]: + await _owned_or_404(memory_id, user) + return await mem0_manager.delete_memory(memory_id=memory_id) + + +@router.get("/v1/memories/{memory_id}/history/") +async def memory_history( + memory_id: str, user: str = Depends(get_current_user_platform) +) -> List[Dict[str, Any]]: + await _owned_or_404(memory_id, user) + result = await mem0_manager.get_memory_history(memory_id) + return result.get("history", []) + + +@router.delete("/v1/memories/") +async def delete_all_memories( + request: Request, user: str = Depends(get_current_user_platform) +) -> Dict[str, Any]: + target = _require_self(request.query_params.get("user_id"), user) + return await mem0_manager.delete_user_memories(user_id=target) + + +@router.get("/v1/entities/") +async def list_entities( + user: str = Depends(get_current_user_platform), +) -> Dict[str, Any]: + # This server is single-user-per-key; report the authenticated user as the + # only entity (the platform returns all users/agents/runs with memories). + return {"results": [{"id": user, "name": user, "type": "user"}], "count": 1} diff --git a/docker-compose.yml b/docker-compose.yml index 9b8c3a6..a30db2c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -61,7 +61,7 @@ services: volumes: - ./backend:/app - ./frontend:/app/frontend - command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"] + command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4", "--proxy-headers", "--forwarded-allow-ips", "*"] volumes: qdrant_data: diff --git a/test_sdk_compat.py b/test_sdk_compat.py new file mode 100644 index 0000000..6a2e98d --- /dev/null +++ b/test_sdk_compat.py @@ -0,0 +1,102 @@ +"""End-to-end compatibility test: drive the real mem0ai MemoryClient against +this server. Run INSIDE the backend container (which has mem0ai==2.0.2): + + ssh beast 'cd ~/aistuff/mem0 && docker compose exec -T backend python' < test_sdk_compat.py + +Requires MEM0_API_KEY in the environment (mapped to user 'pratik'). Exercises the +full SDK surface and cleans up the memories it creates (scoped to a throwaway +agent_id so it never touches real memories). +""" + +import os +import sys +import time + +from mem0 import MemoryClient + +KEY = os.environ.get("MEM0_API_KEY") +if not KEY: + sys.exit("set MEM0_API_KEY (mapped to user 'pratik')") + +HOST = os.environ.get("MEM0_HOST", "https://memory.pratikn.com") +USER = os.environ.get("MEM0_TEST_USER", "pratik") +AGENT = "sdk_compat_test" # isolates test data from real memories + +results = [] + + +def check(name, cond, info=""): + results.append(bool(cond)) + print(("PASS " if cond else "FAIL "), name, "--", str(info)[:240]) + + +# --- construct (hits GET /v1/ping/, validates Token auth + org/project) --- +try: + c = MemoryClient(host=HOST, api_key=KEY) + check("construct", True, f"user_email={c.user_email} org={c.org_id} proj={c.project_id}") +except Exception as e: + check("construct", False, repr(e)) + print("\nRESULT: RED (cannot construct client)") + sys.exit(1) + +# --- add (POST /v3/memories/add/) --- +probe = f"SDK compat probe {int(time.time())}: Pratik is validating the mem0 SDK compatibility layer and uses FastAPI." +try: + r = c.add(probe, user_id=USER, agent_id=AGENT) + check("add", isinstance(r, dict), r) +except Exception as e: + check("add", False, repr(e)) + +time.sleep(3) # allow async extraction/indexing to settle + +# --- search (POST /v3/memories/search/) --- +try: + s = c.search("mem0 SDK compatibility layer", filters={"user_id": USER, "agent_id": AGENT}) + check("search.shape", isinstance(s, dict) and "results" in s, s) +except Exception as e: + check("search.shape", False, repr(e)) + +# --- get_all (POST /v3/memories/) --- +ids = [] +try: + g = c.get_all(filters={"user_id": USER, "agent_id": AGENT}) + ok = isinstance(g, dict) and "results" in g and "count" in g + check("get_all.shape", ok, g) + ids = [m.get("id") for m in (g.get("results") or []) if m.get("id")] +except Exception as e: + check("get_all.shape", False, repr(e)) + +mid = ids[0] if ids else None +print(f" (created {len(ids)} memory id(s) under agent={AGENT})") + +# --- item ops (best-effort; depend on extraction producing >=1 fact) --- +if mid: + try: + one = c.get(mid) + check("get", isinstance(one, dict) and one.get("id") == mid, one) + except Exception as e: + check("get", False, repr(e)) + try: + h = c.history(mid) + check("history.is_list", isinstance(h, list), h) + except Exception as e: + check("history.is_list", False, repr(e)) + try: + u = c.update(mid, text="SDK compat probe (updated)") + check("update", isinstance(u, dict), u) + except Exception as e: + check("update", False, repr(e)) + +# --- cleanup: delete only the ids we created --- +deleted = 0 +for i in ids: + try: + c.delete(i) + deleted += 1 + except Exception as e: + print(" delete error", i, repr(e)) +print(f" cleanup: deleted {deleted}/{len(ids)}") + +green = results and all(results) +print("\nRESULT:", "GREEN" if green else "RED", f"({sum(results)}/{len(results)} checks passed)") +sys.exit(0 if green else 1)