feat: rewrite /v1/chat/completions as a real OpenAI-compat proxy

Part 1: strip leading newlines from chat_with_memory's LLM response
(minimax-m2 reasoning output leaks blank lines; .lstrip() at the source
covers /chat, /v1/chat/completions, and the MCP chat tool).

Part 2: replace the /v1/chat/completions handler with an httpx-based
pass-through proxy that preserves every upstream field (tool_calls,
reasoning_tokens, system_fingerprint, finish_reason, etc.) and supports
end-to-end MCP-style tool calling.

What changed:
- models.py: OpenAIChatCompletionRequest is now permissive — typed for
  the common fields (tools, tool_choice, parallel_tool_calls,
  response_format, max_completion_tokens, seed, stream_options,
  reasoning_effort, modalities, etc.) and extra='allow' for forward-
  compat. The typed response models (OpenAIChatCompletionResponse and
  friends) are deleted — the handler returns upstream's JSON dict
  directly so unknown fields aren't silently dropped.
- mem0_manager.py: adds httpx.AsyncClient + an openai_proxy_completion()
  method that injects a "Relevant memories" system message only when
  the last role is 'user' AND no tool flow is in progress, then forwards
  to the upstream LLM. Non-stream returns upstream JSON; stream returns
  an async iterator that yields raw upstream SSE bytes verbatim while
  side-channel-parsing for the post-stream mem0.add. Codifies the
  Memori #434 lessons: never mutates existing messages (only prepends
  system), never touches tool_call_id, runs post-add even on mid-stream
  error via try/finally.
- main.py: handler is now ~50 lines — model_dump(exclude_unset) the
  request, hand off to openai_proxy_completion, return dict OR wrap in
  StreamingResponse. response_model=None so FastAPI doesn't validate.
  Deleted stream_openai_response (post-hoc word-chunking is gone).
  Lifespan shutdown closes mem0_manager.async_http.

Research confirmed mem0 itself does not ship an HTTP /v1/chat/completions
(only the in-process mem0.proxy.main.Mem0 SDK pattern), so we replicate
the pattern without adding a litellm dependency. SSE/tool_calls patterns
are modeled after microsoft/agent-lightning's llm_proxy.

Verified locally: ast.parse OK on all three files. End-to-end smoke tests
will run on beast.
This commit is contained in:
Pratik Narola 2026-05-23 21:08:31 +05:30
parent e99b382b16
commit e5a4d1c7c2
3 changed files with 395 additions and 194 deletions

View file

@ -11,7 +11,6 @@ from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security,
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
import structlog
import asyncio
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
@ -29,6 +28,8 @@ def get_rate_limit_key(request: Request) -> str:
limiter = Limiter(key_func=get_rate_limit_key)
import httpx
from models import (
ChatRequest,
MemoryAddRequest,
@ -43,10 +44,6 @@ from models import (
GlobalStatsResponse,
UserStatsResponse,
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
OpenAIChoice,
OpenAIChoiceMessage,
OpenAIUsage,
)
from mem0_manager import mem0_manager
from auth import get_current_user, get_current_user_openai, auth_service
@ -109,6 +106,12 @@ async def lifespan(app: FastAPI):
except Exception as e:
logger.error(f"Error stopping MCP session manager: {e}")
# Close the async HTTP client used by the /v1/chat/completions proxy.
try:
await mem0_manager.aclose()
except Exception as e:
logger.warning(f"mem0_manager aclose failed: {e}")
logger.info("Shutting down Mem0 Interface POC")
@ -336,135 +339,59 @@ async def chat_with_memory(
)
async def stream_openai_response(
completion_id: str, model: str, content: str, created: int
):
"""Generate SSE stream for OpenAI-compatible streaming by chunking the response."""
import uuid
# First chunk with role
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
# Stream content in chunks (3 words at a time for smooth effect)
words = content.split()
chunk_size = 3
for i in range(0, len(words), chunk_size):
word_chunk = " ".join(words[i : i + chunk_size])
if i + chunk_size < len(words):
word_chunk += " "
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{"index": 0, "delta": {"content": word_chunk}, "finish_reason": None}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.05)
# Final chunk
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
@app.post("/chat/completions")
@app.post("/v1/chat/completions", response_model=None)
@app.post("/chat/completions", response_model=None)
@limiter.limit("30/minute")
async def openai_chat_completions(
request: Request,
completion_request: OpenAIChatCompletionRequest,
authenticated_user: str = Depends(get_current_user_openai),
):
"""OpenAI-compatible chat completions endpoint with mem0 memory integration."""
"""OpenAI-compatible chat completions — pass-through proxy with memory injection.
Forwards the request to the upstream LLM verbatim (preserving tool_calls,
reasoning_tokens, system_fingerprint, finish_reason, etc.) and optionally
prepends a system message with relevant memories when the last role is
`user` and no tool flow is in progress. See
`mem0_manager.openai_proxy_completion` for the full injection rule and
tool-call safety contract.
"""
try:
import uuid
request_kwargs = completion_request.model_dump(exclude_unset=True)
# `user` is the OpenAI client-supplied identifier; we always derive the
# real user from the API key. Strip it before forwarding so it can't
# confuse upstream user-tracking.
request_kwargs.pop("user", None)
user_id = authenticated_user
logger.info(
f"OpenAI chat completion for user: {user_id} (streaming={completion_request.stream})"
"openai_chat_completions",
user_id=authenticated_user,
stream=bool(request_kwargs.get("stream")),
has_tools=bool(request_kwargs.get("tools")),
model=request_kwargs.get("model"),
)
# Extract last user message
user_messages = [
m for m in completion_request.messages if m.get("role") == "user"
]
if not user_messages:
raise HTTPException(
status_code=400,
detail="No user messages provided. Include at least one message with role='user'.",
result = await mem0_manager.openai_proxy_completion(
request_kwargs=request_kwargs,
user_id=authenticated_user,
)
last_message = user_messages[-1].get("content", "")
context = (
completion_request.messages[:-1]
if len(completion_request.messages) > 1
else None
)
# Call chat_with_memory
result = await mem0_manager.chat_with_memory(
message=last_message,
user_id=user_id,
context=context,
)
completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created_time = int(time.time())
assistant_content = result.get("response", "")
if completion_request.stream:
if request_kwargs.get("stream"):
return StreamingResponse(
stream_openai_response(
completion_id=completion_id,
model=settings.default_model,
content=assistant_content,
created=created_time,
),
result,
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
else:
return OpenAIChatCompletionResponse(
id=completion_id,
object="chat.completion",
created=created_time,
model=settings.default_model,
choices=[
OpenAIChoice(
index=0,
message=OpenAIChoiceMessage(
role="assistant", content=assistant_content
),
finish_reason="stop",
)
],
usage=OpenAIUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
return result
except httpx.HTTPStatusError as e:
# Forward upstream error body to the client with its original status
try:
detail = e.response.json()
except Exception:
detail = {"error": {"message": e.response.text[:500]}}
raise HTTPException(status_code=e.response.status_code, detail=detail)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:

View file

@ -1,8 +1,12 @@
"""Ultra-minimal Mem0 Manager - Pure Mem0 + Custom OpenAI Endpoint Only."""
import asyncio
import json
import logging
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, AsyncIterator, Union
from datetime import datetime
import httpx
from mem0 import Memory
from openai import OpenAI
from tenacity import (
@ -75,6 +79,28 @@ def _build_filters(
return merged
def _extract_text(content: Any) -> str:
"""Extract text from an OpenAI message `content` field.
`content` can be a plain string OR a list of multi-part objects
(e.g., [{"type": "text", "text": "..."}, {"type": "image_url", ...}]).
Returns concatenated text parts, or "" if no text is present.
"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for p in content:
if isinstance(p, dict) and p.get("type") == "text":
t = p.get("text")
if isinstance(t, str):
parts.append(t)
return "\n".join(parts)
return ""
# Appended as the "## Custom Instructions" section of the additive-extraction
# prompt (mem0/configs/prompts.py::generate_additive_extraction_prompt). The
# default few-shot bias is consumer-organizer ("favourite movies", "SF restaurants"),
@ -176,8 +202,23 @@ class Mem0Manager:
timeout=60.0, # 60 second timeout for LLM calls
max_retries=2, # Retry failed requests up to 2 times
)
# Async HTTP client used by openai_proxy_completion to forward
# /v1/chat/completions to the upstream endpoint verbatim (preserves
# tool_calls, reasoning_tokens, system_fingerprint, etc., which the
# openai-python SDK can drop). Closed in the FastAPI lifespan
# shutdown via aclose().
self.async_http = httpx.AsyncClient(
timeout=httpx.Timeout(120.0, connect=10.0)
)
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
async def aclose(self) -> None:
"""Release async resources. Called from the FastAPI lifespan shutdown."""
try:
await self.async_http.aclose()
except Exception as e:
logger.warning("async_http aclose failed", error=str(e))
# Pure passthrough methods - no custom logic
@db_retry
@timed("add_memories")
@ -450,7 +491,11 @@ class Mem0Manager:
response = self.openai_client.chat.completions.create(
model=settings.default_model, messages=messages
)
assistant_response = response.choices[0].message.content
# Strip leading whitespace — reasoning models (minimax-m2) leak
# blank lines from their reasoning output into visible content.
# lstrip(), not strip(), preserves intentional trailing whitespace
# inside content (e.g., inside a code block).
assistant_response = (response.choices[0].message.content or "").lstrip()
llm_time = time.time() - llm_start_time
logger.debug(
"LLM call completed",
@ -504,6 +549,261 @@ class Mem0Manager:
"model_used": None,
}
# ------------------------------------------------------------------
# OpenAI-compatible /v1/chat/completions proxy
# ------------------------------------------------------------------
# Design: forward client requests to the upstream endpoint verbatim,
# injecting a "Relevant memories" system message ONLY when the last
# message is from the user AND no tool flow is in progress. The
# opinionated chat_with_memory above is NOT used on this path because:
# (a) clients pick their own model, tools, response_format etc.,
# (b) we must preserve every upstream field (tool_calls, reasoning_tokens,
# system_fingerprint, finish_reason), which the openai-python SDK and
# typed Pydantic responses strip silently,
# (c) streaming must be real (SSE pass-through), not post-hoc word-chunking.
#
# Tool-call safety contract (Memori issue #434 — codified here):
# - We only PREPEND a system message; never mutate, reorder, or delete
# existing messages.
# - We never touch tool messages or tool_call_id values.
# - We skip injection on tool-result follow-ups (last role != "user").
# - We always run the post-stream mem0.add even if upstream errors
# mid-flight (via try/finally).
async def openai_proxy_completion(
self,
request_kwargs: Dict[str, Any],
user_id: str,
) -> Union[Dict[str, Any], AsyncIterator[bytes]]:
"""Proxy a chat-completions request to the upstream LLM.
Returns:
- dict (upstream JSON) when stream is falsy
- async iterator of SSE bytes when stream is truthy
"""
messages = request_kwargs.get("messages") or []
if not messages:
raise ValueError("messages array is empty")
# Injection rule: last role == 'user' AND no prior tool message.
last_msg = messages[-1] if isinstance(messages[-1], dict) else {}
last_role = last_msg.get("role")
has_tool_message = any(
isinstance(m, dict) and m.get("role") == "tool" for m in messages
)
inject_memory = last_role == "user" and not has_tool_message
if inject_memory:
last_user_text = _extract_text(last_msg.get("content"))
if last_user_text:
try:
search_result = self.memory.search(
query=last_user_text,
filters=_build_filters(user_id),
top_k=30,
threshold=0.3,
rerank=True,
)
relevant = (search_result.get("results") or [])[:10]
except Exception as e:
logger.warning(
"memory search failed; proceeding without injection",
error=str(e),
user_id=user_id,
)
relevant = []
if relevant:
memories_str = "\n".join(
f"- {m.get('memory', '')}" for m in relevant
)
mem_block = (
"Relevant memories about the user "
"(use when helpful, ignore otherwise):\n" + memories_str
)
# Merge into existing leading system message if present;
# otherwise prepend a new one. Either way we never mutate
# any other message in the list.
if (
messages
and isinstance(messages[0], dict)
and messages[0].get("role") == "system"
):
existing = _extract_text(messages[0].get("content")) or ""
merged = (existing + "\n\n" + mem_block) if existing else mem_block
messages = [
{"role": "system", "content": merged},
*messages[1:],
]
else:
messages = [
{"role": "system", "content": mem_block},
*messages,
]
request_kwargs["messages"] = messages
logger.info(
"memory injected",
user_id=user_id,
memory_count=len(relevant),
)
url = settings.openai_base_url.rstrip("/") + "/v1/chat/completions"
headers = {
"Authorization": f"Bearer {settings.openai_api_key}",
"Content-Type": "application/json",
}
is_stream = bool(request_kwargs.get("stream"))
if not is_stream:
resp = await self.async_http.post(
url, json=request_kwargs, headers=headers
)
if resp.status_code >= 400:
# Surface upstream error body to the client unchanged
try:
err_body = resp.json()
except Exception:
err_body = {"error": {"message": resp.text[:500]}}
raise httpx.HTTPStatusError(
f"Upstream {resp.status_code}", request=resp.request, response=resp
)
data = resp.json()
# Fire-and-forget mem0.add in the background. Skips tool-only
# responses inside _post_completion_add.
asyncio.create_task(self._post_completion_add(messages, data, user_id))
return data
return self._stream_proxy(url, request_kwargs, headers, messages, user_id)
async def _stream_proxy(
self,
url: str,
request_kwargs: Dict[str, Any],
headers: Dict[str, str],
messages: List[Dict[str, Any]],
user_id: str,
) -> AsyncIterator[bytes]:
"""Stream upstream SSE bytes verbatim; accumulate content for post-stream add."""
accumulated_content: List[str] = []
saw_tool_calls = False
buffer = b""
try:
async with self.async_http.stream(
"POST", url, json=request_kwargs, headers=headers
) as upstream:
if upstream.status_code >= 400:
body_bytes = await upstream.aread()
err = {
"error": {
"message": (
f"Upstream {upstream.status_code}: "
f"{body_bytes.decode('utf-8', errors='replace')[:500]}"
)
}
}
yield f"data: {json.dumps(err)}\n\n".encode("utf-8")
yield b"data: [DONE]\n\n"
return
async for chunk in upstream.aiter_bytes():
if not chunk:
continue
# Forward bytes verbatim — preserves the exact SSE wire format
yield chunk
# Side-channel parse for content/tool_calls accumulation.
# Events are \n\n-separated; data lines start with "data: ".
buffer += chunk
while b"\n\n" in buffer:
event_bytes, buffer = buffer.split(b"\n\n", 1)
for raw_line in event_bytes.split(b"\n"):
if not raw_line.startswith(b"data: "):
continue
payload = raw_line[6:].strip()
if not payload or payload == b"[DONE]":
continue
try:
obj = json.loads(payload)
delta = (
(obj.get("choices") or [{}])[0].get("delta") or {}
)
content_piece = delta.get("content")
if isinstance(content_piece, str):
accumulated_content.append(content_piece)
if delta.get("tool_calls"):
saw_tool_calls = True
except (json.JSONDecodeError, KeyError, IndexError, TypeError):
pass
except httpx.HTTPError as e:
logger.warning("upstream stream error", error=str(e), user_id=user_id)
err = {"error": {"message": f"Upstream stream error: {e}"}}
yield f"data: {json.dumps(err)}\n\n".encode("utf-8")
yield b"data: [DONE]\n\n"
finally:
# Per Memori #434: post-stream mem0.add must run even on mid-stream
# error so partial content is captured. Skipped for tool-only or
# empty responses.
full = "".join(accumulated_content).lstrip()
if full and not saw_tool_calls:
last_user_text = (
_extract_text(messages[-1].get("content"))
if messages
and isinstance(messages[-1], dict)
and messages[-1].get("role") == "user"
else None
)
if last_user_text:
try:
# Synchronous mem0.add in a thread to avoid blocking
# the response loop after StreamingResponse closes.
await asyncio.to_thread(
self.memory.add,
[
{"role": "user", "content": last_user_text},
{"role": "assistant", "content": full},
],
user_id=user_id,
)
except Exception as e:
logger.warning(
"post-stream mem0.add failed",
error=str(e),
user_id=user_id,
)
async def _post_completion_add(
self,
messages: List[Dict[str, Any]],
response_data: Dict[str, Any],
user_id: str,
) -> None:
"""Background mem0.add after a non-stream completion."""
try:
choice = (response_data.get("choices") or [{}])[0]
msg = choice.get("message") or {}
content = msg.get("content")
if not content or not isinstance(content, str):
return # tool-only or non-text completion — skip
content = content.lstrip()
last_user_text = (
_extract_text(messages[-1].get("content"))
if messages
and isinstance(messages[-1], dict)
and messages[-1].get("role") == "user"
else None
)
if not last_user_text:
return
await asyncio.to_thread(
self.memory.add,
[
{"role": "user", "content": last_user_text},
{"role": "assistant", "content": content},
],
user_id=user_id,
)
except Exception as e:
logger.warning(
"post-completion mem0.add failed", error=str(e), user_id=user_id
)
async def health_check(self) -> Dict[str, str]:
"""Basic health check - just connectivity."""
status = {}

View file

@ -1,7 +1,7 @@
"""Ultra-minimal Pydantic models for pure Mem0 API."""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any, Union
from pydantic import BaseModel, ConfigDict, Field
import re
@ -230,82 +230,56 @@ class UserStatsResponse(BaseModel):
# OpenAI-Compatible API Models
class OpenAIMessage(BaseModel):
"""OpenAI message format."""
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
#
# The /v1/chat/completions handler is a pass-through proxy: requests are
# forwarded to the upstream LLM and the upstream response is relayed verbatim
# (dict for non-stream, raw SSE bytes for stream). Hence only the REQUEST
# model lives here; the response is never re-typed (typed Pydantic response
# models silently drop unknown fields like tool_calls / refusal /
# reasoning_tokens — see the audit in the plan file).
class OpenAIChatCompletionRequest(BaseModel):
"""OpenAI chat completion request format."""
"""OpenAI chat completion request — permissive schema, forwarded as-is.
model: str = Field(..., description="Model to use (will use configured default)")
messages: List[Dict[str, str]] = Field(..., description="List of messages")
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
stream: Optional[bool] = Field(False, description="Whether to stream responses")
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
n: Optional[int] = Field(1, description="Number of completions to generate")
stop: Optional[List[str]] = Field(None, description="Stop sequences")
presence_penalty: Optional[float] = Field(0, description="Presence penalty")
frequency_penalty: Optional[float] = Field(0, description="Frequency penalty")
user: Optional[str] = Field(
None, description="User identifier (ignored, uses API key)"
Only `model` and `messages` are required. All other standard OpenAI
parameters are typed (for client IDE/docs benefit) but optional. Unknown
fields are accepted via `extra="allow"` and forwarded to upstream, so
new OpenAI parameters don't require a code change here.
"""
model_config = ConfigDict(extra="allow")
model: str = Field(..., description="Model to use (forwarded to upstream)")
messages: List[Dict[str, Any]] = Field(
..., description="Messages array (multi-part content supported)"
)
class OpenAIUsage(BaseModel):
"""Token usage information."""
prompt_tokens: int = Field(..., description="Tokens in the prompt")
completion_tokens: int = Field(..., description="Tokens in the completion")
total_tokens: int = Field(..., description="Total tokens used")
class OpenAIChoiceMessage(BaseModel):
"""Message in a choice."""
role: str = Field(..., description="Role of the message")
content: str = Field(..., description="Content of the message")
class OpenAIChoice(BaseModel):
"""Individual completion choice."""
index: int = Field(..., description="Choice index")
message: OpenAIChoiceMessage = Field(..., description="Message content")
finish_reason: str = Field(..., description="Reason for completion finish")
class OpenAIChatCompletionResponse(BaseModel):
"""OpenAI chat completion response format."""
id: str = Field(..., description="Unique completion ID")
object: str = Field(default="chat.completion", description="Object type")
created: int = Field(..., description="Unix timestamp of creation")
model: str = Field(..., description="Model used for completion")
choices: List[OpenAIChoice] = Field(..., description="List of completion choices")
usage: Optional[OpenAIUsage] = Field(None, description="Token usage information")
# Streaming-specific models
class OpenAIStreamDelta(BaseModel):
"""Delta content in a streaming chunk."""
role: Optional[str] = Field(None, description="Role (only in first chunk)")
content: Optional[str] = Field(None, description="Incremental content")
class OpenAIStreamChoice(BaseModel):
"""Individual streaming choice."""
index: int = Field(..., description="Choice index")
delta: OpenAIStreamDelta = Field(..., description="Delta content")
finish_reason: Optional[str] = Field(
None, description="Reason for completion finish"
)
# Common params (typed for IDE/docs; all optional)
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stream_options: Optional[Dict[str, Any]] = None
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = None # deprecated; still accepted
max_completion_tokens: Optional[int] = None # replaces max_tokens
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
seed: Optional[int] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
response_format: Optional[Dict[str, Any]] = None
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
parallel_tool_calls: Optional[bool] = None
reasoning_effort: Optional[str] = None # o-series / reasoning models
modalities: Optional[List[str]] = None
audio: Optional[Dict[str, Any]] = None
prediction: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None
store: Optional[bool] = None
service_tier: Optional[str] = None
logit_bias: Optional[Dict[str, float]] = None
# `user` is ignored — the authenticated user is derived from the API key.
user: Optional[str] = None