knowledge-base/backend/mem0_manager.py
Pratik Narola a228780146 production improvements: configurable embeddings, v1.1, O(1) ownership, retries
- Make Ollama URL configurable via OLLAMA_BASE_URL env var
- Add version: v1.1 to Mem0 config (required for latest features)
- Make embedding model and dimensions configurable
- Fix ownership check: O(1) lookup instead of fetching 10k records
- Add tenacity retry logic for database operations
2026-01-15 23:01:18 +05:30

498 lines
17 KiB
Python

"""Ultra-minimal Mem0 Manager - Pure Mem0 + Custom OpenAI Endpoint Only."""
import logging
from typing import Dict, List, Optional, Any
from datetime import datetime
from mem0 import Memory
from openai import OpenAI
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
)
from config import settings
from monitoring import timed
logger = logging.getLogger(__name__)
# Retry decorator for database operations (Qdrant, Neo4j)
db_retry = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
from mem0.llms.openai import OpenAILLM
_original_generate_response = OpenAILLM.generate_response
def patched_generate_response(
self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs
):
# Remove 'store' parameter as LiteLLM doesn't support it
if hasattr(self.config, "store"):
self.config.store = None
# Remove 'top_p' to avoid conflict with temperature for Claude models
if hasattr(self.config, "top_p"):
self.config.top_p = None
return _original_generate_response(
self, messages, response_format, tools, tool_choice, **kwargs
)
OpenAILLM.generate_response = patched_generate_response
logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter")
class Mem0Manager:
"""
Ultra-minimal manager that bridges custom OpenAI endpoint with pure Mem0.
No custom logic - let Mem0 handle all memory intelligence.
"""
def __init__(self):
# Custom endpoint configuration with graph memory enabled
logger.info(
"Initializing ultra-minimal Mem0Manager with custom endpoint with settings:",
settings,
)
config = {
"version": "v1.1",
"enable_graph": True,
"llm": {
"provider": "openai",
"config": {
"model": settings.default_model,
"api_key": settings.openai_api_key,
"openai_base_url": settings.openai_base_url,
"temperature": 0.1,
"top_p": None,
},
},
"embedder": {
"provider": "ollama",
"config": {
"model": settings.embedding_model,
"ollama_base_url": settings.ollama_base_url,
"embedding_dims": settings.embedding_dims,
},
},
"vector_store": {
"provider": "qdrant",
"config": {
"collection_name": settings.qdrant_collection_name,
"host": settings.qdrant_host,
"port": settings.qdrant_port,
"embedding_model_dims": settings.embedding_dims,
"on_disk": True,
},
},
"graph_store": {
"provider": "neo4j",
"config": {
"url": settings.neo4j_uri,
"username": settings.neo4j_username,
"password": settings.neo4j_password,
},
},
"reranker": {
"provider": "cohere",
"config": {
"api_key": settings.cohere_api_key,
"model": "rerank-english-v3.0",
"top_n": 10,
},
},
}
self.memory = Memory.from_config(config)
self.openai_client = OpenAI(
api_key=settings.openai_api_key,
base_url=settings.openai_base_url,
timeout=60.0, # 60 second timeout for LLM calls
max_retries=2, # Retry failed requests up to 2 times
)
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
# Pure passthrough methods - no custom logic
@db_retry
@timed("add_memories")
async def add_memories(
self,
messages: List[Dict[str, str]],
user_id: Optional[str] = "default",
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Add memories - simplified native Mem0 pattern (10 lines vs 45)."""
try:
# Convert ChatMessage objects to dict if needed
formatted_messages = []
for msg in messages:
if hasattr(msg, "dict"):
formatted_messages.append(msg.dict())
else:
formatted_messages.append(msg)
# Auto-enhance metadata for better memory quality
combined_metadata = metadata or {}
# Add automatic metadata enhancement
auto_metadata = {
"timestamp": datetime.now().isoformat(),
"source": "chat_conversation",
"message_count": len(formatted_messages),
"auto_generated": True,
}
# Merge user metadata with auto metadata (user metadata takes precedence)
enhanced_metadata = {**auto_metadata, **combined_metadata}
# Direct Mem0 add with enhanced metadata
result = self.memory.add(
formatted_messages,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=enhanced_metadata,
)
return {
"added_memories": result if isinstance(result, list) else [result],
"message": "Memories added successfully",
"hierarchy": {
"user_id": user_id,
"agent_id": agent_id,
"run_id": run_id,
},
}
except Exception as e:
logger.error(f"Error adding memories: {e}")
raise
@db_retry
@timed("search_memories")
async def search_memories(
self,
query: str,
user_id: Optional[str] = "default",
limit: int = 5,
threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
# keyword_search: bool = False,
# rerank: bool = False,
# filter_memories: bool = False,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Search memories - native Mem0 pattern"""
try:
# Minimal empty query protection for API compatibility
if not query or query.strip() == "":
return {
"memories": [],
"total_count": 0,
"query": query,
"note": "Empty query provided, no results returned. Use a specific query to search memories.",
}
# Direct Mem0 search - trust native handling
result = self.memory.search(
query=query,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=limit,
threshold=threshold,
filters=filters,
)
return {
"memories": result.get("results", []),
"total_count": len(result.get("results", [])),
"query": query,
}
except Exception as e:
logger.error(f"Error searching memories: {e}")
raise
@db_retry
async def get_user_memories(
self,
user_id: str,
limit: int = 10,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""Get all memories for a user - native Mem0 pattern."""
try:
# Direct Mem0 get_all call - trust native parameter handling
result = self.memory.get_all(
user_id=user_id,
limit=limit,
agent_id=agent_id,
run_id=run_id,
filters=filters,
)
return result.get("results", [])
except Exception as e:
logger.error(f"Error getting user memories: {e}")
raise
@db_retry
async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
"""Get a single memory by ID. Returns None if not found."""
try:
result = self.memory.get(memory_id=memory_id)
return result
except Exception as e:
logger.debug(f"Memory {memory_id} not found or error: {e}")
return None
async def verify_memory_ownership(self, memory_id: str, user_id: str) -> bool:
"""Check if a memory belongs to a user. O(1) instead of O(n)."""
memory = await self.get_memory(memory_id)
if memory is None:
return False
return memory.get("user_id") == user_id
@db_retry
@timed("update_memory")
async def update_memory(
self,
memory_id: str,
content: str,
) -> Dict[str, Any]:
"""Update memory - pure Mem0 passthrough."""
try:
result = self.memory.update(memory_id=memory_id, data=content)
return {"message": "Memory updated successfully", "result": result}
except Exception as e:
logger.error(f"Error updating memory: {e}")
raise
@db_retry
@timed("delete_memory")
async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
"""Delete memory - pure Mem0 passthrough."""
try:
self.memory.delete(memory_id=memory_id)
return {"message": "Memory deleted successfully"}
except Exception as e:
logger.error(f"Error deleting memory: {e}")
raise
async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]:
"""Delete all user memories - pure Mem0 passthrough."""
try:
self.memory.delete_all(user_id=user_id)
return {"message": "All user memories deleted successfully"}
except Exception as e:
logger.error(f"Error deleting user memories: {e}")
raise
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
"""Get memory change history - pure Mem0 passthrough."""
try:
history = self.memory.history(memory_id=memory_id)
return {
"memory_id": memory_id,
"history": history,
"message": "Memory history retrieved successfully",
}
except Exception as e:
logger.error(f"Error getting memory history: {e}")
raise
async def get_graph_relationships(
self,
user_id: Optional[str],
agent_id: Optional[str],
run_id: Optional[str],
limit: int = 50,
) -> Dict[str, Any]:
"""Get graph relationships - using correct Mem0 get_all() method."""
try:
# Use get_all() to retrieve memories with graph relationships
result = self.memory.get_all(
user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit
)
# Extract relationships from Mem0's response structure
relationships = result.get("relations", [])
# For entities, we can derive them from memory results or relations
entities = []
if "results" in result:
# Extract unique entities from memories and relationships
entity_set = set()
# Add entities from relationships
for rel in relationships:
if "source" in rel:
entity_set.add(rel["source"])
if "target" in rel:
entity_set.add(rel["target"])
entities = [{"name": entity} for entity in entity_set]
return {
"relationships": relationships,
"entities": entities,
"user_id": user_id,
"agent_id": agent_id,
"run_id": run_id,
"total_memories": len(result.get("results", [])),
"total_relationships": len(relationships),
}
except Exception as e:
logger.error(f"Error getting graph relationships: {e}")
# Return empty but structured response on error
return {
"relationships": [],
"entities": [],
"user_id": user_id,
"agent_id": agent_id,
"run_id": run_id,
"total_memories": 0,
"total_relationships": 0,
"error": str(e),
}
@timed("chat_with_memory")
async def chat_with_memory(
self,
message: str,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
context: Optional[List[Dict[str, str]]] = None,
# metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Chat with memory - native Mem0 pattern with detailed timing."""
import time
try:
total_start_time = time.time()
logger.info("Starting chat request", user_id=user_id)
search_start_time = time.time()
search_result = self.memory.search(
query=message,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
limit=10,
threshold=0.3,
)
relevant_memories = search_result.get("results", [])
memories_str = "\n".join(
f"- {entry['memory']}" for entry in relevant_memories
)
search_time = time.time() - search_start_time
logger.debug(
"Memory search completed",
search_time_s=round(search_time, 2),
memories_found=len(relevant_memories),
)
prep_start_time = time.time()
system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
messages = [{"role": "system", "content": system_prompt}]
if context:
messages.extend(context)
logger.debug("Added context messages", context_count=len(context))
messages.append({"role": "user", "content": message})
prep_time = time.time() - prep_start_time
llm_start_time = time.time()
response = self.openai_client.chat.completions.create(
model=settings.default_model, messages=messages
)
assistant_response = response.choices[0].message.content
llm_time = time.time() - llm_start_time
logger.debug(
"LLM call completed",
llm_time_s=round(llm_time, 2),
model=settings.default_model,
)
add_start_time = time.time()
memory_messages = [
{"role": "user", "content": message},
{"role": "assistant", "content": assistant_response},
]
self.memory.add(memory_messages, user_id=user_id)
add_time = time.time() - add_start_time
total_time = time.time() - total_start_time
logger.info(
"Chat request completed",
user_id=user_id,
total_time_s=round(total_time, 2),
search_time_s=round(search_time, 2),
llm_time_s=round(llm_time, 2),
add_time_s=round(add_time, 2),
memories_used=len(relevant_memories),
model=settings.default_model,
)
return {
"response": assistant_response,
"memories_used": len(relevant_memories),
"model_used": settings.default_model,
"timing": {
"total": round(total_time, 2),
"search": round(search_time, 2),
"llm": round(llm_time, 2),
"add": round(add_time, 2),
},
}
except Exception as e:
logger.error(
"Error in chat_with_memory",
error=str(e),
user_id=user_id,
exc_info=True,
)
return {
"error": str(e),
"response": "I apologize, but I encountered an error processing your request.",
"memories_used": 0,
"model_used": None,
}
async def health_check(self) -> Dict[str, str]:
"""Basic health check - just connectivity."""
status = {}
# Check custom OpenAI endpoint
try:
models = self.openai_client.models.list()
status["openai_endpoint"] = "healthy"
except Exception as e:
status["openai_endpoint"] = f"unhealthy: {str(e)}"
# Check Mem0 memory
try:
self.memory.search(query="test", user_id="health_check", limit=1)
status["mem0_memory"] = "healthy"
except Exception as e:
status["mem0_memory"] = f"unhealthy: {str(e)}"
return status
# Global instance
mem0_manager = Mem0Manager()