feat: mem0 platform SDK (MemoryClient) compatibility + proxy-header redirect fix
Implements the subset of the hosted mem0 platform API that mem0ai==2.0.2
MemoryClient calls, so MemoryClient(host=..., api_key=...) works against this
server. Verified end-to-end (construct/add/search/get_all/get/history/update/delete).
- platform_compat.py: GET /v1/ping/ (returns non-empty org_id/project_id, which
the SDK's Project init requires), POST /v3/memories/{add,search}/,
POST /v3/memories/ (paginated get_all), /v1/memories/{id}/ item ops, and
GET /v1/entities/ -- all mapped onto the existing mem0_manager.
- auth.get_current_user_platform: accepts Authorization: Token (mem0 SDK),
Bearer, or X-API-Key.
- main.py: include the platform router; remove the /v1/memories* aliases added
in ea07a82 (the SDK uses /v3 and trailing-slash /v1/memories/{id}/, not those
paths); keep /v1/chat/completions and the native /memories* routes.
- docker-compose: run uvicorn with --proxy-headers --forwarded-allow-ips=* so the
proxy's https scheme is honoured. This stops trailing-slash 307 redirects from
downgrading https->http and dropping the Authorization header -- the actual
cause of the reported "POST auth broken" symptom (auth was never broken).
- test_sdk_compat.py: end-to-end MemoryClient round-trip against the server.
This commit is contained in:
parent
ea07a82bd7
commit
ed11a00ab3
5 changed files with 366 additions and 7 deletions
|
|
@ -130,6 +130,39 @@ async def get_current_user_openai(
|
|||
return auth_service.verify_api_key(api_key)
|
||||
|
||||
|
||||
async def get_current_user_platform(
|
||||
authorization: Optional[str] = Header(None),
|
||||
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
|
||||
) -> str:
|
||||
"""FastAPI dependency for the mem0 platform SDK (``MemoryClient``).
|
||||
|
||||
``MemoryClient`` authenticates with ``Authorization: Token <key>``. We also
|
||||
accept ``Bearer <key>`` and ``X-API-Key`` so the same routes work from curl
|
||||
and OpenAI-style clients. Resolves to the mapped user_id (raises 401 if the
|
||||
key is missing or invalid).
|
||||
"""
|
||||
api_key = None
|
||||
|
||||
if authorization:
|
||||
scheme, _, rest = authorization.partition(" ")
|
||||
if rest and scheme.lower() in ("token", "bearer"):
|
||||
api_key = rest.strip()
|
||||
else:
|
||||
# Bare value with no recognised scheme — treat the whole header as the key.
|
||||
api_key = authorization.strip()
|
||||
if not api_key and x_api_key:
|
||||
api_key = x_api_key
|
||||
|
||||
if not api_key:
|
||||
logger.warning("No API key provided in Authorization or X-API-Key headers")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing API key. Provide 'Authorization: Token <key>' (mem0 SDK), 'Bearer <key>', or 'X-API-Key'.",
|
||||
)
|
||||
|
||||
return auth_service.verify_api_key(api_key)
|
||||
|
||||
|
||||
async def verify_user_access(
|
||||
api_key: str = Security(api_key_header), user_id: Optional[str] = None
|
||||
) -> str:
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ from models import (
|
|||
)
|
||||
from mem0_manager import mem0_manager
|
||||
from auth import get_current_user, get_current_user_openai, auth_service
|
||||
from platform_compat import router as platform_router
|
||||
|
||||
# Configure structured logging
|
||||
structlog.configure(
|
||||
|
|
@ -401,7 +402,6 @@ async def openai_chat_completions(
|
|||
|
||||
# Memory management endpoints - pure Mem0 passthroughs
|
||||
@app.post("/memories")
|
||||
@app.post("/v1/memories")
|
||||
@limiter.limit("60/minute") # Memory operations - 60/min
|
||||
async def add_memories(
|
||||
request: Request,
|
||||
|
|
@ -440,7 +440,6 @@ async def add_memories(
|
|||
|
||||
|
||||
@app.post("/memories/search")
|
||||
@app.post("/v1/memories/search")
|
||||
@limiter.limit("120/minute") # Search is lighter - 120/min
|
||||
async def search_memories(
|
||||
request: Request,
|
||||
|
|
@ -481,7 +480,6 @@ async def search_memories(
|
|||
detail="An internal error occurred. Please try again later.",
|
||||
)
|
||||
|
||||
@app.get("/v1/memories/{user_id}")
|
||||
@app.get("/memories/{user_id}")
|
||||
@limiter.limit("120/minute")
|
||||
async def get_user_memories(
|
||||
|
|
@ -520,7 +518,6 @@ async def get_user_memories(
|
|||
|
||||
|
||||
@app.put("/memories")
|
||||
@app.put("/v1/memories")
|
||||
@limiter.limit("60/minute")
|
||||
async def update_memory(
|
||||
request: Request,
|
||||
|
|
@ -565,7 +562,6 @@ async def update_memory(
|
|||
detail="An internal error occurred. Please try again later.",
|
||||
)
|
||||
|
||||
@app.delete("/v1/memories/{memory_id}")
|
||||
@app.delete("/memories/{memory_id}")
|
||||
@limiter.limit("60/minute")
|
||||
async def delete_memory(
|
||||
|
|
@ -600,7 +596,6 @@ async def delete_memory(
|
|||
)
|
||||
|
||||
|
||||
@app.delete("/v1/memories/user/{user_id}")
|
||||
@app.delete("/memories/user/{user_id}")
|
||||
@limiter.limit("10/minute") # Dangerous bulk delete - heavily rate limited
|
||||
async def delete_user_memories(
|
||||
|
|
@ -837,6 +832,11 @@ async def get_active_users(
|
|||
|
||||
|
||||
|
||||
# mem0 platform SDK (MemoryClient) compatibility routes:
|
||||
# /v1/ping/, /v3/memories/{add,search}/, /v3/memories/, /v1/memories/*, /v1/entities/
|
||||
app.include_router(platform_router)
|
||||
|
||||
|
||||
# Mount MCP server at /mcp endpoint
|
||||
try:
|
||||
from mcp_server import create_mcp_app
|
||||
|
|
|
|||
224
backend/platform_compat.py
Normal file
224
backend/platform_compat.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""mem0 platform API compatibility layer.
|
||||
|
||||
Implements the subset of the hosted mem0 platform API that the ``mem0ai``
|
||||
``MemoryClient`` (pinned ``mem0ai==2.0.2``) actually calls, so
|
||||
``MemoryClient(host="https://memory.pratikn.com", api_key=...)`` works against
|
||||
this self-hosted server. Each platform route maps onto the existing
|
||||
``mem0_manager`` singleton.
|
||||
|
||||
Contract notes (verified against the installed SDK source):
|
||||
- Auth header is ``Authorization: Token <key>`` (handled by
|
||||
``get_current_user_platform``, which also accepts Bearer / X-API-Key).
|
||||
- Core ops use ``/v3/memories/*``; item ops use ``/v1/memories/*``; all paths
|
||||
carry a trailing slash and are registered here at the exact path the SDK calls
|
||||
(so FastAPI matches exactly and never issues a slash redirect).
|
||||
- ``GET /v1/ping/`` runs at client construction and MUST return non-empty
|
||||
``org_id`` and ``project_id`` or the SDK's ``Project(...)`` raises.
|
||||
- Scoping (user_id/agent_id/run_id) is carried in ``filters`` (search/get_all),
|
||||
the top-level body (add), or the query string (delete_all). It defaults to the
|
||||
authenticated user; a mismatch is rejected with 403 (same model as the native
|
||||
endpoints).
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
import structlog
|
||||
|
||||
from auth import get_current_user_platform
|
||||
from mem0_manager import mem0_manager
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["mem0-platform-compat"])
|
||||
|
||||
|
||||
def _require_self(requested: Optional[str], authed: str) -> str:
|
||||
"""Return the user_id to operate on: default to the authenticated user,
|
||||
reject a mismatch with 403 (consistent with the native endpoints)."""
|
||||
if not requested:
|
||||
return authed
|
||||
if requested != authed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Access denied: you can only access your own memories (authenticated as '{authed}')",
|
||||
)
|
||||
return authed
|
||||
|
||||
|
||||
async def _json_body(request: Request) -> Dict[str, Any]:
|
||||
"""Parse the JSON body defensively (the SDK sends varied shapes)."""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
body = None
|
||||
return body if isinstance(body, dict) else {}
|
||||
|
||||
|
||||
def _split_filters(body: Dict[str, Any], authed: str):
|
||||
"""Pull the scoping IDs out of the SDK's ``filters`` object and return
|
||||
(target_user, agent_id, run_id, remaining_filters)."""
|
||||
filters = dict(body.get("filters") or {})
|
||||
target = _require_self(filters.pop("user_id", None), authed)
|
||||
agent_id = filters.pop("agent_id", None)
|
||||
run_id = filters.pop("run_id", None)
|
||||
filters.pop("app_id", None) # accepted by the SDK; unused here
|
||||
return target, agent_id, run_id, (filters or None)
|
||||
|
||||
|
||||
async def _owned_or_404(memory_id: str, user: str) -> None:
|
||||
if not await mem0_manager.verify_memory_ownership(memory_id, user):
|
||||
raise HTTPException(status_code=404, detail=f"Memory '{memory_id}' not found")
|
||||
|
||||
|
||||
@router.get("/v1/ping/")
|
||||
async def ping(user: str = Depends(get_current_user_platform)) -> Dict[str, Any]:
|
||||
"""Client construction validation. ``org_id``/``project_id`` MUST be
|
||||
non-empty or the SDK's ``Project(...)`` raises 'org_id and project_id must
|
||||
be set'."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"user_email": user,
|
||||
"org_id": "default-org",
|
||||
"project_id": "default-project",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/v3/memories/add/")
|
||||
async def add_memories(
|
||||
request: Request, user: str = Depends(get_current_user_platform)
|
||||
) -> Dict[str, Any]:
|
||||
body = await _json_body(request)
|
||||
messages = body.get("messages")
|
||||
if not messages:
|
||||
raise HTTPException(status_code=422, detail="`messages` is required")
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
target = _require_self(body.get("user_id"), user)
|
||||
|
||||
raw = await mem0_manager.add_memories(
|
||||
messages=messages,
|
||||
user_id=target,
|
||||
agent_id=body.get("agent_id"),
|
||||
run_id=body.get("run_id"),
|
||||
metadata=body.get("metadata"),
|
||||
)
|
||||
# mem0_manager wraps the mem0 result as {"added_memories": [<mem0 dict>], ...};
|
||||
# the mem0 dict is already {"results": [...]} (the platform shape).
|
||||
added = raw.get("added_memories") or []
|
||||
if added and isinstance(added[0], dict) and "results" in added[0]:
|
||||
return added[0]
|
||||
return {"results": added}
|
||||
|
||||
|
||||
@router.post("/v3/memories/search/")
|
||||
async def search_memories(
|
||||
request: Request, user: str = Depends(get_current_user_platform)
|
||||
) -> Dict[str, Any]:
|
||||
body = await _json_body(request)
|
||||
query = body.get("query")
|
||||
if not query:
|
||||
raise HTTPException(status_code=422, detail="`query` is required")
|
||||
target, agent_id, run_id, extra = _split_filters(body, user)
|
||||
top_k = body.get("top_k") or body.get("limit") or 10
|
||||
|
||||
result = await mem0_manager.search_memories(
|
||||
query=query,
|
||||
user_id=target,
|
||||
limit=int(top_k),
|
||||
filters=extra,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
return {"results": result.get("memories", [])}
|
||||
|
||||
|
||||
@router.post("/v3/memories/")
|
||||
async def get_all_memories(
|
||||
request: Request, user: str = Depends(get_current_user_platform)
|
||||
) -> Dict[str, Any]:
|
||||
body = await _json_body(request)
|
||||
target, agent_id, run_id, extra = _split_filters(body, user)
|
||||
|
||||
# The SDK sends page/page_size as query params. mem0's get_all has no offset,
|
||||
# so we fetch up to page*page_size and slice the requested page.
|
||||
try:
|
||||
page = max(int(request.query_params.get("page", 1)), 1)
|
||||
page_size = min(max(int(request.query_params.get("page_size", 100)), 1), 1000)
|
||||
except ValueError:
|
||||
page, page_size = 1, 100
|
||||
|
||||
items = await mem0_manager.get_user_memories(
|
||||
user_id=target,
|
||||
limit=page * page_size,
|
||||
agent_id=agent_id,
|
||||
run_id=run_id,
|
||||
filters=extra,
|
||||
)
|
||||
total = len(items)
|
||||
start = (page - 1) * page_size
|
||||
return {
|
||||
"count": total,
|
||||
"next": page + 1 if start + page_size < total else None,
|
||||
"previous": page - 1 if page > 1 else None,
|
||||
"results": items[start : start + page_size],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/memories/{memory_id}/")
|
||||
async def get_memory(
|
||||
memory_id: str, user: str = Depends(get_current_user_platform)
|
||||
) -> Dict[str, Any]:
|
||||
await _owned_or_404(memory_id, user)
|
||||
mem = await mem0_manager.get_memory(memory_id)
|
||||
if mem is None:
|
||||
raise HTTPException(status_code=404, detail=f"Memory '{memory_id}' not found")
|
||||
return mem
|
||||
|
||||
|
||||
@router.put("/v1/memories/{memory_id}/")
|
||||
async def update_memory(
|
||||
memory_id: str, request: Request, user: str = Depends(get_current_user_platform)
|
||||
) -> Dict[str, Any]:
|
||||
await _owned_or_404(memory_id, user)
|
||||
body = await _json_body(request)
|
||||
text = body.get("text") or body.get("memory") or body.get("data")
|
||||
if not text:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="`text` is required to update a memory"
|
||||
)
|
||||
return await mem0_manager.update_memory(memory_id=memory_id, content=text)
|
||||
|
||||
|
||||
@router.delete("/v1/memories/{memory_id}/")
|
||||
async def delete_memory(
|
||||
memory_id: str, user: str = Depends(get_current_user_platform)
|
||||
) -> Dict[str, Any]:
|
||||
await _owned_or_404(memory_id, user)
|
||||
return await mem0_manager.delete_memory(memory_id=memory_id)
|
||||
|
||||
|
||||
@router.get("/v1/memories/{memory_id}/history/")
|
||||
async def memory_history(
|
||||
memory_id: str, user: str = Depends(get_current_user_platform)
|
||||
) -> List[Dict[str, Any]]:
|
||||
await _owned_or_404(memory_id, user)
|
||||
result = await mem0_manager.get_memory_history(memory_id)
|
||||
return result.get("history", [])
|
||||
|
||||
|
||||
@router.delete("/v1/memories/")
|
||||
async def delete_all_memories(
|
||||
request: Request, user: str = Depends(get_current_user_platform)
|
||||
) -> Dict[str, Any]:
|
||||
target = _require_self(request.query_params.get("user_id"), user)
|
||||
return await mem0_manager.delete_user_memories(user_id=target)
|
||||
|
||||
|
||||
@router.get("/v1/entities/")
|
||||
async def list_entities(
|
||||
user: str = Depends(get_current_user_platform),
|
||||
) -> Dict[str, Any]:
|
||||
# This server is single-user-per-key; report the authenticated user as the
|
||||
# only entity (the platform returns all users/agents/runs with memories).
|
||||
return {"results": [{"id": user, "name": user, "type": "user"}], "count": 1}
|
||||
|
|
@ -61,7 +61,7 @@ services:
|
|||
volumes:
|
||||
- ./backend:/app
|
||||
- ./frontend:/app/frontend
|
||||
command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
||||
command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4", "--proxy-headers", "--forwarded-allow-ips", "*"]
|
||||
|
||||
volumes:
|
||||
qdrant_data:
|
||||
|
|
|
|||
102
test_sdk_compat.py
Normal file
102
test_sdk_compat.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""End-to-end compatibility test: drive the real mem0ai MemoryClient against
|
||||
this server. Run INSIDE the backend container (which has mem0ai==2.0.2):
|
||||
|
||||
ssh beast 'cd ~/aistuff/mem0 && docker compose exec -T backend python' < test_sdk_compat.py
|
||||
|
||||
Requires MEM0_API_KEY in the environment (mapped to user 'pratik'). Exercises the
|
||||
full SDK surface and cleans up the memories it creates (scoped to a throwaway
|
||||
agent_id so it never touches real memories).
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from mem0 import MemoryClient
|
||||
|
||||
KEY = os.environ.get("MEM0_API_KEY")
|
||||
if not KEY:
|
||||
sys.exit("set MEM0_API_KEY (mapped to user 'pratik')")
|
||||
|
||||
HOST = os.environ.get("MEM0_HOST", "https://memory.pratikn.com")
|
||||
USER = os.environ.get("MEM0_TEST_USER", "pratik")
|
||||
AGENT = "sdk_compat_test" # isolates test data from real memories
|
||||
|
||||
results = []
|
||||
|
||||
|
||||
def check(name, cond, info=""):
|
||||
results.append(bool(cond))
|
||||
print(("PASS " if cond else "FAIL "), name, "--", str(info)[:240])
|
||||
|
||||
|
||||
# --- construct (hits GET /v1/ping/, validates Token auth + org/project) ---
|
||||
try:
|
||||
c = MemoryClient(host=HOST, api_key=KEY)
|
||||
check("construct", True, f"user_email={c.user_email} org={c.org_id} proj={c.project_id}")
|
||||
except Exception as e:
|
||||
check("construct", False, repr(e))
|
||||
print("\nRESULT: RED (cannot construct client)")
|
||||
sys.exit(1)
|
||||
|
||||
# --- add (POST /v3/memories/add/) ---
|
||||
probe = f"SDK compat probe {int(time.time())}: Pratik is validating the mem0 SDK compatibility layer and uses FastAPI."
|
||||
try:
|
||||
r = c.add(probe, user_id=USER, agent_id=AGENT)
|
||||
check("add", isinstance(r, dict), r)
|
||||
except Exception as e:
|
||||
check("add", False, repr(e))
|
||||
|
||||
time.sleep(3) # allow async extraction/indexing to settle
|
||||
|
||||
# --- search (POST /v3/memories/search/) ---
|
||||
try:
|
||||
s = c.search("mem0 SDK compatibility layer", filters={"user_id": USER, "agent_id": AGENT})
|
||||
check("search.shape", isinstance(s, dict) and "results" in s, s)
|
||||
except Exception as e:
|
||||
check("search.shape", False, repr(e))
|
||||
|
||||
# --- get_all (POST /v3/memories/) ---
|
||||
ids = []
|
||||
try:
|
||||
g = c.get_all(filters={"user_id": USER, "agent_id": AGENT})
|
||||
ok = isinstance(g, dict) and "results" in g and "count" in g
|
||||
check("get_all.shape", ok, g)
|
||||
ids = [m.get("id") for m in (g.get("results") or []) if m.get("id")]
|
||||
except Exception as e:
|
||||
check("get_all.shape", False, repr(e))
|
||||
|
||||
mid = ids[0] if ids else None
|
||||
print(f" (created {len(ids)} memory id(s) under agent={AGENT})")
|
||||
|
||||
# --- item ops (best-effort; depend on extraction producing >=1 fact) ---
|
||||
if mid:
|
||||
try:
|
||||
one = c.get(mid)
|
||||
check("get", isinstance(one, dict) and one.get("id") == mid, one)
|
||||
except Exception as e:
|
||||
check("get", False, repr(e))
|
||||
try:
|
||||
h = c.history(mid)
|
||||
check("history.is_list", isinstance(h, list), h)
|
||||
except Exception as e:
|
||||
check("history.is_list", False, repr(e))
|
||||
try:
|
||||
u = c.update(mid, text="SDK compat probe (updated)")
|
||||
check("update", isinstance(u, dict), u)
|
||||
except Exception as e:
|
||||
check("update", False, repr(e))
|
||||
|
||||
# --- cleanup: delete only the ids we created ---
|
||||
deleted = 0
|
||||
for i in ids:
|
||||
try:
|
||||
c.delete(i)
|
||||
deleted += 1
|
||||
except Exception as e:
|
||||
print(" delete error", i, repr(e))
|
||||
print(f" cleanup: deleted {deleted}/{len(ids)}")
|
||||
|
||||
green = results and all(results)
|
||||
print("\nRESULT:", "GREEN" if green else "RED", f"({sum(results)}/{len(results)} checks passed)")
|
||||
sys.exit(0 if green else 1)
|
||||
Loading…
Reference in a new issue