feat: mem0 platform SDK (MemoryClient) compatibility + proxy-header redirect fix

Implements the subset of the hosted mem0 platform API that mem0ai==2.0.2
MemoryClient calls, so MemoryClient(host=..., api_key=...) works against this
server. Verified end-to-end (construct/add/search/get_all/get/history/update/delete).

- platform_compat.py: GET /v1/ping/ (returns non-empty org_id/project_id, which
  the SDK's Project init requires), POST /v3/memories/{add,search}/,
  POST /v3/memories/ (paginated get_all), /v1/memories/{id}/ item ops, and
  GET /v1/entities/ -- all mapped onto the existing mem0_manager.
- auth.get_current_user_platform: accepts Authorization: Token (mem0 SDK),
  Bearer, or X-API-Key.
- main.py: include the platform router; remove the /v1/memories* aliases added
  in ea07a82 (the SDK uses /v3 and trailing-slash /v1/memories/{id}/, not those
  paths); keep /v1/chat/completions and the native /memories* routes.
- docker-compose: run uvicorn with --proxy-headers --forwarded-allow-ips=* so the
  proxy's https scheme is honoured. This stops trailing-slash 307 redirects from
  downgrading https->http and dropping the Authorization header -- the actual
  cause of the reported "POST auth broken" symptom (auth was never broken).
- test_sdk_compat.py: end-to-end MemoryClient round-trip against the server.
This commit is contained in:
Pratik Narola 2026-05-26 00:09:22 +05:30
parent ea07a82bd7
commit ed11a00ab3
5 changed files with 366 additions and 7 deletions

View file

@ -130,6 +130,39 @@ async def get_current_user_openai(
return auth_service.verify_api_key(api_key)
async def get_current_user_platform(
authorization: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
) -> str:
"""FastAPI dependency for the mem0 platform SDK (``MemoryClient``).
``MemoryClient`` authenticates with ``Authorization: Token <key>``. We also
accept ``Bearer <key>`` and ``X-API-Key`` so the same routes work from curl
and OpenAI-style clients. Resolves to the mapped user_id (raises 401 if the
key is missing or invalid).
"""
api_key = None
if authorization:
scheme, _, rest = authorization.partition(" ")
if rest and scheme.lower() in ("token", "bearer"):
api_key = rest.strip()
else:
# Bare value with no recognised scheme — treat the whole header as the key.
api_key = authorization.strip()
if not api_key and x_api_key:
api_key = x_api_key
if not api_key:
logger.warning("No API key provided in Authorization or X-API-Key headers")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API key. Provide 'Authorization: Token <key>' (mem0 SDK), 'Bearer <key>', or 'X-API-Key'.",
)
return auth_service.verify_api_key(api_key)
async def verify_user_access(
api_key: str = Security(api_key_header), user_id: Optional[str] = None
) -> str:

View file

@ -47,6 +47,7 @@ from models import (
)
from mem0_manager import mem0_manager
from auth import get_current_user, get_current_user_openai, auth_service
from platform_compat import router as platform_router
# Configure structured logging
structlog.configure(
@ -401,7 +402,6 @@ async def openai_chat_completions(
# Memory management endpoints - pure Mem0 passthroughs
@app.post("/memories")
@app.post("/v1/memories")
@limiter.limit("60/minute") # Memory operations - 60/min
async def add_memories(
request: Request,
@ -440,7 +440,6 @@ async def add_memories(
@app.post("/memories/search")
@app.post("/v1/memories/search")
@limiter.limit("120/minute") # Search is lighter - 120/min
async def search_memories(
request: Request,
@ -481,7 +480,6 @@ async def search_memories(
detail="An internal error occurred. Please try again later.",
)
@app.get("/v1/memories/{user_id}")
@app.get("/memories/{user_id}")
@limiter.limit("120/minute")
async def get_user_memories(
@ -520,7 +518,6 @@ async def get_user_memories(
@app.put("/memories")
@app.put("/v1/memories")
@limiter.limit("60/minute")
async def update_memory(
request: Request,
@ -565,7 +562,6 @@ async def update_memory(
detail="An internal error occurred. Please try again later.",
)
@app.delete("/v1/memories/{memory_id}")
@app.delete("/memories/{memory_id}")
@limiter.limit("60/minute")
async def delete_memory(
@ -600,7 +596,6 @@ async def delete_memory(
)
@app.delete("/v1/memories/user/{user_id}")
@app.delete("/memories/user/{user_id}")
@limiter.limit("10/minute") # Dangerous bulk delete - heavily rate limited
async def delete_user_memories(
@ -837,6 +832,11 @@ async def get_active_users(
# mem0 platform SDK (MemoryClient) compatibility routes:
# /v1/ping/, /v3/memories/{add,search}/, /v3/memories/, /v1/memories/*, /v1/entities/
app.include_router(platform_router)
# Mount MCP server at /mcp endpoint
try:
from mcp_server import create_mcp_app

224
backend/platform_compat.py Normal file
View file

@ -0,0 +1,224 @@
"""mem0 platform API compatibility layer.
Implements the subset of the hosted mem0 platform API that the ``mem0ai``
``MemoryClient`` (pinned ``mem0ai==2.0.2``) actually calls, so
``MemoryClient(host="https://memory.pratikn.com", api_key=...)`` works against
this self-hosted server. Each platform route maps onto the existing
``mem0_manager`` singleton.
Contract notes (verified against the installed SDK source):
- Auth header is ``Authorization: Token <key>`` (handled by
``get_current_user_platform``, which also accepts Bearer / X-API-Key).
- Core ops use ``/v3/memories/*``; item ops use ``/v1/memories/*``; all paths
carry a trailing slash and are registered here at the exact path the SDK calls
(so FastAPI matches exactly and never issues a slash redirect).
- ``GET /v1/ping/`` runs at client construction and MUST return non-empty
``org_id`` and ``project_id`` or the SDK's ``Project(...)`` raises.
- Scoping (user_id/agent_id/run_id) is carried in ``filters`` (search/get_all),
the top-level body (add), or the query string (delete_all). It defaults to the
authenticated user; a mismatch is rejected with 403 (same model as the native
endpoints).
"""
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
import structlog
from auth import get_current_user_platform
from mem0_manager import mem0_manager
logger = structlog.get_logger(__name__)
router = APIRouter(tags=["mem0-platform-compat"])
def _require_self(requested: Optional[str], authed: str) -> str:
"""Return the user_id to operate on: default to the authenticated user,
reject a mismatch with 403 (consistent with the native endpoints)."""
if not requested:
return authed
if requested != authed:
raise HTTPException(
status_code=403,
detail=f"Access denied: you can only access your own memories (authenticated as '{authed}')",
)
return authed
async def _json_body(request: Request) -> Dict[str, Any]:
"""Parse the JSON body defensively (the SDK sends varied shapes)."""
try:
body = await request.json()
except Exception:
body = None
return body if isinstance(body, dict) else {}
def _split_filters(body: Dict[str, Any], authed: str):
"""Pull the scoping IDs out of the SDK's ``filters`` object and return
(target_user, agent_id, run_id, remaining_filters)."""
filters = dict(body.get("filters") or {})
target = _require_self(filters.pop("user_id", None), authed)
agent_id = filters.pop("agent_id", None)
run_id = filters.pop("run_id", None)
filters.pop("app_id", None) # accepted by the SDK; unused here
return target, agent_id, run_id, (filters or None)
async def _owned_or_404(memory_id: str, user: str) -> None:
if not await mem0_manager.verify_memory_ownership(memory_id, user):
raise HTTPException(status_code=404, detail=f"Memory '{memory_id}' not found")
@router.get("/v1/ping/")
async def ping(user: str = Depends(get_current_user_platform)) -> Dict[str, Any]:
"""Client construction validation. ``org_id``/``project_id`` MUST be
non-empty or the SDK's ``Project(...)`` raises 'org_id and project_id must
be set'."""
return {
"status": "ok",
"user_email": user,
"org_id": "default-org",
"project_id": "default-project",
}
@router.post("/v3/memories/add/")
async def add_memories(
request: Request, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
body = await _json_body(request)
messages = body.get("messages")
if not messages:
raise HTTPException(status_code=422, detail="`messages` is required")
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
target = _require_self(body.get("user_id"), user)
raw = await mem0_manager.add_memories(
messages=messages,
user_id=target,
agent_id=body.get("agent_id"),
run_id=body.get("run_id"),
metadata=body.get("metadata"),
)
# mem0_manager wraps the mem0 result as {"added_memories": [<mem0 dict>], ...};
# the mem0 dict is already {"results": [...]} (the platform shape).
added = raw.get("added_memories") or []
if added and isinstance(added[0], dict) and "results" in added[0]:
return added[0]
return {"results": added}
@router.post("/v3/memories/search/")
async def search_memories(
request: Request, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
body = await _json_body(request)
query = body.get("query")
if not query:
raise HTTPException(status_code=422, detail="`query` is required")
target, agent_id, run_id, extra = _split_filters(body, user)
top_k = body.get("top_k") or body.get("limit") or 10
result = await mem0_manager.search_memories(
query=query,
user_id=target,
limit=int(top_k),
filters=extra,
agent_id=agent_id,
run_id=run_id,
)
return {"results": result.get("memories", [])}
@router.post("/v3/memories/")
async def get_all_memories(
request: Request, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
body = await _json_body(request)
target, agent_id, run_id, extra = _split_filters(body, user)
# The SDK sends page/page_size as query params. mem0's get_all has no offset,
# so we fetch up to page*page_size and slice the requested page.
try:
page = max(int(request.query_params.get("page", 1)), 1)
page_size = min(max(int(request.query_params.get("page_size", 100)), 1), 1000)
except ValueError:
page, page_size = 1, 100
items = await mem0_manager.get_user_memories(
user_id=target,
limit=page * page_size,
agent_id=agent_id,
run_id=run_id,
filters=extra,
)
total = len(items)
start = (page - 1) * page_size
return {
"count": total,
"next": page + 1 if start + page_size < total else None,
"previous": page - 1 if page > 1 else None,
"results": items[start : start + page_size],
}
@router.get("/v1/memories/{memory_id}/")
async def get_memory(
memory_id: str, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
await _owned_or_404(memory_id, user)
mem = await mem0_manager.get_memory(memory_id)
if mem is None:
raise HTTPException(status_code=404, detail=f"Memory '{memory_id}' not found")
return mem
@router.put("/v1/memories/{memory_id}/")
async def update_memory(
memory_id: str, request: Request, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
await _owned_or_404(memory_id, user)
body = await _json_body(request)
text = body.get("text") or body.get("memory") or body.get("data")
if not text:
raise HTTPException(
status_code=422, detail="`text` is required to update a memory"
)
return await mem0_manager.update_memory(memory_id=memory_id, content=text)
@router.delete("/v1/memories/{memory_id}/")
async def delete_memory(
memory_id: str, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
await _owned_or_404(memory_id, user)
return await mem0_manager.delete_memory(memory_id=memory_id)
@router.get("/v1/memories/{memory_id}/history/")
async def memory_history(
memory_id: str, user: str = Depends(get_current_user_platform)
) -> List[Dict[str, Any]]:
await _owned_or_404(memory_id, user)
result = await mem0_manager.get_memory_history(memory_id)
return result.get("history", [])
@router.delete("/v1/memories/")
async def delete_all_memories(
request: Request, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
target = _require_self(request.query_params.get("user_id"), user)
return await mem0_manager.delete_user_memories(user_id=target)
@router.get("/v1/entities/")
async def list_entities(
user: str = Depends(get_current_user_platform),
) -> Dict[str, Any]:
# This server is single-user-per-key; report the authenticated user as the
# only entity (the platform returns all users/agents/runs with memories).
return {"results": [{"id": user, "name": user, "type": "user"}], "count": 1}

View file

@ -61,7 +61,7 @@ services:
volumes:
- ./backend:/app
- ./frontend:/app/frontend
command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4", "--proxy-headers", "--forwarded-allow-ips", "*"]
volumes:
qdrant_data:

102
test_sdk_compat.py Normal file
View file

@ -0,0 +1,102 @@
"""End-to-end compatibility test: drive the real mem0ai MemoryClient against
this server. Run INSIDE the backend container (which has mem0ai==2.0.2):
ssh beast 'cd ~/aistuff/mem0 && docker compose exec -T backend python' < test_sdk_compat.py
Requires MEM0_API_KEY in the environment (mapped to user 'pratik'). Exercises the
full SDK surface and cleans up the memories it creates (scoped to a throwaway
agent_id so it never touches real memories).
"""
import os
import sys
import time
from mem0 import MemoryClient
KEY = os.environ.get("MEM0_API_KEY")
if not KEY:
sys.exit("set MEM0_API_KEY (mapped to user 'pratik')")
HOST = os.environ.get("MEM0_HOST", "https://memory.pratikn.com")
USER = os.environ.get("MEM0_TEST_USER", "pratik")
AGENT = "sdk_compat_test" # isolates test data from real memories
results = []
def check(name, cond, info=""):
results.append(bool(cond))
print(("PASS " if cond else "FAIL "), name, "--", str(info)[:240])
# --- construct (hits GET /v1/ping/, validates Token auth + org/project) ---
try:
c = MemoryClient(host=HOST, api_key=KEY)
check("construct", True, f"user_email={c.user_email} org={c.org_id} proj={c.project_id}")
except Exception as e:
check("construct", False, repr(e))
print("\nRESULT: RED (cannot construct client)")
sys.exit(1)
# --- add (POST /v3/memories/add/) ---
probe = f"SDK compat probe {int(time.time())}: Pratik is validating the mem0 SDK compatibility layer and uses FastAPI."
try:
r = c.add(probe, user_id=USER, agent_id=AGENT)
check("add", isinstance(r, dict), r)
except Exception as e:
check("add", False, repr(e))
time.sleep(3) # allow async extraction/indexing to settle
# --- search (POST /v3/memories/search/) ---
try:
s = c.search("mem0 SDK compatibility layer", filters={"user_id": USER, "agent_id": AGENT})
check("search.shape", isinstance(s, dict) and "results" in s, s)
except Exception as e:
check("search.shape", False, repr(e))
# --- get_all (POST /v3/memories/) ---
ids = []
try:
g = c.get_all(filters={"user_id": USER, "agent_id": AGENT})
ok = isinstance(g, dict) and "results" in g and "count" in g
check("get_all.shape", ok, g)
ids = [m.get("id") for m in (g.get("results") or []) if m.get("id")]
except Exception as e:
check("get_all.shape", False, repr(e))
mid = ids[0] if ids else None
print(f" (created {len(ids)} memory id(s) under agent={AGENT})")
# --- item ops (best-effort; depend on extraction producing >=1 fact) ---
if mid:
try:
one = c.get(mid)
check("get", isinstance(one, dict) and one.get("id") == mid, one)
except Exception as e:
check("get", False, repr(e))
try:
h = c.history(mid)
check("history.is_list", isinstance(h, list), h)
except Exception as e:
check("history.is_list", False, repr(e))
try:
u = c.update(mid, text="SDK compat probe (updated)")
check("update", isinstance(u, dict), u)
except Exception as e:
check("update", False, repr(e))
# --- cleanup: delete only the ids we created ---
deleted = 0
for i in ids:
try:
c.delete(i)
deleted += 1
except Exception as e:
print(" delete error", i, repr(e))
print(f" cleanup: deleted {deleted}/{len(ids)}")
green = results and all(results)
print("\nRESULT:", "GREEN" if green else "RED", f"({sum(results)}/{len(results)} checks passed)")
sys.exit(0 if green else 1)