Implements the subset of the hosted mem0 platform API that mem0ai==2.0.2
MemoryClient calls, so MemoryClient(host=..., api_key=...) works against this
server. Verified end-to-end (construct/add/search/get_all/get/history/update/delete).
- platform_compat.py: GET /v1/ping/ (returns non-empty org_id/project_id, which
the SDK's Project init requires), POST /v3/memories/{add,search}/,
POST /v3/memories/ (paginated get_all), /v1/memories/{id}/ item ops, and
GET /v1/entities/ -- all mapped onto the existing mem0_manager.
- auth.get_current_user_platform: accepts Authorization: Token (mem0 SDK),
Bearer, or X-API-Key.
- main.py: include the platform router; remove the /v1/memories* aliases added
in ea07a82 (the SDK uses /v3 and trailing-slash /v1/memories/{id}/, not those
paths); keep /v1/chat/completions and the native /memories* routes.
- docker-compose: run uvicorn with --proxy-headers --forwarded-allow-ips=* so the
proxy's https scheme is honoured. This stops trailing-slash 307 redirects from
downgrading https->http and dropping the Authorization header -- the actual
cause of the reported "POST auth broken" symptom (auth was never broken).
- test_sdk_compat.py: end-to-end MemoryClient round-trip against the server.
224 lines
8.1 KiB
Python
224 lines
8.1 KiB
Python
"""mem0 platform API compatibility layer.
|
|
|
|
Implements the subset of the hosted mem0 platform API that the ``mem0ai``
|
|
``MemoryClient`` (pinned ``mem0ai==2.0.2``) actually calls, so
|
|
``MemoryClient(host="https://memory.pratikn.com", api_key=...)`` works against
|
|
this self-hosted server. Each platform route maps onto the existing
|
|
``mem0_manager`` singleton.
|
|
|
|
Contract notes (verified against the installed SDK source):
|
|
- Auth header is ``Authorization: Token <key>`` (handled by
|
|
``get_current_user_platform``, which also accepts Bearer / X-API-Key).
|
|
- Core ops use ``/v3/memories/*``; item ops use ``/v1/memories/*``; all paths
|
|
carry a trailing slash and are registered here at the exact path the SDK calls
|
|
(so FastAPI matches exactly and never issues a slash redirect).
|
|
- ``GET /v1/ping/`` runs at client construction and MUST return non-empty
|
|
``org_id`` and ``project_id`` or the SDK's ``Project(...)`` raises.
|
|
- Scoping (user_id/agent_id/run_id) is carried in ``filters`` (search/get_all),
|
|
the top-level body (add), or the query string (delete_all). It defaults to the
|
|
authenticated user; a mismatch is rejected with 403 (same model as the native
|
|
endpoints).
|
|
"""
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
import structlog
|
|
|
|
from auth import get_current_user_platform
|
|
from mem0_manager import mem0_manager
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
router = APIRouter(tags=["mem0-platform-compat"])
|
|
|
|
|
|
def _require_self(requested: Optional[str], authed: str) -> str:
|
|
"""Return the user_id to operate on: default to the authenticated user,
|
|
reject a mismatch with 403 (consistent with the native endpoints)."""
|
|
if not requested:
|
|
return authed
|
|
if requested != authed:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"Access denied: you can only access your own memories (authenticated as '{authed}')",
|
|
)
|
|
return authed
|
|
|
|
|
|
async def _json_body(request: Request) -> Dict[str, Any]:
|
|
"""Parse the JSON body defensively (the SDK sends varied shapes)."""
|
|
try:
|
|
body = await request.json()
|
|
except Exception:
|
|
body = None
|
|
return body if isinstance(body, dict) else {}
|
|
|
|
|
|
def _split_filters(body: Dict[str, Any], authed: str):
|
|
"""Pull the scoping IDs out of the SDK's ``filters`` object and return
|
|
(target_user, agent_id, run_id, remaining_filters)."""
|
|
filters = dict(body.get("filters") or {})
|
|
target = _require_self(filters.pop("user_id", None), authed)
|
|
agent_id = filters.pop("agent_id", None)
|
|
run_id = filters.pop("run_id", None)
|
|
filters.pop("app_id", None) # accepted by the SDK; unused here
|
|
return target, agent_id, run_id, (filters or None)
|
|
|
|
|
|
async def _owned_or_404(memory_id: str, user: str) -> None:
|
|
if not await mem0_manager.verify_memory_ownership(memory_id, user):
|
|
raise HTTPException(status_code=404, detail=f"Memory '{memory_id}' not found")
|
|
|
|
|
|
@router.get("/v1/ping/")
|
|
async def ping(user: str = Depends(get_current_user_platform)) -> Dict[str, Any]:
|
|
"""Client construction validation. ``org_id``/``project_id`` MUST be
|
|
non-empty or the SDK's ``Project(...)`` raises 'org_id and project_id must
|
|
be set'."""
|
|
return {
|
|
"status": "ok",
|
|
"user_email": user,
|
|
"org_id": "default-org",
|
|
"project_id": "default-project",
|
|
}
|
|
|
|
|
|
@router.post("/v3/memories/add/")
|
|
async def add_memories(
|
|
request: Request, user: str = Depends(get_current_user_platform)
|
|
) -> Dict[str, Any]:
|
|
body = await _json_body(request)
|
|
messages = body.get("messages")
|
|
if not messages:
|
|
raise HTTPException(status_code=422, detail="`messages` is required")
|
|
if isinstance(messages, str):
|
|
messages = [{"role": "user", "content": messages}]
|
|
target = _require_self(body.get("user_id"), user)
|
|
|
|
raw = await mem0_manager.add_memories(
|
|
messages=messages,
|
|
user_id=target,
|
|
agent_id=body.get("agent_id"),
|
|
run_id=body.get("run_id"),
|
|
metadata=body.get("metadata"),
|
|
)
|
|
# mem0_manager wraps the mem0 result as {"added_memories": [<mem0 dict>], ...};
|
|
# the mem0 dict is already {"results": [...]} (the platform shape).
|
|
added = raw.get("added_memories") or []
|
|
if added and isinstance(added[0], dict) and "results" in added[0]:
|
|
return added[0]
|
|
return {"results": added}
|
|
|
|
|
|
@router.post("/v3/memories/search/")
|
|
async def search_memories(
|
|
request: Request, user: str = Depends(get_current_user_platform)
|
|
) -> Dict[str, Any]:
|
|
body = await _json_body(request)
|
|
query = body.get("query")
|
|
if not query:
|
|
raise HTTPException(status_code=422, detail="`query` is required")
|
|
target, agent_id, run_id, extra = _split_filters(body, user)
|
|
top_k = body.get("top_k") or body.get("limit") or 10
|
|
|
|
result = await mem0_manager.search_memories(
|
|
query=query,
|
|
user_id=target,
|
|
limit=int(top_k),
|
|
filters=extra,
|
|
agent_id=agent_id,
|
|
run_id=run_id,
|
|
)
|
|
return {"results": result.get("memories", [])}
|
|
|
|
|
|
@router.post("/v3/memories/")
|
|
async def get_all_memories(
|
|
request: Request, user: str = Depends(get_current_user_platform)
|
|
) -> Dict[str, Any]:
|
|
body = await _json_body(request)
|
|
target, agent_id, run_id, extra = _split_filters(body, user)
|
|
|
|
# The SDK sends page/page_size as query params. mem0's get_all has no offset,
|
|
# so we fetch up to page*page_size and slice the requested page.
|
|
try:
|
|
page = max(int(request.query_params.get("page", 1)), 1)
|
|
page_size = min(max(int(request.query_params.get("page_size", 100)), 1), 1000)
|
|
except ValueError:
|
|
page, page_size = 1, 100
|
|
|
|
items = await mem0_manager.get_user_memories(
|
|
user_id=target,
|
|
limit=page * page_size,
|
|
agent_id=agent_id,
|
|
run_id=run_id,
|
|
filters=extra,
|
|
)
|
|
total = len(items)
|
|
start = (page - 1) * page_size
|
|
return {
|
|
"count": total,
|
|
"next": page + 1 if start + page_size < total else None,
|
|
"previous": page - 1 if page > 1 else None,
|
|
"results": items[start : start + page_size],
|
|
}
|
|
|
|
|
|
@router.get("/v1/memories/{memory_id}/")
|
|
async def get_memory(
|
|
memory_id: str, user: str = Depends(get_current_user_platform)
|
|
) -> Dict[str, Any]:
|
|
await _owned_or_404(memory_id, user)
|
|
mem = await mem0_manager.get_memory(memory_id)
|
|
if mem is None:
|
|
raise HTTPException(status_code=404, detail=f"Memory '{memory_id}' not found")
|
|
return mem
|
|
|
|
|
|
@router.put("/v1/memories/{memory_id}/")
|
|
async def update_memory(
|
|
memory_id: str, request: Request, user: str = Depends(get_current_user_platform)
|
|
) -> Dict[str, Any]:
|
|
await _owned_or_404(memory_id, user)
|
|
body = await _json_body(request)
|
|
text = body.get("text") or body.get("memory") or body.get("data")
|
|
if not text:
|
|
raise HTTPException(
|
|
status_code=422, detail="`text` is required to update a memory"
|
|
)
|
|
return await mem0_manager.update_memory(memory_id=memory_id, content=text)
|
|
|
|
|
|
@router.delete("/v1/memories/{memory_id}/")
|
|
async def delete_memory(
|
|
memory_id: str, user: str = Depends(get_current_user_platform)
|
|
) -> Dict[str, Any]:
|
|
await _owned_or_404(memory_id, user)
|
|
return await mem0_manager.delete_memory(memory_id=memory_id)
|
|
|
|
|
|
@router.get("/v1/memories/{memory_id}/history/")
|
|
async def memory_history(
|
|
memory_id: str, user: str = Depends(get_current_user_platform)
|
|
) -> List[Dict[str, Any]]:
|
|
await _owned_or_404(memory_id, user)
|
|
result = await mem0_manager.get_memory_history(memory_id)
|
|
return result.get("history", [])
|
|
|
|
|
|
@router.delete("/v1/memories/")
|
|
async def delete_all_memories(
|
|
request: Request, user: str = Depends(get_current_user_platform)
|
|
) -> Dict[str, Any]:
|
|
target = _require_self(request.query_params.get("user_id"), user)
|
|
return await mem0_manager.delete_user_memories(user_id=target)
|
|
|
|
|
|
@router.get("/v1/entities/")
|
|
async def list_entities(
|
|
user: str = Depends(get_current_user_platform),
|
|
) -> Dict[str, Any]:
|
|
# This server is single-user-per-key; report the authenticated user as the
|
|
# only entity (the platform returns all users/agents/runs with memories).
|
|
return {"results": [{"id": user, "name": user, "type": "user"}], "count": 1}
|