knowledge-base/backend/platform_compat.py
Pratik Narola ed11a00ab3 feat: mem0 platform SDK (MemoryClient) compatibility + proxy-header redirect fix
Implements the subset of the hosted mem0 platform API that mem0ai==2.0.2
MemoryClient calls, so MemoryClient(host=..., api_key=...) works against this
server. Verified end-to-end (construct/add/search/get_all/get/history/update/delete).

- platform_compat.py: GET /v1/ping/ (returns non-empty org_id/project_id, which
  the SDK's Project init requires), POST /v3/memories/{add,search}/,
  POST /v3/memories/ (paginated get_all), /v1/memories/{id}/ item ops, and
  GET /v1/entities/ -- all mapped onto the existing mem0_manager.
- auth.get_current_user_platform: accepts Authorization: Token (mem0 SDK),
  Bearer, or X-API-Key.
- main.py: include the platform router; remove the /v1/memories* aliases added
  in ea07a82 (the SDK uses /v3 and trailing-slash /v1/memories/{id}/, not those
  paths); keep /v1/chat/completions and the native /memories* routes.
- docker-compose: run uvicorn with --proxy-headers --forwarded-allow-ips=* so the
  proxy's https scheme is honoured. This stops trailing-slash 307 redirects from
  downgrading https->http and dropping the Authorization header -- the actual
  cause of the reported "POST auth broken" symptom (auth was never broken).
- test_sdk_compat.py: end-to-end MemoryClient round-trip against the server.
2026-05-26 00:09:22 +05:30

224 lines
8.1 KiB
Python

"""mem0 platform API compatibility layer.
Implements the subset of the hosted mem0 platform API that the ``mem0ai``
``MemoryClient`` (pinned ``mem0ai==2.0.2``) actually calls, so
``MemoryClient(host="https://memory.pratikn.com", api_key=...)`` works against
this self-hosted server. Each platform route maps onto the existing
``mem0_manager`` singleton.
Contract notes (verified against the installed SDK source):
- Auth header is ``Authorization: Token <key>`` (handled by
``get_current_user_platform``, which also accepts Bearer / X-API-Key).
- Core ops use ``/v3/memories/*``; item ops use ``/v1/memories/*``; all paths
carry a trailing slash and are registered here at the exact path the SDK calls
(so FastAPI matches exactly and never issues a slash redirect).
- ``GET /v1/ping/`` runs at client construction and MUST return non-empty
``org_id`` and ``project_id`` or the SDK's ``Project(...)`` raises.
- Scoping (user_id/agent_id/run_id) is carried in ``filters`` (search/get_all),
the top-level body (add), or the query string (delete_all). It defaults to the
authenticated user; a mismatch is rejected with 403 (same model as the native
endpoints).
"""
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
import structlog
from auth import get_current_user_platform
from mem0_manager import mem0_manager
logger = structlog.get_logger(__name__)
router = APIRouter(tags=["mem0-platform-compat"])
def _require_self(requested: Optional[str], authed: str) -> str:
"""Return the user_id to operate on: default to the authenticated user,
reject a mismatch with 403 (consistent with the native endpoints)."""
if not requested:
return authed
if requested != authed:
raise HTTPException(
status_code=403,
detail=f"Access denied: you can only access your own memories (authenticated as '{authed}')",
)
return authed
async def _json_body(request: Request) -> Dict[str, Any]:
"""Parse the JSON body defensively (the SDK sends varied shapes)."""
try:
body = await request.json()
except Exception:
body = None
return body if isinstance(body, dict) else {}
def _split_filters(body: Dict[str, Any], authed: str):
"""Pull the scoping IDs out of the SDK's ``filters`` object and return
(target_user, agent_id, run_id, remaining_filters)."""
filters = dict(body.get("filters") or {})
target = _require_self(filters.pop("user_id", None), authed)
agent_id = filters.pop("agent_id", None)
run_id = filters.pop("run_id", None)
filters.pop("app_id", None) # accepted by the SDK; unused here
return target, agent_id, run_id, (filters or None)
async def _owned_or_404(memory_id: str, user: str) -> None:
if not await mem0_manager.verify_memory_ownership(memory_id, user):
raise HTTPException(status_code=404, detail=f"Memory '{memory_id}' not found")
@router.get("/v1/ping/")
async def ping(user: str = Depends(get_current_user_platform)) -> Dict[str, Any]:
"""Client construction validation. ``org_id``/``project_id`` MUST be
non-empty or the SDK's ``Project(...)`` raises 'org_id and project_id must
be set'."""
return {
"status": "ok",
"user_email": user,
"org_id": "default-org",
"project_id": "default-project",
}
@router.post("/v3/memories/add/")
async def add_memories(
request: Request, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
body = await _json_body(request)
messages = body.get("messages")
if not messages:
raise HTTPException(status_code=422, detail="`messages` is required")
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
target = _require_self(body.get("user_id"), user)
raw = await mem0_manager.add_memories(
messages=messages,
user_id=target,
agent_id=body.get("agent_id"),
run_id=body.get("run_id"),
metadata=body.get("metadata"),
)
# mem0_manager wraps the mem0 result as {"added_memories": [<mem0 dict>], ...};
# the mem0 dict is already {"results": [...]} (the platform shape).
added = raw.get("added_memories") or []
if added and isinstance(added[0], dict) and "results" in added[0]:
return added[0]
return {"results": added}
@router.post("/v3/memories/search/")
async def search_memories(
request: Request, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
body = await _json_body(request)
query = body.get("query")
if not query:
raise HTTPException(status_code=422, detail="`query` is required")
target, agent_id, run_id, extra = _split_filters(body, user)
top_k = body.get("top_k") or body.get("limit") or 10
result = await mem0_manager.search_memories(
query=query,
user_id=target,
limit=int(top_k),
filters=extra,
agent_id=agent_id,
run_id=run_id,
)
return {"results": result.get("memories", [])}
@router.post("/v3/memories/")
async def get_all_memories(
request: Request, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
body = await _json_body(request)
target, agent_id, run_id, extra = _split_filters(body, user)
# The SDK sends page/page_size as query params. mem0's get_all has no offset,
# so we fetch up to page*page_size and slice the requested page.
try:
page = max(int(request.query_params.get("page", 1)), 1)
page_size = min(max(int(request.query_params.get("page_size", 100)), 1), 1000)
except ValueError:
page, page_size = 1, 100
items = await mem0_manager.get_user_memories(
user_id=target,
limit=page * page_size,
agent_id=agent_id,
run_id=run_id,
filters=extra,
)
total = len(items)
start = (page - 1) * page_size
return {
"count": total,
"next": page + 1 if start + page_size < total else None,
"previous": page - 1 if page > 1 else None,
"results": items[start : start + page_size],
}
@router.get("/v1/memories/{memory_id}/")
async def get_memory(
memory_id: str, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
await _owned_or_404(memory_id, user)
mem = await mem0_manager.get_memory(memory_id)
if mem is None:
raise HTTPException(status_code=404, detail=f"Memory '{memory_id}' not found")
return mem
@router.put("/v1/memories/{memory_id}/")
async def update_memory(
memory_id: str, request: Request, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
await _owned_or_404(memory_id, user)
body = await _json_body(request)
text = body.get("text") or body.get("memory") or body.get("data")
if not text:
raise HTTPException(
status_code=422, detail="`text` is required to update a memory"
)
return await mem0_manager.update_memory(memory_id=memory_id, content=text)
@router.delete("/v1/memories/{memory_id}/")
async def delete_memory(
memory_id: str, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
await _owned_or_404(memory_id, user)
return await mem0_manager.delete_memory(memory_id=memory_id)
@router.get("/v1/memories/{memory_id}/history/")
async def memory_history(
memory_id: str, user: str = Depends(get_current_user_platform)
) -> List[Dict[str, Any]]:
await _owned_or_404(memory_id, user)
result = await mem0_manager.get_memory_history(memory_id)
return result.get("history", [])
@router.delete("/v1/memories/")
async def delete_all_memories(
request: Request, user: str = Depends(get_current_user_platform)
) -> Dict[str, Any]:
target = _require_self(request.query_params.get("user_id"), user)
return await mem0_manager.delete_user_memories(user_id=target)
@router.get("/v1/entities/")
async def list_entities(
user: str = Depends(get_current_user_platform),
) -> Dict[str, Any]:
# This server is single-user-per-key; report the authenticated user as the
# only entity (the platform returns all users/agents/runs with memories).
return {"results": [{"id": user, "name": user, "type": "user"}], "count": 1}