diff --git a/backend/main.py b/backend/main.py index d76de1c..d775706 100644 --- a/backend/main.py +++ b/backend/main.py @@ -11,7 +11,6 @@ from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security, from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse import structlog -import asyncio from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded @@ -29,6 +28,8 @@ def get_rate_limit_key(request: Request) -> str: limiter = Limiter(key_func=get_rate_limit_key) +import httpx + from models import ( ChatRequest, MemoryAddRequest, @@ -43,10 +44,6 @@ from models import ( GlobalStatsResponse, UserStatsResponse, OpenAIChatCompletionRequest, - OpenAIChatCompletionResponse, - OpenAIChoice, - OpenAIChoiceMessage, - OpenAIUsage, ) from mem0_manager import mem0_manager from auth import get_current_user, get_current_user_openai, auth_service @@ -109,6 +106,12 @@ async def lifespan(app: FastAPI): except Exception as e: logger.error(f"Error stopping MCP session manager: {e}") + # Close the async HTTP client used by the /v1/chat/completions proxy. + try: + await mem0_manager.aclose() + except Exception as e: + logger.warning(f"mem0_manager aclose failed: {e}") + logger.info("Shutting down Mem0 Interface POC") @@ -336,135 +339,59 @@ async def chat_with_memory( ) -async def stream_openai_response( - completion_id: str, model: str, content: str, created: int -): - """Generate SSE stream for OpenAI-compatible streaming by chunking the response.""" - import uuid - - # First chunk with role - chunk = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "delta": {"role": "assistant", "content": ""}, - "finish_reason": None, - } - ], - } - yield f"data: {json.dumps(chunk)}\n\n" - - # Stream content in chunks (3 words at a time for smooth effect) - words = content.split() - chunk_size = 3 - - for i in range(0, len(words), chunk_size): - word_chunk = " ".join(words[i : i + chunk_size]) - if i + chunk_size < len(words): - word_chunk += " " - - chunk = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - {"index": 0, "delta": {"content": word_chunk}, "finish_reason": None} - ], - } - yield f"data: {json.dumps(chunk)}\n\n" - await asyncio.sleep(0.05) - - # Final chunk - chunk = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], - } - yield f"data: {json.dumps(chunk)}\n\n" - yield "data: [DONE]\n\n" - - -@app.post("/v1/chat/completions") -@app.post("/chat/completions") +@app.post("/v1/chat/completions", response_model=None) +@app.post("/chat/completions", response_model=None) @limiter.limit("30/minute") async def openai_chat_completions( request: Request, completion_request: OpenAIChatCompletionRequest, authenticated_user: str = Depends(get_current_user_openai), ): - """OpenAI-compatible chat completions endpoint with mem0 memory integration.""" + """OpenAI-compatible chat completions — pass-through proxy with memory injection. + + Forwards the request to the upstream LLM verbatim (preserving tool_calls, + reasoning_tokens, system_fingerprint, finish_reason, etc.) and optionally + prepends a system message with relevant memories when the last role is + `user` and no tool flow is in progress. See + `mem0_manager.openai_proxy_completion` for the full injection rule and + tool-call safety contract. + """ try: - import uuid + request_kwargs = completion_request.model_dump(exclude_unset=True) + # `user` is the OpenAI client-supplied identifier; we always derive the + # real user from the API key. Strip it before forwarding so it can't + # confuse upstream user-tracking. + request_kwargs.pop("user", None) - user_id = authenticated_user logger.info( - f"OpenAI chat completion for user: {user_id} (streaming={completion_request.stream})" + "openai_chat_completions", + user_id=authenticated_user, + stream=bool(request_kwargs.get("stream")), + has_tools=bool(request_kwargs.get("tools")), + model=request_kwargs.get("model"), ) - # Extract last user message - user_messages = [ - m for m in completion_request.messages if m.get("role") == "user" - ] - if not user_messages: - raise HTTPException( - status_code=400, - detail="No user messages provided. Include at least one message with role='user'.", - ) - - last_message = user_messages[-1].get("content", "") - context = ( - completion_request.messages[:-1] - if len(completion_request.messages) > 1 - else None + result = await mem0_manager.openai_proxy_completion( + request_kwargs=request_kwargs, + user_id=authenticated_user, ) - # Call chat_with_memory - result = await mem0_manager.chat_with_memory( - message=last_message, - user_id=user_id, - context=context, - ) - - completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" - created_time = int(time.time()) - assistant_content = result.get("response", "") - - if completion_request.stream: + if request_kwargs.get("stream"): return StreamingResponse( - stream_openai_response( - completion_id=completion_id, - model=settings.default_model, - content=assistant_content, - created=created_time, - ), + result, media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) - else: - return OpenAIChatCompletionResponse( - id=completion_id, - object="chat.completion", - created=created_time, - model=settings.default_model, - choices=[ - OpenAIChoice( - index=0, - message=OpenAIChoiceMessage( - role="assistant", content=assistant_content - ), - finish_reason="stop", - ) - ], - usage=OpenAIUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), - ) - + return result + except httpx.HTTPStatusError as e: + # Forward upstream error body to the client with its original status + try: + detail = e.response.json() + except Exception: + detail = {"error": {"message": e.response.text[:500]}} + raise HTTPException(status_code=e.response.status_code, detail=detail) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) except HTTPException: raise except Exception as e: diff --git a/backend/mem0_manager.py b/backend/mem0_manager.py index 7575d86..7e08401 100644 --- a/backend/mem0_manager.py +++ b/backend/mem0_manager.py @@ -1,8 +1,12 @@ """Ultra-minimal Mem0 Manager - Pure Mem0 + Custom OpenAI Endpoint Only.""" +import asyncio +import json import logging -from typing import Dict, List, Optional, Any +from typing import Dict, List, Optional, Any, AsyncIterator, Union from datetime import datetime + +import httpx from mem0 import Memory from openai import OpenAI from tenacity import ( @@ -75,6 +79,28 @@ def _build_filters( return merged +def _extract_text(content: Any) -> str: + """Extract text from an OpenAI message `content` field. + + `content` can be a plain string OR a list of multi-part objects + (e.g., [{"type": "text", "text": "..."}, {"type": "image_url", ...}]). + Returns concatenated text parts, or "" if no text is present. + """ + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for p in content: + if isinstance(p, dict) and p.get("type") == "text": + t = p.get("text") + if isinstance(t, str): + parts.append(t) + return "\n".join(parts) + return "" + + # Appended as the "## Custom Instructions" section of the additive-extraction # prompt (mem0/configs/prompts.py::generate_additive_extraction_prompt). The # default few-shot bias is consumer-organizer ("favourite movies", "SF restaurants"), @@ -176,8 +202,23 @@ class Mem0Manager: timeout=60.0, # 60 second timeout for LLM calls max_retries=2, # Retry failed requests up to 2 times ) + # Async HTTP client used by openai_proxy_completion to forward + # /v1/chat/completions to the upstream endpoint verbatim (preserves + # tool_calls, reasoning_tokens, system_fingerprint, etc., which the + # openai-python SDK can drop). Closed in the FastAPI lifespan + # shutdown via aclose(). + self.async_http = httpx.AsyncClient( + timeout=httpx.Timeout(120.0, connect=10.0) + ) logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint") + async def aclose(self) -> None: + """Release async resources. Called from the FastAPI lifespan shutdown.""" + try: + await self.async_http.aclose() + except Exception as e: + logger.warning("async_http aclose failed", error=str(e)) + # Pure passthrough methods - no custom logic @db_retry @timed("add_memories") @@ -450,7 +491,11 @@ class Mem0Manager: response = self.openai_client.chat.completions.create( model=settings.default_model, messages=messages ) - assistant_response = response.choices[0].message.content + # Strip leading whitespace — reasoning models (minimax-m2) leak + # blank lines from their reasoning output into visible content. + # lstrip(), not strip(), preserves intentional trailing whitespace + # inside content (e.g., inside a code block). + assistant_response = (response.choices[0].message.content or "").lstrip() llm_time = time.time() - llm_start_time logger.debug( "LLM call completed", @@ -504,6 +549,261 @@ class Mem0Manager: "model_used": None, } + # ------------------------------------------------------------------ + # OpenAI-compatible /v1/chat/completions proxy + # ------------------------------------------------------------------ + # Design: forward client requests to the upstream endpoint verbatim, + # injecting a "Relevant memories" system message ONLY when the last + # message is from the user AND no tool flow is in progress. The + # opinionated chat_with_memory above is NOT used on this path because: + # (a) clients pick their own model, tools, response_format etc., + # (b) we must preserve every upstream field (tool_calls, reasoning_tokens, + # system_fingerprint, finish_reason), which the openai-python SDK and + # typed Pydantic responses strip silently, + # (c) streaming must be real (SSE pass-through), not post-hoc word-chunking. + # + # Tool-call safety contract (Memori issue #434 — codified here): + # - We only PREPEND a system message; never mutate, reorder, or delete + # existing messages. + # - We never touch tool messages or tool_call_id values. + # - We skip injection on tool-result follow-ups (last role != "user"). + # - We always run the post-stream mem0.add even if upstream errors + # mid-flight (via try/finally). + async def openai_proxy_completion( + self, + request_kwargs: Dict[str, Any], + user_id: str, + ) -> Union[Dict[str, Any], AsyncIterator[bytes]]: + """Proxy a chat-completions request to the upstream LLM. + + Returns: + - dict (upstream JSON) when stream is falsy + - async iterator of SSE bytes when stream is truthy + """ + messages = request_kwargs.get("messages") or [] + if not messages: + raise ValueError("messages array is empty") + + # Injection rule: last role == 'user' AND no prior tool message. + last_msg = messages[-1] if isinstance(messages[-1], dict) else {} + last_role = last_msg.get("role") + has_tool_message = any( + isinstance(m, dict) and m.get("role") == "tool" for m in messages + ) + inject_memory = last_role == "user" and not has_tool_message + + if inject_memory: + last_user_text = _extract_text(last_msg.get("content")) + if last_user_text: + try: + search_result = self.memory.search( + query=last_user_text, + filters=_build_filters(user_id), + top_k=30, + threshold=0.3, + rerank=True, + ) + relevant = (search_result.get("results") or [])[:10] + except Exception as e: + logger.warning( + "memory search failed; proceeding without injection", + error=str(e), + user_id=user_id, + ) + relevant = [] + if relevant: + memories_str = "\n".join( + f"- {m.get('memory', '')}" for m in relevant + ) + mem_block = ( + "Relevant memories about the user " + "(use when helpful, ignore otherwise):\n" + memories_str + ) + # Merge into existing leading system message if present; + # otherwise prepend a new one. Either way we never mutate + # any other message in the list. + if ( + messages + and isinstance(messages[0], dict) + and messages[0].get("role") == "system" + ): + existing = _extract_text(messages[0].get("content")) or "" + merged = (existing + "\n\n" + mem_block) if existing else mem_block + messages = [ + {"role": "system", "content": merged}, + *messages[1:], + ] + else: + messages = [ + {"role": "system", "content": mem_block}, + *messages, + ] + request_kwargs["messages"] = messages + logger.info( + "memory injected", + user_id=user_id, + memory_count=len(relevant), + ) + + url = settings.openai_base_url.rstrip("/") + "/v1/chat/completions" + headers = { + "Authorization": f"Bearer {settings.openai_api_key}", + "Content-Type": "application/json", + } + is_stream = bool(request_kwargs.get("stream")) + + if not is_stream: + resp = await self.async_http.post( + url, json=request_kwargs, headers=headers + ) + if resp.status_code >= 400: + # Surface upstream error body to the client unchanged + try: + err_body = resp.json() + except Exception: + err_body = {"error": {"message": resp.text[:500]}} + raise httpx.HTTPStatusError( + f"Upstream {resp.status_code}", request=resp.request, response=resp + ) + data = resp.json() + # Fire-and-forget mem0.add in the background. Skips tool-only + # responses inside _post_completion_add. + asyncio.create_task(self._post_completion_add(messages, data, user_id)) + return data + + return self._stream_proxy(url, request_kwargs, headers, messages, user_id) + + async def _stream_proxy( + self, + url: str, + request_kwargs: Dict[str, Any], + headers: Dict[str, str], + messages: List[Dict[str, Any]], + user_id: str, + ) -> AsyncIterator[bytes]: + """Stream upstream SSE bytes verbatim; accumulate content for post-stream add.""" + accumulated_content: List[str] = [] + saw_tool_calls = False + buffer = b"" + try: + async with self.async_http.stream( + "POST", url, json=request_kwargs, headers=headers + ) as upstream: + if upstream.status_code >= 400: + body_bytes = await upstream.aread() + err = { + "error": { + "message": ( + f"Upstream {upstream.status_code}: " + f"{body_bytes.decode('utf-8', errors='replace')[:500]}" + ) + } + } + yield f"data: {json.dumps(err)}\n\n".encode("utf-8") + yield b"data: [DONE]\n\n" + return + + async for chunk in upstream.aiter_bytes(): + if not chunk: + continue + # Forward bytes verbatim — preserves the exact SSE wire format + yield chunk + # Side-channel parse for content/tool_calls accumulation. + # Events are \n\n-separated; data lines start with "data: ". + buffer += chunk + while b"\n\n" in buffer: + event_bytes, buffer = buffer.split(b"\n\n", 1) + for raw_line in event_bytes.split(b"\n"): + if not raw_line.startswith(b"data: "): + continue + payload = raw_line[6:].strip() + if not payload or payload == b"[DONE]": + continue + try: + obj = json.loads(payload) + delta = ( + (obj.get("choices") or [{}])[0].get("delta") or {} + ) + content_piece = delta.get("content") + if isinstance(content_piece, str): + accumulated_content.append(content_piece) + if delta.get("tool_calls"): + saw_tool_calls = True + except (json.JSONDecodeError, KeyError, IndexError, TypeError): + pass + except httpx.HTTPError as e: + logger.warning("upstream stream error", error=str(e), user_id=user_id) + err = {"error": {"message": f"Upstream stream error: {e}"}} + yield f"data: {json.dumps(err)}\n\n".encode("utf-8") + yield b"data: [DONE]\n\n" + finally: + # Per Memori #434: post-stream mem0.add must run even on mid-stream + # error so partial content is captured. Skipped for tool-only or + # empty responses. + full = "".join(accumulated_content).lstrip() + if full and not saw_tool_calls: + last_user_text = ( + _extract_text(messages[-1].get("content")) + if messages + and isinstance(messages[-1], dict) + and messages[-1].get("role") == "user" + else None + ) + if last_user_text: + try: + # Synchronous mem0.add in a thread to avoid blocking + # the response loop after StreamingResponse closes. + await asyncio.to_thread( + self.memory.add, + [ + {"role": "user", "content": last_user_text}, + {"role": "assistant", "content": full}, + ], + user_id=user_id, + ) + except Exception as e: + logger.warning( + "post-stream mem0.add failed", + error=str(e), + user_id=user_id, + ) + + async def _post_completion_add( + self, + messages: List[Dict[str, Any]], + response_data: Dict[str, Any], + user_id: str, + ) -> None: + """Background mem0.add after a non-stream completion.""" + try: + choice = (response_data.get("choices") or [{}])[0] + msg = choice.get("message") or {} + content = msg.get("content") + if not content or not isinstance(content, str): + return # tool-only or non-text completion — skip + content = content.lstrip() + last_user_text = ( + _extract_text(messages[-1].get("content")) + if messages + and isinstance(messages[-1], dict) + and messages[-1].get("role") == "user" + else None + ) + if not last_user_text: + return + await asyncio.to_thread( + self.memory.add, + [ + {"role": "user", "content": last_user_text}, + {"role": "assistant", "content": content}, + ], + user_id=user_id, + ) + except Exception as e: + logger.warning( + "post-completion mem0.add failed", error=str(e), user_id=user_id + ) + async def health_check(self) -> Dict[str, str]: """Basic health check - just connectivity.""" status = {} diff --git a/backend/models.py b/backend/models.py index fff8d03..6e8847d 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,7 +1,7 @@ """Ultra-minimal Pydantic models for pure Mem0 API.""" -from typing import List, Optional, Dict, Any -from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any, Union +from pydantic import BaseModel, ConfigDict, Field import re @@ -230,82 +230,56 @@ class UserStatsResponse(BaseModel): # OpenAI-Compatible API Models - - -class OpenAIMessage(BaseModel): - """OpenAI message format.""" - - role: str = Field(..., description="Message role (system, user, assistant)") - content: str = Field(..., description="Message content") +# +# The /v1/chat/completions handler is a pass-through proxy: requests are +# forwarded to the upstream LLM and the upstream response is relayed verbatim +# (dict for non-stream, raw SSE bytes for stream). Hence only the REQUEST +# model lives here; the response is never re-typed (typed Pydantic response +# models silently drop unknown fields like tool_calls / refusal / +# reasoning_tokens — see the audit in the plan file). class OpenAIChatCompletionRequest(BaseModel): - """OpenAI chat completion request format.""" + """OpenAI chat completion request — permissive schema, forwarded as-is. - model: str = Field(..., description="Model to use (will use configured default)") - messages: List[Dict[str, str]] = Field(..., description="List of messages") - temperature: Optional[float] = Field(0.7, description="Sampling temperature") - max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate") - stream: Optional[bool] = Field(False, description="Whether to stream responses") - top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter") - n: Optional[int] = Field(1, description="Number of completions to generate") - stop: Optional[List[str]] = Field(None, description="Stop sequences") - presence_penalty: Optional[float] = Field(0, description="Presence penalty") - frequency_penalty: Optional[float] = Field(0, description="Frequency penalty") - user: Optional[str] = Field( - None, description="User identifier (ignored, uses API key)" + Only `model` and `messages` are required. All other standard OpenAI + parameters are typed (for client IDE/docs benefit) but optional. Unknown + fields are accepted via `extra="allow"` and forwarded to upstream, so + new OpenAI parameters don't require a code change here. + """ + + model_config = ConfigDict(extra="allow") + + model: str = Field(..., description="Model to use (forwarded to upstream)") + messages: List[Dict[str, Any]] = Field( + ..., description="Messages array (multi-part content supported)" ) - -class OpenAIUsage(BaseModel): - """Token usage information.""" - - prompt_tokens: int = Field(..., description="Tokens in the prompt") - completion_tokens: int = Field(..., description="Tokens in the completion") - total_tokens: int = Field(..., description="Total tokens used") - - -class OpenAIChoiceMessage(BaseModel): - """Message in a choice.""" - - role: str = Field(..., description="Role of the message") - content: str = Field(..., description="Content of the message") - - -class OpenAIChoice(BaseModel): - """Individual completion choice.""" - - index: int = Field(..., description="Choice index") - message: OpenAIChoiceMessage = Field(..., description="Message content") - finish_reason: str = Field(..., description="Reason for completion finish") - - -class OpenAIChatCompletionResponse(BaseModel): - """OpenAI chat completion response format.""" - - id: str = Field(..., description="Unique completion ID") - object: str = Field(default="chat.completion", description="Object type") - created: int = Field(..., description="Unix timestamp of creation") - model: str = Field(..., description="Model used for completion") - choices: List[OpenAIChoice] = Field(..., description="List of completion choices") - usage: Optional[OpenAIUsage] = Field(None, description="Token usage information") - - -# Streaming-specific models - - -class OpenAIStreamDelta(BaseModel): - """Delta content in a streaming chunk.""" - - role: Optional[str] = Field(None, description="Role (only in first chunk)") - content: Optional[str] = Field(None, description="Incremental content") - - -class OpenAIStreamChoice(BaseModel): - """Individual streaming choice.""" - - index: int = Field(..., description="Choice index") - delta: OpenAIStreamDelta = Field(..., description="Delta content") - finish_reason: Optional[str] = Field( - None, description="Reason for completion finish" - ) + # Common params (typed for IDE/docs; all optional) + temperature: Optional[float] = None + top_p: Optional[float] = None + n: Optional[int] = None + stream: Optional[bool] = None + stream_options: Optional[Dict[str, Any]] = None + stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = None # deprecated; still accepted + max_completion_tokens: Optional[int] = None # replaces max_tokens + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + seed: Optional[int] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[int] = None + response_format: Optional[Dict[str, Any]] = None + tools: Optional[List[Dict[str, Any]]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + parallel_tool_calls: Optional[bool] = None + reasoning_effort: Optional[str] = None # o-series / reasoning models + modalities: Optional[List[str]] = None + audio: Optional[Dict[str, Any]] = None + prediction: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None + store: Optional[bool] = None + service_tier: Optional[str] = None + logit_bias: Optional[Dict[str, float]] = None + # `user` is ignored — the authenticated user is derived from the API key. + user: Optional[str] = None