feat: rewrite /v1/chat/completions as a real OpenAI-compat proxy
Part 1: strip leading newlines from chat_with_memory's LLM response (minimax-m2 reasoning output leaks blank lines; .lstrip() at the source covers /chat, /v1/chat/completions, and the MCP chat tool). Part 2: replace the /v1/chat/completions handler with an httpx-based pass-through proxy that preserves every upstream field (tool_calls, reasoning_tokens, system_fingerprint, finish_reason, etc.) and supports end-to-end MCP-style tool calling. What changed: - models.py: OpenAIChatCompletionRequest is now permissive — typed for the common fields (tools, tool_choice, parallel_tool_calls, response_format, max_completion_tokens, seed, stream_options, reasoning_effort, modalities, etc.) and extra='allow' for forward- compat. The typed response models (OpenAIChatCompletionResponse and friends) are deleted — the handler returns upstream's JSON dict directly so unknown fields aren't silently dropped. - mem0_manager.py: adds httpx.AsyncClient + an openai_proxy_completion() method that injects a "Relevant memories" system message only when the last role is 'user' AND no tool flow is in progress, then forwards to the upstream LLM. Non-stream returns upstream JSON; stream returns an async iterator that yields raw upstream SSE bytes verbatim while side-channel-parsing for the post-stream mem0.add. Codifies the Memori #434 lessons: never mutates existing messages (only prepends system), never touches tool_call_id, runs post-add even on mid-stream error via try/finally. - main.py: handler is now ~50 lines — model_dump(exclude_unset) the request, hand off to openai_proxy_completion, return dict OR wrap in StreamingResponse. response_model=None so FastAPI doesn't validate. Deleted stream_openai_response (post-hoc word-chunking is gone). Lifespan shutdown closes mem0_manager.async_http. Research confirmed mem0 itself does not ship an HTTP /v1/chat/completions (only the in-process mem0.proxy.main.Mem0 SDK pattern), so we replicate the pattern without adding a litellm dependency. SSE/tool_calls patterns are modeled after microsoft/agent-lightning's llm_proxy. Verified locally: ast.parse OK on all three files. End-to-end smoke tests will run on beast.
This commit is contained in:
parent
e99b382b16
commit
e5a4d1c7c2
3 changed files with 395 additions and 194 deletions
161
backend/main.py
161
backend/main.py
|
|
@ -11,7 +11,6 @@ from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security,
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
import structlog
|
import structlog
|
||||||
import asyncio
|
|
||||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
@ -29,6 +28,8 @@ def get_rate_limit_key(request: Request) -> str:
|
||||||
|
|
||||||
|
|
||||||
limiter = Limiter(key_func=get_rate_limit_key)
|
limiter = Limiter(key_func=get_rate_limit_key)
|
||||||
|
import httpx
|
||||||
|
|
||||||
from models import (
|
from models import (
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
MemoryAddRequest,
|
MemoryAddRequest,
|
||||||
|
|
@ -43,10 +44,6 @@ from models import (
|
||||||
GlobalStatsResponse,
|
GlobalStatsResponse,
|
||||||
UserStatsResponse,
|
UserStatsResponse,
|
||||||
OpenAIChatCompletionRequest,
|
OpenAIChatCompletionRequest,
|
||||||
OpenAIChatCompletionResponse,
|
|
||||||
OpenAIChoice,
|
|
||||||
OpenAIChoiceMessage,
|
|
||||||
OpenAIUsage,
|
|
||||||
)
|
)
|
||||||
from mem0_manager import mem0_manager
|
from mem0_manager import mem0_manager
|
||||||
from auth import get_current_user, get_current_user_openai, auth_service
|
from auth import get_current_user, get_current_user_openai, auth_service
|
||||||
|
|
@ -109,6 +106,12 @@ async def lifespan(app: FastAPI):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error stopping MCP session manager: {e}")
|
logger.error(f"Error stopping MCP session manager: {e}")
|
||||||
|
|
||||||
|
# Close the async HTTP client used by the /v1/chat/completions proxy.
|
||||||
|
try:
|
||||||
|
await mem0_manager.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"mem0_manager aclose failed: {e}")
|
||||||
|
|
||||||
logger.info("Shutting down Mem0 Interface POC")
|
logger.info("Shutting down Mem0 Interface POC")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -336,135 +339,59 @@ async def chat_with_memory(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stream_openai_response(
|
@app.post("/v1/chat/completions", response_model=None)
|
||||||
completion_id: str, model: str, content: str, created: int
|
@app.post("/chat/completions", response_model=None)
|
||||||
):
|
|
||||||
"""Generate SSE stream for OpenAI-compatible streaming by chunking the response."""
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
# First chunk with role
|
|
||||||
chunk = {
|
|
||||||
"id": completion_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant", "content": ""},
|
|
||||||
"finish_reason": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
|
||||||
|
|
||||||
# Stream content in chunks (3 words at a time for smooth effect)
|
|
||||||
words = content.split()
|
|
||||||
chunk_size = 3
|
|
||||||
|
|
||||||
for i in range(0, len(words), chunk_size):
|
|
||||||
word_chunk = " ".join(words[i : i + chunk_size])
|
|
||||||
if i + chunk_size < len(words):
|
|
||||||
word_chunk += " "
|
|
||||||
|
|
||||||
chunk = {
|
|
||||||
"id": completion_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": [
|
|
||||||
{"index": 0, "delta": {"content": word_chunk}, "finish_reason": None}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
|
|
||||||
# Final chunk
|
|
||||||
chunk = {
|
|
||||||
"id": completion_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
@app.post("/chat/completions")
|
|
||||||
@limiter.limit("30/minute")
|
@limiter.limit("30/minute")
|
||||||
async def openai_chat_completions(
|
async def openai_chat_completions(
|
||||||
request: Request,
|
request: Request,
|
||||||
completion_request: OpenAIChatCompletionRequest,
|
completion_request: OpenAIChatCompletionRequest,
|
||||||
authenticated_user: str = Depends(get_current_user_openai),
|
authenticated_user: str = Depends(get_current_user_openai),
|
||||||
):
|
):
|
||||||
"""OpenAI-compatible chat completions endpoint with mem0 memory integration."""
|
"""OpenAI-compatible chat completions — pass-through proxy with memory injection.
|
||||||
|
|
||||||
|
Forwards the request to the upstream LLM verbatim (preserving tool_calls,
|
||||||
|
reasoning_tokens, system_fingerprint, finish_reason, etc.) and optionally
|
||||||
|
prepends a system message with relevant memories when the last role is
|
||||||
|
`user` and no tool flow is in progress. See
|
||||||
|
`mem0_manager.openai_proxy_completion` for the full injection rule and
|
||||||
|
tool-call safety contract.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
import uuid
|
request_kwargs = completion_request.model_dump(exclude_unset=True)
|
||||||
|
# `user` is the OpenAI client-supplied identifier; we always derive the
|
||||||
|
# real user from the API key. Strip it before forwarding so it can't
|
||||||
|
# confuse upstream user-tracking.
|
||||||
|
request_kwargs.pop("user", None)
|
||||||
|
|
||||||
user_id = authenticated_user
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"OpenAI chat completion for user: {user_id} (streaming={completion_request.stream})"
|
"openai_chat_completions",
|
||||||
|
user_id=authenticated_user,
|
||||||
|
stream=bool(request_kwargs.get("stream")),
|
||||||
|
has_tools=bool(request_kwargs.get("tools")),
|
||||||
|
model=request_kwargs.get("model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract last user message
|
result = await mem0_manager.openai_proxy_completion(
|
||||||
user_messages = [
|
request_kwargs=request_kwargs,
|
||||||
m for m in completion_request.messages if m.get("role") == "user"
|
user_id=authenticated_user,
|
||||||
]
|
|
||||||
if not user_messages:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="No user messages provided. Include at least one message with role='user'.",
|
|
||||||
)
|
|
||||||
|
|
||||||
last_message = user_messages[-1].get("content", "")
|
|
||||||
context = (
|
|
||||||
completion_request.messages[:-1]
|
|
||||||
if len(completion_request.messages) > 1
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call chat_with_memory
|
if request_kwargs.get("stream"):
|
||||||
result = await mem0_manager.chat_with_memory(
|
|
||||||
message=last_message,
|
|
||||||
user_id=user_id,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
||||||
created_time = int(time.time())
|
|
||||||
assistant_content = result.get("response", "")
|
|
||||||
|
|
||||||
if completion_request.stream:
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_openai_response(
|
result,
|
||||||
completion_id=completion_id,
|
|
||||||
model=settings.default_model,
|
|
||||||
content=assistant_content,
|
|
||||||
created=created_time,
|
|
||||||
),
|
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
)
|
)
|
||||||
else:
|
return result
|
||||||
return OpenAIChatCompletionResponse(
|
except httpx.HTTPStatusError as e:
|
||||||
id=completion_id,
|
# Forward upstream error body to the client with its original status
|
||||||
object="chat.completion",
|
try:
|
||||||
created=created_time,
|
detail = e.response.json()
|
||||||
model=settings.default_model,
|
except Exception:
|
||||||
choices=[
|
detail = {"error": {"message": e.response.text[:500]}}
|
||||||
OpenAIChoice(
|
raise HTTPException(status_code=e.response.status_code, detail=detail)
|
||||||
index=0,
|
except ValueError as e:
|
||||||
message=OpenAIChoiceMessage(
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
role="assistant", content=assistant_content
|
|
||||||
),
|
|
||||||
finish_reason="stop",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
usage=OpenAIUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,12 @@
|
||||||
"""Ultra-minimal Mem0 Manager - Pure Mem0 + Custom OpenAI Endpoint Only."""
|
"""Ultra-minimal Mem0 Manager - Pure Mem0 + Custom OpenAI Endpoint Only."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any, AsyncIterator, Union
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import httpx
|
||||||
from mem0 import Memory
|
from mem0 import Memory
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
|
|
@ -75,6 +79,28 @@ def _build_filters(
|
||||||
return merged
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text(content: Any) -> str:
|
||||||
|
"""Extract text from an OpenAI message `content` field.
|
||||||
|
|
||||||
|
`content` can be a plain string OR a list of multi-part objects
|
||||||
|
(e.g., [{"type": "text", "text": "..."}, {"type": "image_url", ...}]).
|
||||||
|
Returns concatenated text parts, or "" if no text is present.
|
||||||
|
"""
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts = []
|
||||||
|
for p in content:
|
||||||
|
if isinstance(p, dict) and p.get("type") == "text":
|
||||||
|
t = p.get("text")
|
||||||
|
if isinstance(t, str):
|
||||||
|
parts.append(t)
|
||||||
|
return "\n".join(parts)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
# Appended as the "## Custom Instructions" section of the additive-extraction
|
# Appended as the "## Custom Instructions" section of the additive-extraction
|
||||||
# prompt (mem0/configs/prompts.py::generate_additive_extraction_prompt). The
|
# prompt (mem0/configs/prompts.py::generate_additive_extraction_prompt). The
|
||||||
# default few-shot bias is consumer-organizer ("favourite movies", "SF restaurants"),
|
# default few-shot bias is consumer-organizer ("favourite movies", "SF restaurants"),
|
||||||
|
|
@ -176,8 +202,23 @@ class Mem0Manager:
|
||||||
timeout=60.0, # 60 second timeout for LLM calls
|
timeout=60.0, # 60 second timeout for LLM calls
|
||||||
max_retries=2, # Retry failed requests up to 2 times
|
max_retries=2, # Retry failed requests up to 2 times
|
||||||
)
|
)
|
||||||
|
# Async HTTP client used by openai_proxy_completion to forward
|
||||||
|
# /v1/chat/completions to the upstream endpoint verbatim (preserves
|
||||||
|
# tool_calls, reasoning_tokens, system_fingerprint, etc., which the
|
||||||
|
# openai-python SDK can drop). Closed in the FastAPI lifespan
|
||||||
|
# shutdown via aclose().
|
||||||
|
self.async_http = httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(120.0, connect=10.0)
|
||||||
|
)
|
||||||
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
|
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
"""Release async resources. Called from the FastAPI lifespan shutdown."""
|
||||||
|
try:
|
||||||
|
await self.async_http.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("async_http aclose failed", error=str(e))
|
||||||
|
|
||||||
# Pure passthrough methods - no custom logic
|
# Pure passthrough methods - no custom logic
|
||||||
@db_retry
|
@db_retry
|
||||||
@timed("add_memories")
|
@timed("add_memories")
|
||||||
|
|
@ -450,7 +491,11 @@ class Mem0Manager:
|
||||||
response = self.openai_client.chat.completions.create(
|
response = self.openai_client.chat.completions.create(
|
||||||
model=settings.default_model, messages=messages
|
model=settings.default_model, messages=messages
|
||||||
)
|
)
|
||||||
assistant_response = response.choices[0].message.content
|
# Strip leading whitespace — reasoning models (minimax-m2) leak
|
||||||
|
# blank lines from their reasoning output into visible content.
|
||||||
|
# lstrip(), not strip(), preserves intentional trailing whitespace
|
||||||
|
# inside content (e.g., inside a code block).
|
||||||
|
assistant_response = (response.choices[0].message.content or "").lstrip()
|
||||||
llm_time = time.time() - llm_start_time
|
llm_time = time.time() - llm_start_time
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"LLM call completed",
|
"LLM call completed",
|
||||||
|
|
@ -504,6 +549,261 @@ class Mem0Manager:
|
||||||
"model_used": None,
|
"model_used": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# OpenAI-compatible /v1/chat/completions proxy
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Design: forward client requests to the upstream endpoint verbatim,
|
||||||
|
# injecting a "Relevant memories" system message ONLY when the last
|
||||||
|
# message is from the user AND no tool flow is in progress. The
|
||||||
|
# opinionated chat_with_memory above is NOT used on this path because:
|
||||||
|
# (a) clients pick their own model, tools, response_format etc.,
|
||||||
|
# (b) we must preserve every upstream field (tool_calls, reasoning_tokens,
|
||||||
|
# system_fingerprint, finish_reason), which the openai-python SDK and
|
||||||
|
# typed Pydantic responses strip silently,
|
||||||
|
# (c) streaming must be real (SSE pass-through), not post-hoc word-chunking.
|
||||||
|
#
|
||||||
|
# Tool-call safety contract (Memori issue #434 — codified here):
|
||||||
|
# - We only PREPEND a system message; never mutate, reorder, or delete
|
||||||
|
# existing messages.
|
||||||
|
# - We never touch tool messages or tool_call_id values.
|
||||||
|
# - We skip injection on tool-result follow-ups (last role != "user").
|
||||||
|
# - We always run the post-stream mem0.add even if upstream errors
|
||||||
|
# mid-flight (via try/finally).
|
||||||
|
async def openai_proxy_completion(
|
||||||
|
self,
|
||||||
|
request_kwargs: Dict[str, Any],
|
||||||
|
user_id: str,
|
||||||
|
) -> Union[Dict[str, Any], AsyncIterator[bytes]]:
|
||||||
|
"""Proxy a chat-completions request to the upstream LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- dict (upstream JSON) when stream is falsy
|
||||||
|
- async iterator of SSE bytes when stream is truthy
|
||||||
|
"""
|
||||||
|
messages = request_kwargs.get("messages") or []
|
||||||
|
if not messages:
|
||||||
|
raise ValueError("messages array is empty")
|
||||||
|
|
||||||
|
# Injection rule: last role == 'user' AND no prior tool message.
|
||||||
|
last_msg = messages[-1] if isinstance(messages[-1], dict) else {}
|
||||||
|
last_role = last_msg.get("role")
|
||||||
|
has_tool_message = any(
|
||||||
|
isinstance(m, dict) and m.get("role") == "tool" for m in messages
|
||||||
|
)
|
||||||
|
inject_memory = last_role == "user" and not has_tool_message
|
||||||
|
|
||||||
|
if inject_memory:
|
||||||
|
last_user_text = _extract_text(last_msg.get("content"))
|
||||||
|
if last_user_text:
|
||||||
|
try:
|
||||||
|
search_result = self.memory.search(
|
||||||
|
query=last_user_text,
|
||||||
|
filters=_build_filters(user_id),
|
||||||
|
top_k=30,
|
||||||
|
threshold=0.3,
|
||||||
|
rerank=True,
|
||||||
|
)
|
||||||
|
relevant = (search_result.get("results") or [])[:10]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"memory search failed; proceeding without injection",
|
||||||
|
error=str(e),
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
relevant = []
|
||||||
|
if relevant:
|
||||||
|
memories_str = "\n".join(
|
||||||
|
f"- {m.get('memory', '')}" for m in relevant
|
||||||
|
)
|
||||||
|
mem_block = (
|
||||||
|
"Relevant memories about the user "
|
||||||
|
"(use when helpful, ignore otherwise):\n" + memories_str
|
||||||
|
)
|
||||||
|
# Merge into existing leading system message if present;
|
||||||
|
# otherwise prepend a new one. Either way we never mutate
|
||||||
|
# any other message in the list.
|
||||||
|
if (
|
||||||
|
messages
|
||||||
|
and isinstance(messages[0], dict)
|
||||||
|
and messages[0].get("role") == "system"
|
||||||
|
):
|
||||||
|
existing = _extract_text(messages[0].get("content")) or ""
|
||||||
|
merged = (existing + "\n\n" + mem_block) if existing else mem_block
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": merged},
|
||||||
|
*messages[1:],
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": mem_block},
|
||||||
|
*messages,
|
||||||
|
]
|
||||||
|
request_kwargs["messages"] = messages
|
||||||
|
logger.info(
|
||||||
|
"memory injected",
|
||||||
|
user_id=user_id,
|
||||||
|
memory_count=len(relevant),
|
||||||
|
)
|
||||||
|
|
||||||
|
url = settings.openai_base_url.rstrip("/") + "/v1/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {settings.openai_api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
is_stream = bool(request_kwargs.get("stream"))
|
||||||
|
|
||||||
|
if not is_stream:
|
||||||
|
resp = await self.async_http.post(
|
||||||
|
url, json=request_kwargs, headers=headers
|
||||||
|
)
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
# Surface upstream error body to the client unchanged
|
||||||
|
try:
|
||||||
|
err_body = resp.json()
|
||||||
|
except Exception:
|
||||||
|
err_body = {"error": {"message": resp.text[:500]}}
|
||||||
|
raise httpx.HTTPStatusError(
|
||||||
|
f"Upstream {resp.status_code}", request=resp.request, response=resp
|
||||||
|
)
|
||||||
|
data = resp.json()
|
||||||
|
# Fire-and-forget mem0.add in the background. Skips tool-only
|
||||||
|
# responses inside _post_completion_add.
|
||||||
|
asyncio.create_task(self._post_completion_add(messages, data, user_id))
|
||||||
|
return data
|
||||||
|
|
||||||
|
return self._stream_proxy(url, request_kwargs, headers, messages, user_id)
|
||||||
|
|
||||||
|
async def _stream_proxy(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
request_kwargs: Dict[str, Any],
|
||||||
|
headers: Dict[str, str],
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
user_id: str,
|
||||||
|
) -> AsyncIterator[bytes]:
|
||||||
|
"""Stream upstream SSE bytes verbatim; accumulate content for post-stream add."""
|
||||||
|
accumulated_content: List[str] = []
|
||||||
|
saw_tool_calls = False
|
||||||
|
buffer = b""
|
||||||
|
try:
|
||||||
|
async with self.async_http.stream(
|
||||||
|
"POST", url, json=request_kwargs, headers=headers
|
||||||
|
) as upstream:
|
||||||
|
if upstream.status_code >= 400:
|
||||||
|
body_bytes = await upstream.aread()
|
||||||
|
err = {
|
||||||
|
"error": {
|
||||||
|
"message": (
|
||||||
|
f"Upstream {upstream.status_code}: "
|
||||||
|
f"{body_bytes.decode('utf-8', errors='replace')[:500]}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(err)}\n\n".encode("utf-8")
|
||||||
|
yield b"data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
async for chunk in upstream.aiter_bytes():
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
# Forward bytes verbatim — preserves the exact SSE wire format
|
||||||
|
yield chunk
|
||||||
|
# Side-channel parse for content/tool_calls accumulation.
|
||||||
|
# Events are \n\n-separated; data lines start with "data: ".
|
||||||
|
buffer += chunk
|
||||||
|
while b"\n\n" in buffer:
|
||||||
|
event_bytes, buffer = buffer.split(b"\n\n", 1)
|
||||||
|
for raw_line in event_bytes.split(b"\n"):
|
||||||
|
if not raw_line.startswith(b"data: "):
|
||||||
|
continue
|
||||||
|
payload = raw_line[6:].strip()
|
||||||
|
if not payload or payload == b"[DONE]":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
obj = json.loads(payload)
|
||||||
|
delta = (
|
||||||
|
(obj.get("choices") or [{}])[0].get("delta") or {}
|
||||||
|
)
|
||||||
|
content_piece = delta.get("content")
|
||||||
|
if isinstance(content_piece, str):
|
||||||
|
accumulated_content.append(content_piece)
|
||||||
|
if delta.get("tool_calls"):
|
||||||
|
saw_tool_calls = True
|
||||||
|
except (json.JSONDecodeError, KeyError, IndexError, TypeError):
|
||||||
|
pass
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.warning("upstream stream error", error=str(e), user_id=user_id)
|
||||||
|
err = {"error": {"message": f"Upstream stream error: {e}"}}
|
||||||
|
yield f"data: {json.dumps(err)}\n\n".encode("utf-8")
|
||||||
|
yield b"data: [DONE]\n\n"
|
||||||
|
finally:
|
||||||
|
# Per Memori #434: post-stream mem0.add must run even on mid-stream
|
||||||
|
# error so partial content is captured. Skipped for tool-only or
|
||||||
|
# empty responses.
|
||||||
|
full = "".join(accumulated_content).lstrip()
|
||||||
|
if full and not saw_tool_calls:
|
||||||
|
last_user_text = (
|
||||||
|
_extract_text(messages[-1].get("content"))
|
||||||
|
if messages
|
||||||
|
and isinstance(messages[-1], dict)
|
||||||
|
and messages[-1].get("role") == "user"
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if last_user_text:
|
||||||
|
try:
|
||||||
|
# Synchronous mem0.add in a thread to avoid blocking
|
||||||
|
# the response loop after StreamingResponse closes.
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.memory.add,
|
||||||
|
[
|
||||||
|
{"role": "user", "content": last_user_text},
|
||||||
|
{"role": "assistant", "content": full},
|
||||||
|
],
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"post-stream mem0.add failed",
|
||||||
|
error=str(e),
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _post_completion_add(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
response_data: Dict[str, Any],
|
||||||
|
user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Background mem0.add after a non-stream completion."""
|
||||||
|
try:
|
||||||
|
choice = (response_data.get("choices") or [{}])[0]
|
||||||
|
msg = choice.get("message") or {}
|
||||||
|
content = msg.get("content")
|
||||||
|
if not content or not isinstance(content, str):
|
||||||
|
return # tool-only or non-text completion — skip
|
||||||
|
content = content.lstrip()
|
||||||
|
last_user_text = (
|
||||||
|
_extract_text(messages[-1].get("content"))
|
||||||
|
if messages
|
||||||
|
and isinstance(messages[-1], dict)
|
||||||
|
and messages[-1].get("role") == "user"
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if not last_user_text:
|
||||||
|
return
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.memory.add,
|
||||||
|
[
|
||||||
|
{"role": "user", "content": last_user_text},
|
||||||
|
{"role": "assistant", "content": content},
|
||||||
|
],
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"post-completion mem0.add failed", error=str(e), user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
async def health_check(self) -> Dict[str, str]:
|
async def health_check(self) -> Dict[str, str]:
|
||||||
"""Basic health check - just connectivity."""
|
"""Basic health check - just connectivity."""
|
||||||
status = {}
|
status = {}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Ultra-minimal Pydantic models for pure Mem0 API."""
|
"""Ultra-minimal Pydantic models for pure Mem0 API."""
|
||||||
|
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -230,82 +230,56 @@ class UserStatsResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
# OpenAI-Compatible API Models
|
# OpenAI-Compatible API Models
|
||||||
|
#
|
||||||
|
# The /v1/chat/completions handler is a pass-through proxy: requests are
|
||||||
class OpenAIMessage(BaseModel):
|
# forwarded to the upstream LLM and the upstream response is relayed verbatim
|
||||||
"""OpenAI message format."""
|
# (dict for non-stream, raw SSE bytes for stream). Hence only the REQUEST
|
||||||
|
# model lives here; the response is never re-typed (typed Pydantic response
|
||||||
role: str = Field(..., description="Message role (system, user, assistant)")
|
# models silently drop unknown fields like tool_calls / refusal /
|
||||||
content: str = Field(..., description="Message content")
|
# reasoning_tokens — see the audit in the plan file).
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletionRequest(BaseModel):
|
class OpenAIChatCompletionRequest(BaseModel):
|
||||||
"""OpenAI chat completion request format."""
|
"""OpenAI chat completion request — permissive schema, forwarded as-is.
|
||||||
|
|
||||||
model: str = Field(..., description="Model to use (will use configured default)")
|
Only `model` and `messages` are required. All other standard OpenAI
|
||||||
messages: List[Dict[str, str]] = Field(..., description="List of messages")
|
parameters are typed (for client IDE/docs benefit) but optional. Unknown
|
||||||
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
|
fields are accepted via `extra="allow"` and forwarded to upstream, so
|
||||||
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
|
new OpenAI parameters don't require a code change here.
|
||||||
stream: Optional[bool] = Field(False, description="Whether to stream responses")
|
"""
|
||||||
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
|
|
||||||
n: Optional[int] = Field(1, description="Number of completions to generate")
|
model_config = ConfigDict(extra="allow")
|
||||||
stop: Optional[List[str]] = Field(None, description="Stop sequences")
|
|
||||||
presence_penalty: Optional[float] = Field(0, description="Presence penalty")
|
model: str = Field(..., description="Model to use (forwarded to upstream)")
|
||||||
frequency_penalty: Optional[float] = Field(0, description="Frequency penalty")
|
messages: List[Dict[str, Any]] = Field(
|
||||||
user: Optional[str] = Field(
|
..., description="Messages array (multi-part content supported)"
|
||||||
None, description="User identifier (ignored, uses API key)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Common params (typed for IDE/docs; all optional)
|
||||||
class OpenAIUsage(BaseModel):
|
temperature: Optional[float] = None
|
||||||
"""Token usage information."""
|
top_p: Optional[float] = None
|
||||||
|
n: Optional[int] = None
|
||||||
prompt_tokens: int = Field(..., description="Tokens in the prompt")
|
stream: Optional[bool] = None
|
||||||
completion_tokens: int = Field(..., description="Tokens in the completion")
|
stream_options: Optional[Dict[str, Any]] = None
|
||||||
total_tokens: int = Field(..., description="Total tokens used")
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
max_tokens: Optional[int] = None # deprecated; still accepted
|
||||||
|
max_completion_tokens: Optional[int] = None # replaces max_tokens
|
||||||
class OpenAIChoiceMessage(BaseModel):
|
presence_penalty: Optional[float] = None
|
||||||
"""Message in a choice."""
|
frequency_penalty: Optional[float] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
role: str = Field(..., description="Role of the message")
|
logprobs: Optional[bool] = None
|
||||||
content: str = Field(..., description="Content of the message")
|
top_logprobs: Optional[int] = None
|
||||||
|
response_format: Optional[Dict[str, Any]] = None
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None
|
||||||
class OpenAIChoice(BaseModel):
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
"""Individual completion choice."""
|
parallel_tool_calls: Optional[bool] = None
|
||||||
|
reasoning_effort: Optional[str] = None # o-series / reasoning models
|
||||||
index: int = Field(..., description="Choice index")
|
modalities: Optional[List[str]] = None
|
||||||
message: OpenAIChoiceMessage = Field(..., description="Message content")
|
audio: Optional[Dict[str, Any]] = None
|
||||||
finish_reason: str = Field(..., description="Reason for completion finish")
|
prediction: Optional[Dict[str, Any]] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
store: Optional[bool] = None
|
||||||
class OpenAIChatCompletionResponse(BaseModel):
|
service_tier: Optional[str] = None
|
||||||
"""OpenAI chat completion response format."""
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
# `user` is ignored — the authenticated user is derived from the API key.
|
||||||
id: str = Field(..., description="Unique completion ID")
|
user: Optional[str] = None
|
||||||
object: str = Field(default="chat.completion", description="Object type")
|
|
||||||
created: int = Field(..., description="Unix timestamp of creation")
|
|
||||||
model: str = Field(..., description="Model used for completion")
|
|
||||||
choices: List[OpenAIChoice] = Field(..., description="List of completion choices")
|
|
||||||
usage: Optional[OpenAIUsage] = Field(None, description="Token usage information")
|
|
||||||
|
|
||||||
|
|
||||||
# Streaming-specific models
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIStreamDelta(BaseModel):
|
|
||||||
"""Delta content in a streaming chunk."""
|
|
||||||
|
|
||||||
role: Optional[str] = Field(None, description="Role (only in first chunk)")
|
|
||||||
content: Optional[str] = Field(None, description="Incremental content")
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIStreamChoice(BaseModel):
|
|
||||||
"""Individual streaming choice."""
|
|
||||||
|
|
||||||
index: int = Field(..., description="Choice index")
|
|
||||||
delta: OpenAIStreamDelta = Field(..., description="Delta content")
|
|
||||||
finish_reason: Optional[str] = Field(
|
|
||||||
None, description="Reason for completion finish"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue