feat: rewrite /v1/chat/completions as a real OpenAI-compat proxy
Part 1: strip leading newlines from chat_with_memory's LLM response (minimax-m2 reasoning output leaks blank lines; .lstrip() at the source covers /chat, /v1/chat/completions, and the MCP chat tool). Part 2: replace the /v1/chat/completions handler with an httpx-based pass-through proxy that preserves every upstream field (tool_calls, reasoning_tokens, system_fingerprint, finish_reason, etc.) and supports end-to-end MCP-style tool calling. What changed: - models.py: OpenAIChatCompletionRequest is now permissive — typed for the common fields (tools, tool_choice, parallel_tool_calls, response_format, max_completion_tokens, seed, stream_options, reasoning_effort, modalities, etc.) and extra='allow' for forward- compat. The typed response models (OpenAIChatCompletionResponse and friends) are deleted — the handler returns upstream's JSON dict directly so unknown fields aren't silently dropped. - mem0_manager.py: adds httpx.AsyncClient + an openai_proxy_completion() method that injects a "Relevant memories" system message only when the last role is 'user' AND no tool flow is in progress, then forwards to the upstream LLM. Non-stream returns upstream JSON; stream returns an async iterator that yields raw upstream SSE bytes verbatim while side-channel-parsing for the post-stream mem0.add. Codifies the Memori #434 lessons: never mutates existing messages (only prepends system), never touches tool_call_id, runs post-add even on mid-stream error via try/finally. - main.py: handler is now ~50 lines — model_dump(exclude_unset) the request, hand off to openai_proxy_completion, return dict OR wrap in StreamingResponse. response_model=None so FastAPI doesn't validate. Deleted stream_openai_response (post-hoc word-chunking is gone). Lifespan shutdown closes mem0_manager.async_http. Research confirmed mem0 itself does not ship an HTTP /v1/chat/completions (only the in-process mem0.proxy.main.Mem0 SDK pattern), so we replicate the pattern without adding a litellm dependency. SSE/tool_calls patterns are modeled after microsoft/agent-lightning's llm_proxy. Verified locally: ast.parse OK on all three files. End-to-end smoke tests will run on beast.
This commit is contained in:
parent
e99b382b16
commit
e5a4d1c7c2
3 changed files with 395 additions and 194 deletions
161
backend/main.py
161
backend/main.py
|
|
@ -11,7 +11,6 @@ from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security,
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import structlog
|
||||
import asyncio
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
|
@ -29,6 +28,8 @@ def get_rate_limit_key(request: Request) -> str:
|
|||
|
||||
|
||||
limiter = Limiter(key_func=get_rate_limit_key)
|
||||
import httpx
|
||||
|
||||
from models import (
|
||||
ChatRequest,
|
||||
MemoryAddRequest,
|
||||
|
|
@ -43,10 +44,6 @@ from models import (
|
|||
GlobalStatsResponse,
|
||||
UserStatsResponse,
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIChatCompletionResponse,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceMessage,
|
||||
OpenAIUsage,
|
||||
)
|
||||
from mem0_manager import mem0_manager
|
||||
from auth import get_current_user, get_current_user_openai, auth_service
|
||||
|
|
@ -109,6 +106,12 @@ async def lifespan(app: FastAPI):
|
|||
except Exception as e:
|
||||
logger.error(f"Error stopping MCP session manager: {e}")
|
||||
|
||||
# Close the async HTTP client used by the /v1/chat/completions proxy.
|
||||
try:
|
||||
await mem0_manager.aclose()
|
||||
except Exception as e:
|
||||
logger.warning(f"mem0_manager aclose failed: {e}")
|
||||
|
||||
logger.info("Shutting down Mem0 Interface POC")
|
||||
|
||||
|
||||
|
|
@ -336,135 +339,59 @@ async def chat_with_memory(
|
|||
)
|
||||
|
||||
|
||||
async def stream_openai_response(
|
||||
completion_id: str, model: str, content: str, created: int
|
||||
):
|
||||
"""Generate SSE stream for OpenAI-compatible streaming by chunking the response."""
|
||||
import uuid
|
||||
|
||||
# First chunk with role
|
||||
chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": ""},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
# Stream content in chunks (3 words at a time for smooth effect)
|
||||
words = content.split()
|
||||
chunk_size = 3
|
||||
|
||||
for i in range(0, len(words), chunk_size):
|
||||
word_chunk = " ".join(words[i : i + chunk_size])
|
||||
if i + chunk_size < len(words):
|
||||
word_chunk += " "
|
||||
|
||||
chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": word_chunk}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Final chunk
|
||||
chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@app.post("/chat/completions")
|
||||
@app.post("/v1/chat/completions", response_model=None)
|
||||
@app.post("/chat/completions", response_model=None)
|
||||
@limiter.limit("30/minute")
|
||||
async def openai_chat_completions(
|
||||
request: Request,
|
||||
completion_request: OpenAIChatCompletionRequest,
|
||||
authenticated_user: str = Depends(get_current_user_openai),
|
||||
):
|
||||
"""OpenAI-compatible chat completions endpoint with mem0 memory integration."""
|
||||
"""OpenAI-compatible chat completions — pass-through proxy with memory injection.
|
||||
|
||||
Forwards the request to the upstream LLM verbatim (preserving tool_calls,
|
||||
reasoning_tokens, system_fingerprint, finish_reason, etc.) and optionally
|
||||
prepends a system message with relevant memories when the last role is
|
||||
`user` and no tool flow is in progress. See
|
||||
`mem0_manager.openai_proxy_completion` for the full injection rule and
|
||||
tool-call safety contract.
|
||||
"""
|
||||
try:
|
||||
import uuid
|
||||
request_kwargs = completion_request.model_dump(exclude_unset=True)
|
||||
# `user` is the OpenAI client-supplied identifier; we always derive the
|
||||
# real user from the API key. Strip it before forwarding so it can't
|
||||
# confuse upstream user-tracking.
|
||||
request_kwargs.pop("user", None)
|
||||
|
||||
user_id = authenticated_user
|
||||
logger.info(
|
||||
f"OpenAI chat completion for user: {user_id} (streaming={completion_request.stream})"
|
||||
"openai_chat_completions",
|
||||
user_id=authenticated_user,
|
||||
stream=bool(request_kwargs.get("stream")),
|
||||
has_tools=bool(request_kwargs.get("tools")),
|
||||
model=request_kwargs.get("model"),
|
||||
)
|
||||
|
||||
# Extract last user message
|
||||
user_messages = [
|
||||
m for m in completion_request.messages if m.get("role") == "user"
|
||||
]
|
||||
if not user_messages:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No user messages provided. Include at least one message with role='user'.",
|
||||
result = await mem0_manager.openai_proxy_completion(
|
||||
request_kwargs=request_kwargs,
|
||||
user_id=authenticated_user,
|
||||
)
|
||||
|
||||
last_message = user_messages[-1].get("content", "")
|
||||
context = (
|
||||
completion_request.messages[:-1]
|
||||
if len(completion_request.messages) > 1
|
||||
else None
|
||||
)
|
||||
|
||||
# Call chat_with_memory
|
||||
result = await mem0_manager.chat_with_memory(
|
||||
message=last_message,
|
||||
user_id=user_id,
|
||||
context=context,
|
||||
)
|
||||
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
created_time = int(time.time())
|
||||
assistant_content = result.get("response", "")
|
||||
|
||||
if completion_request.stream:
|
||||
if request_kwargs.get("stream"):
|
||||
return StreamingResponse(
|
||||
stream_openai_response(
|
||||
completion_id=completion_id,
|
||||
model=settings.default_model,
|
||||
content=assistant_content,
|
||||
created=created_time,
|
||||
),
|
||||
result,
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
else:
|
||||
return OpenAIChatCompletionResponse(
|
||||
id=completion_id,
|
||||
object="chat.completion",
|
||||
created=created_time,
|
||||
model=settings.default_model,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
message=OpenAIChoiceMessage(
|
||||
role="assistant", content=assistant_content
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=OpenAIUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
|
||||
return result
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Forward upstream error body to the client with its original status
|
||||
try:
|
||||
detail = e.response.json()
|
||||
except Exception:
|
||||
detail = {"error": {"message": e.response.text[:500]}}
|
||||
raise HTTPException(status_code=e.response.status_code, detail=detail)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
"""Ultra-minimal Mem0 Manager - Pure Mem0 + Custom OpenAI Endpoint Only."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, AsyncIterator, Union
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
from mem0 import Memory
|
||||
from openai import OpenAI
|
||||
from tenacity import (
|
||||
|
|
@ -75,6 +79,28 @@ def _build_filters(
|
|||
return merged
|
||||
|
||||
|
||||
def _extract_text(content: Any) -> str:
|
||||
"""Extract text from an OpenAI message `content` field.
|
||||
|
||||
`content` can be a plain string OR a list of multi-part objects
|
||||
(e.g., [{"type": "text", "text": "..."}, {"type": "image_url", ...}]).
|
||||
Returns concatenated text parts, or "" if no text is present.
|
||||
"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for p in content:
|
||||
if isinstance(p, dict) and p.get("type") == "text":
|
||||
t = p.get("text")
|
||||
if isinstance(t, str):
|
||||
parts.append(t)
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
# Appended as the "## Custom Instructions" section of the additive-extraction
|
||||
# prompt (mem0/configs/prompts.py::generate_additive_extraction_prompt). The
|
||||
# default few-shot bias is consumer-organizer ("favourite movies", "SF restaurants"),
|
||||
|
|
@ -176,8 +202,23 @@ class Mem0Manager:
|
|||
timeout=60.0, # 60 second timeout for LLM calls
|
||||
max_retries=2, # Retry failed requests up to 2 times
|
||||
)
|
||||
# Async HTTP client used by openai_proxy_completion to forward
|
||||
# /v1/chat/completions to the upstream endpoint verbatim (preserves
|
||||
# tool_calls, reasoning_tokens, system_fingerprint, etc., which the
|
||||
# openai-python SDK can drop). Closed in the FastAPI lifespan
|
||||
# shutdown via aclose().
|
||||
self.async_http = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(120.0, connect=10.0)
|
||||
)
|
||||
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Release async resources. Called from the FastAPI lifespan shutdown."""
|
||||
try:
|
||||
await self.async_http.aclose()
|
||||
except Exception as e:
|
||||
logger.warning("async_http aclose failed", error=str(e))
|
||||
|
||||
# Pure passthrough methods - no custom logic
|
||||
@db_retry
|
||||
@timed("add_memories")
|
||||
|
|
@ -450,7 +491,11 @@ class Mem0Manager:
|
|||
response = self.openai_client.chat.completions.create(
|
||||
model=settings.default_model, messages=messages
|
||||
)
|
||||
assistant_response = response.choices[0].message.content
|
||||
# Strip leading whitespace — reasoning models (minimax-m2) leak
|
||||
# blank lines from their reasoning output into visible content.
|
||||
# lstrip(), not strip(), preserves intentional trailing whitespace
|
||||
# inside content (e.g., inside a code block).
|
||||
assistant_response = (response.choices[0].message.content or "").lstrip()
|
||||
llm_time = time.time() - llm_start_time
|
||||
logger.debug(
|
||||
"LLM call completed",
|
||||
|
|
@ -504,6 +549,261 @@ class Mem0Manager:
|
|||
"model_used": None,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# OpenAI-compatible /v1/chat/completions proxy
|
||||
# ------------------------------------------------------------------
|
||||
# Design: forward client requests to the upstream endpoint verbatim,
|
||||
# injecting a "Relevant memories" system message ONLY when the last
|
||||
# message is from the user AND no tool flow is in progress. The
|
||||
# opinionated chat_with_memory above is NOT used on this path because:
|
||||
# (a) clients pick their own model, tools, response_format etc.,
|
||||
# (b) we must preserve every upstream field (tool_calls, reasoning_tokens,
|
||||
# system_fingerprint, finish_reason), which the openai-python SDK and
|
||||
# typed Pydantic responses strip silently,
|
||||
# (c) streaming must be real (SSE pass-through), not post-hoc word-chunking.
|
||||
#
|
||||
# Tool-call safety contract (Memori issue #434 — codified here):
|
||||
# - We only PREPEND a system message; never mutate, reorder, or delete
|
||||
# existing messages.
|
||||
# - We never touch tool messages or tool_call_id values.
|
||||
# - We skip injection on tool-result follow-ups (last role != "user").
|
||||
# - We always run the post-stream mem0.add even if upstream errors
|
||||
# mid-flight (via try/finally).
|
||||
async def openai_proxy_completion(
|
||||
self,
|
||||
request_kwargs: Dict[str, Any],
|
||||
user_id: str,
|
||||
) -> Union[Dict[str, Any], AsyncIterator[bytes]]:
|
||||
"""Proxy a chat-completions request to the upstream LLM.
|
||||
|
||||
Returns:
|
||||
- dict (upstream JSON) when stream is falsy
|
||||
- async iterator of SSE bytes when stream is truthy
|
||||
"""
|
||||
messages = request_kwargs.get("messages") or []
|
||||
if not messages:
|
||||
raise ValueError("messages array is empty")
|
||||
|
||||
# Injection rule: last role == 'user' AND no prior tool message.
|
||||
last_msg = messages[-1] if isinstance(messages[-1], dict) else {}
|
||||
last_role = last_msg.get("role")
|
||||
has_tool_message = any(
|
||||
isinstance(m, dict) and m.get("role") == "tool" for m in messages
|
||||
)
|
||||
inject_memory = last_role == "user" and not has_tool_message
|
||||
|
||||
if inject_memory:
|
||||
last_user_text = _extract_text(last_msg.get("content"))
|
||||
if last_user_text:
|
||||
try:
|
||||
search_result = self.memory.search(
|
||||
query=last_user_text,
|
||||
filters=_build_filters(user_id),
|
||||
top_k=30,
|
||||
threshold=0.3,
|
||||
rerank=True,
|
||||
)
|
||||
relevant = (search_result.get("results") or [])[:10]
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"memory search failed; proceeding without injection",
|
||||
error=str(e),
|
||||
user_id=user_id,
|
||||
)
|
||||
relevant = []
|
||||
if relevant:
|
||||
memories_str = "\n".join(
|
||||
f"- {m.get('memory', '')}" for m in relevant
|
||||
)
|
||||
mem_block = (
|
||||
"Relevant memories about the user "
|
||||
"(use when helpful, ignore otherwise):\n" + memories_str
|
||||
)
|
||||
# Merge into existing leading system message if present;
|
||||
# otherwise prepend a new one. Either way we never mutate
|
||||
# any other message in the list.
|
||||
if (
|
||||
messages
|
||||
and isinstance(messages[0], dict)
|
||||
and messages[0].get("role") == "system"
|
||||
):
|
||||
existing = _extract_text(messages[0].get("content")) or ""
|
||||
merged = (existing + "\n\n" + mem_block) if existing else mem_block
|
||||
messages = [
|
||||
{"role": "system", "content": merged},
|
||||
*messages[1:],
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": mem_block},
|
||||
*messages,
|
||||
]
|
||||
request_kwargs["messages"] = messages
|
||||
logger.info(
|
||||
"memory injected",
|
||||
user_id=user_id,
|
||||
memory_count=len(relevant),
|
||||
)
|
||||
|
||||
url = settings.openai_base_url.rstrip("/") + "/v1/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {settings.openai_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
is_stream = bool(request_kwargs.get("stream"))
|
||||
|
||||
if not is_stream:
|
||||
resp = await self.async_http.post(
|
||||
url, json=request_kwargs, headers=headers
|
||||
)
|
||||
if resp.status_code >= 400:
|
||||
# Surface upstream error body to the client unchanged
|
||||
try:
|
||||
err_body = resp.json()
|
||||
except Exception:
|
||||
err_body = {"error": {"message": resp.text[:500]}}
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Upstream {resp.status_code}", request=resp.request, response=resp
|
||||
)
|
||||
data = resp.json()
|
||||
# Fire-and-forget mem0.add in the background. Skips tool-only
|
||||
# responses inside _post_completion_add.
|
||||
asyncio.create_task(self._post_completion_add(messages, data, user_id))
|
||||
return data
|
||||
|
||||
return self._stream_proxy(url, request_kwargs, headers, messages, user_id)
|
||||
|
||||
async def _stream_proxy(
|
||||
self,
|
||||
url: str,
|
||||
request_kwargs: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
messages: List[Dict[str, Any]],
|
||||
user_id: str,
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""Stream upstream SSE bytes verbatim; accumulate content for post-stream add."""
|
||||
accumulated_content: List[str] = []
|
||||
saw_tool_calls = False
|
||||
buffer = b""
|
||||
try:
|
||||
async with self.async_http.stream(
|
||||
"POST", url, json=request_kwargs, headers=headers
|
||||
) as upstream:
|
||||
if upstream.status_code >= 400:
|
||||
body_bytes = await upstream.aread()
|
||||
err = {
|
||||
"error": {
|
||||
"message": (
|
||||
f"Upstream {upstream.status_code}: "
|
||||
f"{body_bytes.decode('utf-8', errors='replace')[:500]}"
|
||||
)
|
||||
}
|
||||
}
|
||||
yield f"data: {json.dumps(err)}\n\n".encode("utf-8")
|
||||
yield b"data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
async for chunk in upstream.aiter_bytes():
|
||||
if not chunk:
|
||||
continue
|
||||
# Forward bytes verbatim — preserves the exact SSE wire format
|
||||
yield chunk
|
||||
# Side-channel parse for content/tool_calls accumulation.
|
||||
# Events are \n\n-separated; data lines start with "data: ".
|
||||
buffer += chunk
|
||||
while b"\n\n" in buffer:
|
||||
event_bytes, buffer = buffer.split(b"\n\n", 1)
|
||||
for raw_line in event_bytes.split(b"\n"):
|
||||
if not raw_line.startswith(b"data: "):
|
||||
continue
|
||||
payload = raw_line[6:].strip()
|
||||
if not payload or payload == b"[DONE]":
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(payload)
|
||||
delta = (
|
||||
(obj.get("choices") or [{}])[0].get("delta") or {}
|
||||
)
|
||||
content_piece = delta.get("content")
|
||||
if isinstance(content_piece, str):
|
||||
accumulated_content.append(content_piece)
|
||||
if delta.get("tool_calls"):
|
||||
saw_tool_calls = True
|
||||
except (json.JSONDecodeError, KeyError, IndexError, TypeError):
|
||||
pass
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning("upstream stream error", error=str(e), user_id=user_id)
|
||||
err = {"error": {"message": f"Upstream stream error: {e}"}}
|
||||
yield f"data: {json.dumps(err)}\n\n".encode("utf-8")
|
||||
yield b"data: [DONE]\n\n"
|
||||
finally:
|
||||
# Per Memori #434: post-stream mem0.add must run even on mid-stream
|
||||
# error so partial content is captured. Skipped for tool-only or
|
||||
# empty responses.
|
||||
full = "".join(accumulated_content).lstrip()
|
||||
if full and not saw_tool_calls:
|
||||
last_user_text = (
|
||||
_extract_text(messages[-1].get("content"))
|
||||
if messages
|
||||
and isinstance(messages[-1], dict)
|
||||
and messages[-1].get("role") == "user"
|
||||
else None
|
||||
)
|
||||
if last_user_text:
|
||||
try:
|
||||
# Synchronous mem0.add in a thread to avoid blocking
|
||||
# the response loop after StreamingResponse closes.
|
||||
await asyncio.to_thread(
|
||||
self.memory.add,
|
||||
[
|
||||
{"role": "user", "content": last_user_text},
|
||||
{"role": "assistant", "content": full},
|
||||
],
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"post-stream mem0.add failed",
|
||||
error=str(e),
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
async def _post_completion_add(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
response_data: Dict[str, Any],
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""Background mem0.add after a non-stream completion."""
|
||||
try:
|
||||
choice = (response_data.get("choices") or [{}])[0]
|
||||
msg = choice.get("message") or {}
|
||||
content = msg.get("content")
|
||||
if not content or not isinstance(content, str):
|
||||
return # tool-only or non-text completion — skip
|
||||
content = content.lstrip()
|
||||
last_user_text = (
|
||||
_extract_text(messages[-1].get("content"))
|
||||
if messages
|
||||
and isinstance(messages[-1], dict)
|
||||
and messages[-1].get("role") == "user"
|
||||
else None
|
||||
)
|
||||
if not last_user_text:
|
||||
return
|
||||
await asyncio.to_thread(
|
||||
self.memory.add,
|
||||
[
|
||||
{"role": "user", "content": last_user_text},
|
||||
{"role": "assistant", "content": content},
|
||||
],
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"post-completion mem0.add failed", error=str(e), user_id=user_id
|
||||
)
|
||||
|
||||
async def health_check(self) -> Dict[str, str]:
|
||||
"""Basic health check - just connectivity."""
|
||||
status = {}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Ultra-minimal Pydantic models for pure Mem0 API."""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
import re
|
||||
|
||||
|
||||
|
|
@ -230,82 +230,56 @@ class UserStatsResponse(BaseModel):
|
|||
|
||||
|
||||
# OpenAI-Compatible API Models
|
||||
|
||||
|
||||
class OpenAIMessage(BaseModel):
|
||||
"""OpenAI message format."""
|
||||
|
||||
role: str = Field(..., description="Message role (system, user, assistant)")
|
||||
content: str = Field(..., description="Message content")
|
||||
#
|
||||
# The /v1/chat/completions handler is a pass-through proxy: requests are
|
||||
# forwarded to the upstream LLM and the upstream response is relayed verbatim
|
||||
# (dict for non-stream, raw SSE bytes for stream). Hence only the REQUEST
|
||||
# model lives here; the response is never re-typed (typed Pydantic response
|
||||
# models silently drop unknown fields like tool_calls / refusal /
|
||||
# reasoning_tokens — see the audit in the plan file).
|
||||
|
||||
|
||||
class OpenAIChatCompletionRequest(BaseModel):
|
||||
"""OpenAI chat completion request format."""
|
||||
"""OpenAI chat completion request — permissive schema, forwarded as-is.
|
||||
|
||||
model: str = Field(..., description="Model to use (will use configured default)")
|
||||
messages: List[Dict[str, str]] = Field(..., description="List of messages")
|
||||
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
|
||||
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
|
||||
stream: Optional[bool] = Field(False, description="Whether to stream responses")
|
||||
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
|
||||
n: Optional[int] = Field(1, description="Number of completions to generate")
|
||||
stop: Optional[List[str]] = Field(None, description="Stop sequences")
|
||||
presence_penalty: Optional[float] = Field(0, description="Presence penalty")
|
||||
frequency_penalty: Optional[float] = Field(0, description="Frequency penalty")
|
||||
user: Optional[str] = Field(
|
||||
None, description="User identifier (ignored, uses API key)"
|
||||
Only `model` and `messages` are required. All other standard OpenAI
|
||||
parameters are typed (for client IDE/docs benefit) but optional. Unknown
|
||||
fields are accepted via `extra="allow"` and forwarded to upstream, so
|
||||
new OpenAI parameters don't require a code change here.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
model: str = Field(..., description="Model to use (forwarded to upstream)")
|
||||
messages: List[Dict[str, Any]] = Field(
|
||||
..., description="Messages array (multi-part content supported)"
|
||||
)
|
||||
|
||||
|
||||
class OpenAIUsage(BaseModel):
|
||||
"""Token usage information."""
|
||||
|
||||
prompt_tokens: int = Field(..., description="Tokens in the prompt")
|
||||
completion_tokens: int = Field(..., description="Tokens in the completion")
|
||||
total_tokens: int = Field(..., description="Total tokens used")
|
||||
|
||||
|
||||
class OpenAIChoiceMessage(BaseModel):
|
||||
"""Message in a choice."""
|
||||
|
||||
role: str = Field(..., description="Role of the message")
|
||||
content: str = Field(..., description="Content of the message")
|
||||
|
||||
|
||||
class OpenAIChoice(BaseModel):
|
||||
"""Individual completion choice."""
|
||||
|
||||
index: int = Field(..., description="Choice index")
|
||||
message: OpenAIChoiceMessage = Field(..., description="Message content")
|
||||
finish_reason: str = Field(..., description="Reason for completion finish")
|
||||
|
||||
|
||||
class OpenAIChatCompletionResponse(BaseModel):
|
||||
"""OpenAI chat completion response format."""
|
||||
|
||||
id: str = Field(..., description="Unique completion ID")
|
||||
object: str = Field(default="chat.completion", description="Object type")
|
||||
created: int = Field(..., description="Unix timestamp of creation")
|
||||
model: str = Field(..., description="Model used for completion")
|
||||
choices: List[OpenAIChoice] = Field(..., description="List of completion choices")
|
||||
usage: Optional[OpenAIUsage] = Field(None, description="Token usage information")
|
||||
|
||||
|
||||
# Streaming-specific models
|
||||
|
||||
|
||||
class OpenAIStreamDelta(BaseModel):
|
||||
"""Delta content in a streaming chunk."""
|
||||
|
||||
role: Optional[str] = Field(None, description="Role (only in first chunk)")
|
||||
content: Optional[str] = Field(None, description="Incremental content")
|
||||
|
||||
|
||||
class OpenAIStreamChoice(BaseModel):
|
||||
"""Individual streaming choice."""
|
||||
|
||||
index: int = Field(..., description="Choice index")
|
||||
delta: OpenAIStreamDelta = Field(..., description="Delta content")
|
||||
finish_reason: Optional[str] = Field(
|
||||
None, description="Reason for completion finish"
|
||||
)
|
||||
# Common params (typed for IDE/docs; all optional)
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stream_options: Optional[Dict[str, Any]] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_tokens: Optional[int] = None # deprecated; still accepted
|
||||
max_completion_tokens: Optional[int] = None # replaces max_tokens
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
response_format: Optional[Dict[str, Any]] = None
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||
parallel_tool_calls: Optional[bool] = None
|
||||
reasoning_effort: Optional[str] = None # o-series / reasoning models
|
||||
modalities: Optional[List[str]] = None
|
||||
audio: Optional[Dict[str, Any]] = None
|
||||
prediction: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
store: Optional[bool] = None
|
||||
service_tier: Optional[str] = None
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
# `user` is ignored — the authenticated user is derived from the API key.
|
||||
user: Optional[str] = None
|
||||
|
|
|
|||
Loading…
Reference in a new issue