Part 1: strip leading newlines from chat_with_memory's LLM response (minimax-m2 reasoning output leaks blank lines; .lstrip() at the source covers /chat, /v1/chat/completions, and the MCP chat tool). Part 2: replace the /v1/chat/completions handler with an httpx-based pass-through proxy that preserves every upstream field (tool_calls, reasoning_tokens, system_fingerprint, finish_reason, etc.) and supports end-to-end MCP-style tool calling. What changed: - models.py: OpenAIChatCompletionRequest is now permissive — typed for the common fields (tools, tool_choice, parallel_tool_calls, response_format, max_completion_tokens, seed, stream_options, reasoning_effort, modalities, etc.) and extra='allow' for forward- compat. The typed response models (OpenAIChatCompletionResponse and friends) are deleted — the handler returns upstream's JSON dict directly so unknown fields aren't silently dropped. - mem0_manager.py: adds httpx.AsyncClient + an openai_proxy_completion() method that injects a "Relevant memories" system message only when the last role is 'user' AND no tool flow is in progress, then forwards to the upstream LLM. Non-stream returns upstream JSON; stream returns an async iterator that yields raw upstream SSE bytes verbatim while side-channel-parsing for the post-stream mem0.add. Codifies the Memori #434 lessons: never mutates existing messages (only prepends system), never touches tool_call_id, runs post-add even on mid-stream error via try/finally. - main.py: handler is now ~50 lines — model_dump(exclude_unset) the request, hand off to openai_proxy_completion, return dict OR wrap in StreamingResponse. response_model=None so FastAPI doesn't validate. Deleted stream_openai_response (post-hoc word-chunking is gone). Lifespan shutdown closes mem0_manager.async_http. Research confirmed mem0 itself does not ship an HTTP /v1/chat/completions (only the in-process mem0.proxy.main.Mem0 SDK pattern), so we replicate the pattern without adding a litellm dependency. SSE/tool_calls patterns are modeled after microsoft/agent-lightning's llm_proxy. Verified locally: ast.parse OK on all three files. End-to-end smoke tests will run on beast.
831 lines
34 KiB
Python
831 lines
34 KiB
Python
"""Ultra-minimal Mem0 Manager - Pure Mem0 + Custom OpenAI Endpoint Only."""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from typing import Dict, List, Optional, Any, AsyncIterator, Union
|
||
from datetime import datetime
|
||
|
||
import httpx
|
||
from mem0 import Memory
|
||
from openai import OpenAI
|
||
from tenacity import (
|
||
retry,
|
||
stop_after_attempt,
|
||
wait_exponential,
|
||
retry_if_exception_type,
|
||
before_sleep_log,
|
||
)
|
||
import structlog
|
||
|
||
from config import settings
|
||
from monitoring import timed
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
|
||
# Retry decorator for database operations (Qdrant)
|
||
db_retry = retry(
|
||
stop=stop_after_attempt(3),
|
||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||
retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)),
|
||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||
reraise=True,
|
||
)
|
||
|
||
# Monkey-patch Mem0's OpenAI LLM to clear top_p when the configured LLM
|
||
# is Claude reached via an OpenAI-compatible endpoint: Claude rejects top_p
|
||
# whenever temperature is set, and OpenAILLM sends both unconditionally.
|
||
# (The 'store' branch is now redundant in mem0ai>=2.0.0 — upstream made it
|
||
# opt-in — but harmless; kept for safety.)
|
||
from mem0.llms.openai import OpenAILLM
|
||
|
||
_original_generate_response = OpenAILLM.generate_response
|
||
|
||
|
||
def patched_generate_response(
|
||
self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs
|
||
):
|
||
if hasattr(self.config, "store"):
|
||
self.config.store = None
|
||
if hasattr(self.config, "top_p"):
|
||
self.config.top_p = None
|
||
return _original_generate_response(
|
||
self, messages, response_format, tools, tool_choice, **kwargs
|
||
)
|
||
|
||
|
||
OpenAILLM.generate_response = patched_generate_response
|
||
logger.info("Applied Claude/OpenAI-compatible patch: cleared top_p (and store)")
|
||
|
||
|
||
def _build_filters(
|
||
user_id: Optional[str],
|
||
agent_id: Optional[str] = None,
|
||
run_id: Optional[str] = None,
|
||
extra: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""Build the filters dict required by mem0 v2 search/get_all.
|
||
|
||
In mem0 v2.x, user_id/agent_id/run_id are rejected as top-level kwargs
|
||
on Memory.search and Memory.get_all — they must live inside `filters`.
|
||
"""
|
||
merged: Dict[str, Any] = dict(extra) if extra else {}
|
||
if user_id is not None:
|
||
merged["user_id"] = user_id
|
||
if agent_id is not None:
|
||
merged["agent_id"] = agent_id
|
||
if run_id is not None:
|
||
merged["run_id"] = run_id
|
||
return merged
|
||
|
||
|
||
def _extract_text(content: Any) -> str:
|
||
"""Extract text from an OpenAI message `content` field.
|
||
|
||
`content` can be a plain string OR a list of multi-part objects
|
||
(e.g., [{"type": "text", "text": "..."}, {"type": "image_url", ...}]).
|
||
Returns concatenated text parts, or "" if no text is present.
|
||
"""
|
||
if content is None:
|
||
return ""
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
parts = []
|
||
for p in content:
|
||
if isinstance(p, dict) and p.get("type") == "text":
|
||
t = p.get("text")
|
||
if isinstance(t, str):
|
||
parts.append(t)
|
||
return "\n".join(parts)
|
||
return ""
|
||
|
||
|
||
# Appended as the "## Custom Instructions" section of the additive-extraction
|
||
# prompt (mem0/configs/prompts.py::generate_additive_extraction_prompt). The
|
||
# default few-shot bias is consumer-organizer ("favourite movies", "SF restaurants"),
|
||
# which under-extracts on the work/project/relationship traffic this deployment
|
||
# actually sees. This re-prioritizes without replacing mem0's structural guidance.
|
||
CUSTOM_FACT_EXTRACTION_INSTRUCTIONS = """
|
||
This memory store serves a working assistant — engineering, product, and operational contexts plus the user's people and recurring life context. Prioritize accordingly:
|
||
|
||
HIGH-VALUE facts to capture:
|
||
- Work context: company, team, role; ongoing projects with goals/status/blockers; product or domain knowledge being built; tools/frameworks/languages in active use; technical decisions and the reasoning; recurring meetings or rituals.
|
||
- People in the user's orbit: colleagues, family, friends, mentors — names, relationships, roles, what they do, the current state of the relationship or shared context.
|
||
- Recurring personal context: home/work locations, regular schedule, standing commitments, durable preferences (food restrictions, working hours, communication style), planned events with dates.
|
||
- Acquired knowledge: concepts being studied or built, specific problems being solved, prior solutions tried and their outcomes.
|
||
|
||
LOWER-PRIORITY (extract only if they reveal a pattern or future relevance):
|
||
- Single transient states ("running 5 minutes late", "didn't sleep well") — capture only if they recur or signal a habit.
|
||
- Movies, music, restaurants, hobbies — only when noted as durable preferences or part of a recurring activity, not when mentioned in passing.
|
||
|
||
SKIP entirely:
|
||
- Generic world knowledge (timezones, capital cities, definitions) — the assistant already knows these.
|
||
- Greetings, acknowledgments, meta-conversation ("Thanks!", "Got it").
|
||
- Restatements or paraphrases of facts already in Existing Memories or Recently Extracted Memories.
|
||
|
||
Prefer specificity. "Pratik uses FastAPI for backend services" beats "Pratik does backend development." When a person is mentioned by a short name or nickname, capture the relationship if known ("Anushree is Pratik's wife") so future references resolve correctly.
|
||
""".strip()
|
||
|
||
|
||
class Mem0Manager:
|
||
"""
|
||
Ultra-minimal manager that bridges custom OpenAI endpoint with pure Mem0.
|
||
No custom logic - let Mem0 handle all memory intelligence.
|
||
"""
|
||
|
||
def __init__(self):
|
||
logger.info(
|
||
"Initializing Mem0Manager with custom endpoint",
|
||
model=settings.default_model,
|
||
embedding_model=settings.embedding_model,
|
||
embedding_dims=settings.embedding_dims,
|
||
qdrant_host=settings.qdrant_host,
|
||
)
|
||
config = {
|
||
"version": "v1.1",
|
||
"custom_instructions": CUSTOM_FACT_EXTRACTION_INSTRUCTIONS,
|
||
"llm": {
|
||
"provider": "openai",
|
||
"config": {
|
||
"model": settings.default_model,
|
||
"api_key": settings.openai_api_key,
|
||
"openai_base_url": settings.openai_base_url,
|
||
"temperature": 0.1,
|
||
"top_p": None,
|
||
},
|
||
},
|
||
"embedder": {
|
||
# Route embeddings through the OpenAI-compatible LiteLLM proxy
|
||
# rather than Ollama directly — the proxy is reachable from the
|
||
# container in all deployments, Ollama may not be. The model
|
||
# name is the same (qwen3-embedding:4b-q8_0); existing vectors
|
||
# generated via this path stay compatible.
|
||
"provider": "openai",
|
||
"config": {
|
||
"model": settings.embedding_model,
|
||
"api_key": settings.openai_api_key,
|
||
"openai_base_url": settings.openai_base_url,
|
||
"embedding_dims": settings.embedding_dims,
|
||
},
|
||
},
|
||
"vector_store": {
|
||
"provider": "qdrant",
|
||
"config": {
|
||
"collection_name": settings.qdrant_collection_name,
|
||
"host": settings.qdrant_host,
|
||
"port": settings.qdrant_port,
|
||
"embedding_model_dims": settings.embedding_dims,
|
||
"on_disk": True,
|
||
},
|
||
},
|
||
"reranker": {
|
||
"provider": "cohere",
|
||
"config": {
|
||
"api_key": settings.cohere_api_key,
|
||
# v3.5 supersedes v3.0: 4096-token context, multilingual
|
||
# (our users include Hindi/Hinglish content that the
|
||
# English-only v3 silently underperforms on).
|
||
"model": "rerank-v3.5",
|
||
# Raised from 10 → 50 so the rerank output cap does not
|
||
# truncate below typical over-fetch sizes (see search calls
|
||
# below, which request top_k up to ~3× the user's limit).
|
||
"top_n": 50,
|
||
},
|
||
},
|
||
}
|
||
|
||
self.memory = Memory.from_config(config)
|
||
self.openai_client = OpenAI(
|
||
api_key=settings.openai_api_key,
|
||
base_url=settings.openai_base_url,
|
||
timeout=60.0, # 60 second timeout for LLM calls
|
||
max_retries=2, # Retry failed requests up to 2 times
|
||
)
|
||
# Async HTTP client used by openai_proxy_completion to forward
|
||
# /v1/chat/completions to the upstream endpoint verbatim (preserves
|
||
# tool_calls, reasoning_tokens, system_fingerprint, etc., which the
|
||
# openai-python SDK can drop). Closed in the FastAPI lifespan
|
||
# shutdown via aclose().
|
||
self.async_http = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(120.0, connect=10.0)
|
||
)
|
||
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
|
||
|
||
async def aclose(self) -> None:
|
||
"""Release async resources. Called from the FastAPI lifespan shutdown."""
|
||
try:
|
||
await self.async_http.aclose()
|
||
except Exception as e:
|
||
logger.warning("async_http aclose failed", error=str(e))
|
||
|
||
# Pure passthrough methods - no custom logic
|
||
@db_retry
|
||
@timed("add_memories")
|
||
async def add_memories(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
user_id: Optional[str] = "default",
|
||
agent_id: Optional[str] = None,
|
||
run_id: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""Add memories - simplified native Mem0 pattern (10 lines vs 45)."""
|
||
try:
|
||
# Convert ChatMessage objects to dict if needed
|
||
formatted_messages = []
|
||
for msg in messages:
|
||
if hasattr(msg, "dict"):
|
||
formatted_messages.append(msg.dict())
|
||
else:
|
||
formatted_messages.append(msg)
|
||
|
||
# Auto-enhance metadata for better memory quality
|
||
combined_metadata = metadata or {}
|
||
|
||
# Add automatic metadata enhancement
|
||
auto_metadata = {
|
||
"timestamp": datetime.now().isoformat(),
|
||
"source": "chat_conversation",
|
||
"message_count": len(formatted_messages),
|
||
"auto_generated": True,
|
||
}
|
||
|
||
# Merge user metadata with auto metadata (user metadata takes precedence)
|
||
enhanced_metadata = {**auto_metadata, **combined_metadata}
|
||
|
||
# Direct Mem0 add with enhanced metadata
|
||
result = self.memory.add(
|
||
formatted_messages,
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
run_id=run_id,
|
||
metadata=enhanced_metadata,
|
||
)
|
||
|
||
return {
|
||
"added_memories": result if isinstance(result, list) else [result],
|
||
"message": "Memories added successfully",
|
||
"hierarchy": {
|
||
"user_id": user_id,
|
||
"agent_id": agent_id,
|
||
"run_id": run_id,
|
||
},
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Error adding memories: {e}")
|
||
raise
|
||
|
||
@db_retry
|
||
@timed("search_memories")
|
||
async def search_memories(
|
||
self,
|
||
query: str,
|
||
user_id: Optional[str] = "default",
|
||
limit: int = 5,
|
||
threshold: Optional[float] = None,
|
||
filters: Optional[Dict[str, Any]] = None,
|
||
# keyword_search: bool = False,
|
||
# rerank: bool = False,
|
||
# filter_memories: bool = False,
|
||
agent_id: Optional[str] = None,
|
||
run_id: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
"""Search memories - native Mem0 pattern"""
|
||
try:
|
||
# Minimal empty query protection for API compatibility
|
||
if not query or query.strip() == "":
|
||
return {
|
||
"memories": [],
|
||
"total_count": 0,
|
||
"query": query,
|
||
"note": "Empty query provided, no results returned. Use a specific query to search memories.",
|
||
}
|
||
# mem0 v2: entity IDs must live inside the `filters` dict; `limit` is now `top_k`.
|
||
# Over-fetch a 30–50-candidate pool so the Cohere reranker (rerank=True)
|
||
# has room to reorder; then truncate to the caller's requested limit.
|
||
overfetch = max(limit * 3, 30)
|
||
result = self.memory.search(
|
||
query=query,
|
||
filters=_build_filters(user_id, agent_id, run_id, extra=filters),
|
||
top_k=overfetch,
|
||
threshold=threshold,
|
||
rerank=True,
|
||
)
|
||
memories = result.get("results", [])[:limit]
|
||
return {
|
||
"memories": memories,
|
||
"total_count": len(memories),
|
||
"query": query,
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Error searching memories: {e}")
|
||
raise
|
||
|
||
@db_retry
|
||
async def get_user_memories(
|
||
self,
|
||
user_id: str,
|
||
limit: int = 10,
|
||
agent_id: Optional[str] = None,
|
||
run_id: Optional[str] = None,
|
||
filters: Optional[Dict[str, Any]] = None,
|
||
) -> List[Dict[str, Any]]:
|
||
"""Get all memories for a user - native Mem0 pattern."""
|
||
try:
|
||
# mem0 v2: entity IDs must live inside the `filters` dict; `limit` is now `top_k`.
|
||
result = self.memory.get_all(
|
||
filters=_build_filters(user_id, agent_id, run_id, extra=filters),
|
||
top_k=limit,
|
||
)
|
||
return result.get("results", [])
|
||
except Exception as e:
|
||
logger.error(f"Error getting user memories: {e}")
|
||
raise
|
||
|
||
@db_retry
|
||
async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
|
||
"""Get a single memory by ID. Returns None if not found."""
|
||
try:
|
||
result = self.memory.get(memory_id=memory_id)
|
||
return result
|
||
except Exception as e:
|
||
logger.debug(f"Memory {memory_id} not found or error: {e}")
|
||
return None
|
||
|
||
async def verify_memory_ownership(self, memory_id: str, user_id: str) -> bool:
|
||
"""Check if a memory belongs to a user. O(1) instead of O(n)."""
|
||
memory = await self.get_memory(memory_id)
|
||
if memory is None:
|
||
return False
|
||
return memory.get("user_id") == user_id
|
||
|
||
@db_retry
|
||
@timed("update_memory")
|
||
async def update_memory(
|
||
self,
|
||
memory_id: str,
|
||
content: str,
|
||
) -> Dict[str, Any]:
|
||
"""Update memory - pure Mem0 passthrough."""
|
||
try:
|
||
result = self.memory.update(memory_id=memory_id, data=content)
|
||
return {"message": "Memory updated successfully", "result": result}
|
||
except Exception as e:
|
||
logger.error(f"Error updating memory: {e}")
|
||
raise
|
||
|
||
@db_retry
|
||
@timed("delete_memory")
|
||
async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
|
||
"""Delete memory - pure Mem0 passthrough."""
|
||
try:
|
||
self.memory.delete(memory_id=memory_id)
|
||
return {"message": "Memory deleted successfully"}
|
||
except Exception as e:
|
||
logger.error(f"Error deleting memory: {e}")
|
||
raise
|
||
|
||
async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]:
|
||
"""Delete all user memories - pure Mem0 passthrough."""
|
||
try:
|
||
self.memory.delete_all(user_id=user_id)
|
||
return {"message": "All user memories deleted successfully"}
|
||
except Exception as e:
|
||
logger.error(f"Error deleting user memories: {e}")
|
||
raise
|
||
|
||
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
|
||
"""Get memory change history - pure Mem0 passthrough."""
|
||
try:
|
||
history = self.memory.history(memory_id=memory_id)
|
||
return {
|
||
"memory_id": memory_id,
|
||
"history": history,
|
||
"message": "Memory history retrieved successfully",
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Error getting memory history: {e}")
|
||
raise
|
||
|
||
async def get_graph_relationships(
|
||
self,
|
||
user_id: Optional[str],
|
||
agent_id: Optional[str],
|
||
run_id: Optional[str],
|
||
limit: int = 50,
|
||
) -> Dict[str, Any]:
|
||
"""Graph relationships — deprecated in mem0 v2 (OSS graph memory removed).
|
||
|
||
mem0 v2.0.0 deleted the OSS graph store (Neo4j/Memgraph/Kuzu/AGE drivers).
|
||
Entity relationships now influence ranking via a parallel `{collection}_entities`
|
||
Qdrant collection rather than being directly traversable. We return an empty
|
||
graph payload plus a `deprecated` marker so clients (frontend graph.html) can
|
||
render a clear "Graph view unavailable" state instead of erroring.
|
||
"""
|
||
return {
|
||
"relationships": [],
|
||
"entities": [],
|
||
"user_id": user_id,
|
||
"agent_id": agent_id,
|
||
"run_id": run_id,
|
||
"total_memories": 0,
|
||
"total_relationships": 0,
|
||
"deprecated": True,
|
||
"deprecation_note": (
|
||
"OSS graph memory was removed in mem0 v2.0.0. Use search/get_all for "
|
||
"memory retrieval; entity links now affect ranking only."
|
||
),
|
||
}
|
||
|
||
@timed("chat_with_memory")
|
||
async def chat_with_memory(
|
||
self,
|
||
message: str,
|
||
user_id: Optional[str] = None,
|
||
agent_id: Optional[str] = None,
|
||
run_id: Optional[str] = None,
|
||
context: Optional[List[Dict[str, str]]] = None,
|
||
# metadata: Optional[Dict[str, Any]] = None
|
||
) -> Dict[str, Any]:
|
||
"""Chat with memory - native Mem0 pattern with detailed timing."""
|
||
import time
|
||
|
||
try:
|
||
total_start_time = time.time()
|
||
logger.info("Starting chat request", user_id=user_id)
|
||
|
||
search_start_time = time.time()
|
||
# Over-fetch for the Cohere reranker (rerank=True), then keep the
|
||
# top 10 reranked memories for the system prompt.
|
||
search_result = self.memory.search(
|
||
query=message,
|
||
filters=_build_filters(user_id, agent_id, run_id),
|
||
top_k=30,
|
||
threshold=0.3,
|
||
rerank=True,
|
||
)
|
||
relevant_memories = search_result.get("results", [])[:10]
|
||
memories_str = "\n".join(
|
||
f"- {entry['memory']}" for entry in relevant_memories
|
||
)
|
||
search_time = time.time() - search_start_time
|
||
logger.debug(
|
||
"Memory search completed",
|
||
search_time_s=round(search_time, 2),
|
||
memories_found=len(relevant_memories),
|
||
)
|
||
|
||
prep_start_time = time.time()
|
||
system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
|
||
messages = [{"role": "system", "content": system_prompt}]
|
||
|
||
if context:
|
||
messages.extend(context)
|
||
logger.debug("Added context messages", context_count=len(context))
|
||
|
||
messages.append({"role": "user", "content": message})
|
||
prep_time = time.time() - prep_start_time
|
||
|
||
llm_start_time = time.time()
|
||
response = self.openai_client.chat.completions.create(
|
||
model=settings.default_model, messages=messages
|
||
)
|
||
# Strip leading whitespace — reasoning models (minimax-m2) leak
|
||
# blank lines from their reasoning output into visible content.
|
||
# lstrip(), not strip(), preserves intentional trailing whitespace
|
||
# inside content (e.g., inside a code block).
|
||
assistant_response = (response.choices[0].message.content or "").lstrip()
|
||
llm_time = time.time() - llm_start_time
|
||
logger.debug(
|
||
"LLM call completed",
|
||
llm_time_s=round(llm_time, 2),
|
||
model=settings.default_model,
|
||
)
|
||
|
||
add_start_time = time.time()
|
||
memory_messages = [
|
||
{"role": "user", "content": message},
|
||
{"role": "assistant", "content": assistant_response},
|
||
]
|
||
self.memory.add(memory_messages, user_id=user_id)
|
||
add_time = time.time() - add_start_time
|
||
|
||
total_time = time.time() - total_start_time
|
||
logger.info(
|
||
"Chat request completed",
|
||
user_id=user_id,
|
||
total_time_s=round(total_time, 2),
|
||
search_time_s=round(search_time, 2),
|
||
llm_time_s=round(llm_time, 2),
|
||
add_time_s=round(add_time, 2),
|
||
memories_used=len(relevant_memories),
|
||
model=settings.default_model,
|
||
)
|
||
|
||
return {
|
||
"response": assistant_response,
|
||
"memories_used": len(relevant_memories),
|
||
"model_used": settings.default_model,
|
||
"timing": {
|
||
"total": round(total_time, 2),
|
||
"search": round(search_time, 2),
|
||
"llm": round(llm_time, 2),
|
||
"add": round(add_time, 2),
|
||
},
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
"Error in chat_with_memory",
|
||
error=str(e),
|
||
user_id=user_id,
|
||
exc_info=True,
|
||
)
|
||
return {
|
||
"error": str(e),
|
||
"response": "I apologize, but I encountered an error processing your request.",
|
||
"memories_used": 0,
|
||
"model_used": None,
|
||
}
|
||
|
||
# ------------------------------------------------------------------
|
||
# OpenAI-compatible /v1/chat/completions proxy
|
||
# ------------------------------------------------------------------
|
||
# Design: forward client requests to the upstream endpoint verbatim,
|
||
# injecting a "Relevant memories" system message ONLY when the last
|
||
# message is from the user AND no tool flow is in progress. The
|
||
# opinionated chat_with_memory above is NOT used on this path because:
|
||
# (a) clients pick their own model, tools, response_format etc.,
|
||
# (b) we must preserve every upstream field (tool_calls, reasoning_tokens,
|
||
# system_fingerprint, finish_reason), which the openai-python SDK and
|
||
# typed Pydantic responses strip silently,
|
||
# (c) streaming must be real (SSE pass-through), not post-hoc word-chunking.
|
||
#
|
||
# Tool-call safety contract (Memori issue #434 — codified here):
|
||
# - We only PREPEND a system message; never mutate, reorder, or delete
|
||
# existing messages.
|
||
# - We never touch tool messages or tool_call_id values.
|
||
# - We skip injection on tool-result follow-ups (last role != "user").
|
||
# - We always run the post-stream mem0.add even if upstream errors
|
||
# mid-flight (via try/finally).
|
||
async def openai_proxy_completion(
|
||
self,
|
||
request_kwargs: Dict[str, Any],
|
||
user_id: str,
|
||
) -> Union[Dict[str, Any], AsyncIterator[bytes]]:
|
||
"""Proxy a chat-completions request to the upstream LLM.
|
||
|
||
Returns:
|
||
- dict (upstream JSON) when stream is falsy
|
||
- async iterator of SSE bytes when stream is truthy
|
||
"""
|
||
messages = request_kwargs.get("messages") or []
|
||
if not messages:
|
||
raise ValueError("messages array is empty")
|
||
|
||
# Injection rule: last role == 'user' AND no prior tool message.
|
||
last_msg = messages[-1] if isinstance(messages[-1], dict) else {}
|
||
last_role = last_msg.get("role")
|
||
has_tool_message = any(
|
||
isinstance(m, dict) and m.get("role") == "tool" for m in messages
|
||
)
|
||
inject_memory = last_role == "user" and not has_tool_message
|
||
|
||
if inject_memory:
|
||
last_user_text = _extract_text(last_msg.get("content"))
|
||
if last_user_text:
|
||
try:
|
||
search_result = self.memory.search(
|
||
query=last_user_text,
|
||
filters=_build_filters(user_id),
|
||
top_k=30,
|
||
threshold=0.3,
|
||
rerank=True,
|
||
)
|
||
relevant = (search_result.get("results") or [])[:10]
|
||
except Exception as e:
|
||
logger.warning(
|
||
"memory search failed; proceeding without injection",
|
||
error=str(e),
|
||
user_id=user_id,
|
||
)
|
||
relevant = []
|
||
if relevant:
|
||
memories_str = "\n".join(
|
||
f"- {m.get('memory', '')}" for m in relevant
|
||
)
|
||
mem_block = (
|
||
"Relevant memories about the user "
|
||
"(use when helpful, ignore otherwise):\n" + memories_str
|
||
)
|
||
# Merge into existing leading system message if present;
|
||
# otherwise prepend a new one. Either way we never mutate
|
||
# any other message in the list.
|
||
if (
|
||
messages
|
||
and isinstance(messages[0], dict)
|
||
and messages[0].get("role") == "system"
|
||
):
|
||
existing = _extract_text(messages[0].get("content")) or ""
|
||
merged = (existing + "\n\n" + mem_block) if existing else mem_block
|
||
messages = [
|
||
{"role": "system", "content": merged},
|
||
*messages[1:],
|
||
]
|
||
else:
|
||
messages = [
|
||
{"role": "system", "content": mem_block},
|
||
*messages,
|
||
]
|
||
request_kwargs["messages"] = messages
|
||
logger.info(
|
||
"memory injected",
|
||
user_id=user_id,
|
||
memory_count=len(relevant),
|
||
)
|
||
|
||
url = settings.openai_base_url.rstrip("/") + "/v1/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.openai_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
is_stream = bool(request_kwargs.get("stream"))
|
||
|
||
if not is_stream:
|
||
resp = await self.async_http.post(
|
||
url, json=request_kwargs, headers=headers
|
||
)
|
||
if resp.status_code >= 400:
|
||
# Surface upstream error body to the client unchanged
|
||
try:
|
||
err_body = resp.json()
|
||
except Exception:
|
||
err_body = {"error": {"message": resp.text[:500]}}
|
||
raise httpx.HTTPStatusError(
|
||
f"Upstream {resp.status_code}", request=resp.request, response=resp
|
||
)
|
||
data = resp.json()
|
||
# Fire-and-forget mem0.add in the background. Skips tool-only
|
||
# responses inside _post_completion_add.
|
||
asyncio.create_task(self._post_completion_add(messages, data, user_id))
|
||
return data
|
||
|
||
return self._stream_proxy(url, request_kwargs, headers, messages, user_id)
|
||
|
||
async def _stream_proxy(
|
||
self,
|
||
url: str,
|
||
request_kwargs: Dict[str, Any],
|
||
headers: Dict[str, str],
|
||
messages: List[Dict[str, Any]],
|
||
user_id: str,
|
||
) -> AsyncIterator[bytes]:
|
||
"""Stream upstream SSE bytes verbatim; accumulate content for post-stream add."""
|
||
accumulated_content: List[str] = []
|
||
saw_tool_calls = False
|
||
buffer = b""
|
||
try:
|
||
async with self.async_http.stream(
|
||
"POST", url, json=request_kwargs, headers=headers
|
||
) as upstream:
|
||
if upstream.status_code >= 400:
|
||
body_bytes = await upstream.aread()
|
||
err = {
|
||
"error": {
|
||
"message": (
|
||
f"Upstream {upstream.status_code}: "
|
||
f"{body_bytes.decode('utf-8', errors='replace')[:500]}"
|
||
)
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(err)}\n\n".encode("utf-8")
|
||
yield b"data: [DONE]\n\n"
|
||
return
|
||
|
||
async for chunk in upstream.aiter_bytes():
|
||
if not chunk:
|
||
continue
|
||
# Forward bytes verbatim — preserves the exact SSE wire format
|
||
yield chunk
|
||
# Side-channel parse for content/tool_calls accumulation.
|
||
# Events are \n\n-separated; data lines start with "data: ".
|
||
buffer += chunk
|
||
while b"\n\n" in buffer:
|
||
event_bytes, buffer = buffer.split(b"\n\n", 1)
|
||
for raw_line in event_bytes.split(b"\n"):
|
||
if not raw_line.startswith(b"data: "):
|
||
continue
|
||
payload = raw_line[6:].strip()
|
||
if not payload or payload == b"[DONE]":
|
||
continue
|
||
try:
|
||
obj = json.loads(payload)
|
||
delta = (
|
||
(obj.get("choices") or [{}])[0].get("delta") or {}
|
||
)
|
||
content_piece = delta.get("content")
|
||
if isinstance(content_piece, str):
|
||
accumulated_content.append(content_piece)
|
||
if delta.get("tool_calls"):
|
||
saw_tool_calls = True
|
||
except (json.JSONDecodeError, KeyError, IndexError, TypeError):
|
||
pass
|
||
except httpx.HTTPError as e:
|
||
logger.warning("upstream stream error", error=str(e), user_id=user_id)
|
||
err = {"error": {"message": f"Upstream stream error: {e}"}}
|
||
yield f"data: {json.dumps(err)}\n\n".encode("utf-8")
|
||
yield b"data: [DONE]\n\n"
|
||
finally:
|
||
# Per Memori #434: post-stream mem0.add must run even on mid-stream
|
||
# error so partial content is captured. Skipped for tool-only or
|
||
# empty responses.
|
||
full = "".join(accumulated_content).lstrip()
|
||
if full and not saw_tool_calls:
|
||
last_user_text = (
|
||
_extract_text(messages[-1].get("content"))
|
||
if messages
|
||
and isinstance(messages[-1], dict)
|
||
and messages[-1].get("role") == "user"
|
||
else None
|
||
)
|
||
if last_user_text:
|
||
try:
|
||
# Synchronous mem0.add in a thread to avoid blocking
|
||
# the response loop after StreamingResponse closes.
|
||
await asyncio.to_thread(
|
||
self.memory.add,
|
||
[
|
||
{"role": "user", "content": last_user_text},
|
||
{"role": "assistant", "content": full},
|
||
],
|
||
user_id=user_id,
|
||
)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"post-stream mem0.add failed",
|
||
error=str(e),
|
||
user_id=user_id,
|
||
)
|
||
|
||
async def _post_completion_add(
|
||
self,
|
||
messages: List[Dict[str, Any]],
|
||
response_data: Dict[str, Any],
|
||
user_id: str,
|
||
) -> None:
|
||
"""Background mem0.add after a non-stream completion."""
|
||
try:
|
||
choice = (response_data.get("choices") or [{}])[0]
|
||
msg = choice.get("message") or {}
|
||
content = msg.get("content")
|
||
if not content or not isinstance(content, str):
|
||
return # tool-only or non-text completion — skip
|
||
content = content.lstrip()
|
||
last_user_text = (
|
||
_extract_text(messages[-1].get("content"))
|
||
if messages
|
||
and isinstance(messages[-1], dict)
|
||
and messages[-1].get("role") == "user"
|
||
else None
|
||
)
|
||
if not last_user_text:
|
||
return
|
||
await asyncio.to_thread(
|
||
self.memory.add,
|
||
[
|
||
{"role": "user", "content": last_user_text},
|
||
{"role": "assistant", "content": content},
|
||
],
|
||
user_id=user_id,
|
||
)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"post-completion mem0.add failed", error=str(e), user_id=user_id
|
||
)
|
||
|
||
async def health_check(self) -> Dict[str, str]:
|
||
"""Basic health check - just connectivity."""
|
||
status = {}
|
||
|
||
# Check custom OpenAI endpoint
|
||
try:
|
||
models = self.openai_client.models.list()
|
||
status["openai_endpoint"] = "healthy"
|
||
except Exception as e:
|
||
status["openai_endpoint"] = f"unhealthy: {str(e)}"
|
||
|
||
# Check Mem0 memory
|
||
try:
|
||
self.memory.search(
|
||
query="test", filters={"user_id": "health_check"}, top_k=1
|
||
)
|
||
status["mem0_memory"] = "healthy"
|
||
except Exception as e:
|
||
status["mem0_memory"] = f"unhealthy: {str(e)}"
|
||
|
||
return status
|
||
|
||
|
||
# Global instance
|
||
mem0_manager = Mem0Manager()
|