312 lines
No EOL
12 KiB
Python
312 lines
No EOL
12 KiB
Python
"""Ultra-minimal Mem0 Manager - Pure Mem0 + Custom OpenAI Endpoint Only."""
|
|
|
|
import logging
|
|
from typing import Dict, List, Optional, Any
|
|
from mem0 import Memory
|
|
from openai import OpenAI
|
|
|
|
from config import settings
|
|
from monitoring import timed
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Mem0Manager:
|
|
"""
|
|
Ultra-minimal manager that bridges custom OpenAI endpoint with pure Mem0.
|
|
No custom logic - let Mem0 handle all memory intelligence.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Custom endpoint configuration with graph memory enabled
|
|
config = {
|
|
"enable_graph": True,
|
|
"llm": {
|
|
"provider": "openai",
|
|
"config": {
|
|
"model": settings.default_model,
|
|
"api_key": settings.openai_api_key,
|
|
"openai_base_url": settings.openai_base_url
|
|
}
|
|
},
|
|
"embedder": {
|
|
"provider": "ollama",
|
|
"config": {
|
|
"model": "hf.co/Qwen/Qwen3-Embedding-0.6B-GGUF:Q8_0",
|
|
# "api_key": settings.embedder_api_key,
|
|
"ollama_base_url": "http://host.docker.internal:11434",
|
|
"embedding_dims": 1024
|
|
}
|
|
},
|
|
"vector_store": {
|
|
"provider": "pgvector",
|
|
"config": {
|
|
"dbname": settings.postgres_db,
|
|
"user": settings.postgres_user,
|
|
"password": settings.postgres_password,
|
|
"host": settings.postgres_host,
|
|
"port": settings.postgres_port,
|
|
"embedding_model_dims": 1024
|
|
}
|
|
},
|
|
"graph_store": {
|
|
"provider": "neo4j",
|
|
"config": {
|
|
"url": settings.neo4j_uri,
|
|
"username": settings.neo4j_username,
|
|
"password": settings.neo4j_password
|
|
}
|
|
},
|
|
}
|
|
|
|
self.memory = Memory.from_config(config)
|
|
self.openai_client = OpenAI(
|
|
api_key=settings.openai_api_key,
|
|
base_url=settings.openai_base_url
|
|
)
|
|
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
|
|
|
|
|
|
|
|
# Pure passthrough methods - no custom logic
|
|
|
|
@timed("add_memories")
|
|
async def add_memories(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
user_id: str = "default",
|
|
agent_id: Optional[str] = None,
|
|
run_id: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""Add memories - simplified native Mem0 pattern (10 lines vs 45)."""
|
|
try:
|
|
# Convert ChatMessage objects to dict if needed
|
|
formatted_messages = []
|
|
for msg in messages:
|
|
if hasattr(msg, 'dict'):
|
|
formatted_messages.append(msg.dict())
|
|
else:
|
|
formatted_messages.append(msg)
|
|
|
|
# Direct Mem0 add with metadata support
|
|
combined_metadata = metadata or {}
|
|
if agent_id:
|
|
combined_metadata["agent_id"] = agent_id
|
|
if run_id:
|
|
combined_metadata["run_id"] = run_id
|
|
|
|
result = self.memory.add(formatted_messages, user_id=user_id, metadata=combined_metadata if combined_metadata else None)
|
|
|
|
return {
|
|
"added_memories": result if isinstance(result, list) else [result],
|
|
"message": "Memories added successfully",
|
|
"hierarchy": {"user_id": user_id, "agent_id": agent_id, "run_id": run_id}
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error adding memories: {e}")
|
|
raise e
|
|
|
|
@timed("search_memories")
|
|
async def search_memories(
|
|
self,
|
|
query: str,
|
|
user_id: str = "default",
|
|
limit: int = 5,
|
|
threshold: Optional[float] = None,
|
|
filters: Optional[Dict[str, Any]] = None,
|
|
keyword_search: bool = False,
|
|
rerank: bool = False,
|
|
filter_memories: bool = False,
|
|
agent_id: Optional[str] = None,
|
|
run_id: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""Search memories - native Mem0 pattern (5 lines vs 70)."""
|
|
try:
|
|
# Minimal empty query protection for API compatibility
|
|
if not query or query.strip() == "":
|
|
return {"memories": [], "total_count": 0, "query": query, "note": "Empty query provided, no results returned. Use a specific query to search memories."}
|
|
# Direct Mem0 search - trust native handling
|
|
result = self.memory.search(query=query, user_id=user_id, limit=limit)
|
|
return {"memories": result.get("results", []), "total_count": len(result.get("results", [])), "query": query}
|
|
except Exception as e:
|
|
logger.error(f"Error searching memories: {e}")
|
|
raise e
|
|
|
|
async def get_user_memories(
|
|
self,
|
|
user_id: str,
|
|
limit: int = 10,
|
|
agent_id: Optional[str] = None,
|
|
run_id: Optional[str] = None,
|
|
filters: Optional[Dict[str, Any]] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get all memories for a user - native Mem0 pattern."""
|
|
try:
|
|
# Direct Mem0 get_all call - trust native parameter handling
|
|
result = self.memory.get_all(user_id=user_id, limit=limit)
|
|
return result.get("results", [])
|
|
except Exception as e:
|
|
logger.error(f"Error getting user memories: {e}")
|
|
raise e
|
|
|
|
@timed("update_memory")
|
|
async def update_memory(
|
|
self,
|
|
memory_id: str,
|
|
content: str,
|
|
) -> Dict[str, Any]:
|
|
"""Update memory - pure Mem0 passthrough."""
|
|
try:
|
|
result = self.memory.update(
|
|
memory_id=memory_id,
|
|
data=content
|
|
)
|
|
return {"message": "Memory updated successfully", "result": result}
|
|
except Exception as e:
|
|
logger.error(f"Error updating memory: {e}")
|
|
raise e
|
|
|
|
@timed("delete_memory")
|
|
async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
|
|
"""Delete memory - pure Mem0 passthrough."""
|
|
try:
|
|
self.memory.delete(memory_id=memory_id)
|
|
return {"message": "Memory deleted successfully"}
|
|
except Exception as e:
|
|
logger.error(f"Error deleting memory: {e}")
|
|
raise e
|
|
|
|
async def delete_user_memories(self, user_id: str) -> Dict[str, Any]:
|
|
"""Delete all user memories - pure Mem0 passthrough."""
|
|
try:
|
|
self.memory.delete_all(user_id=user_id)
|
|
return {"message": "All user memories deleted successfully"}
|
|
except Exception as e:
|
|
logger.error(f"Error deleting user memories: {e}")
|
|
raise e
|
|
|
|
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
|
|
"""Get memory change history - pure Mem0 passthrough."""
|
|
try:
|
|
history = self.memory.history(memory_id=memory_id)
|
|
return {
|
|
"memory_id": memory_id,
|
|
"history": history,
|
|
"message": "Memory history retrieved successfully"
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error getting memory history: {e}")
|
|
raise e
|
|
|
|
|
|
async def get_graph_relationships(self, user_id: str) -> Dict[str, Any]:
|
|
"""Get graph relationships - using correct Mem0 get_all() method."""
|
|
try:
|
|
# Use get_all() to retrieve memories with graph relationships
|
|
result = self.memory.get_all(
|
|
user_id=user_id,
|
|
limit=50
|
|
)
|
|
|
|
# Extract relationships from Mem0's response structure
|
|
relationships = result.get("relations", [])
|
|
|
|
# For entities, we can derive them from memory results or relations
|
|
entities = []
|
|
if "results" in result:
|
|
# Extract unique entities from memories and relationships
|
|
entity_set = set()
|
|
|
|
# Add entities from relationships
|
|
for rel in relationships:
|
|
if "source" in rel:
|
|
entity_set.add(rel["source"])
|
|
if "target" in rel:
|
|
entity_set.add(rel["target"])
|
|
|
|
entities = [{"name": entity} for entity in entity_set]
|
|
|
|
return {
|
|
"relationships": relationships,
|
|
"entities": entities,
|
|
"user_id": user_id,
|
|
"total_memories": len(result.get("results", [])),
|
|
"total_relationships": len(relationships)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting graph relationships: {e}")
|
|
# Return empty but structured response on error
|
|
return {
|
|
"relationships": [],
|
|
"entities": [],
|
|
"user_id": user_id,
|
|
"total_memories": 0,
|
|
"total_relationships": 0,
|
|
"error": str(e)
|
|
}
|
|
|
|
@timed("chat_with_memory")
|
|
async def chat_with_memory(
|
|
self,
|
|
message: str,
|
|
user_id: str = "default",
|
|
context: Optional[List[Dict[str, str]]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""Chat with memory - native Mem0 pattern (15 lines vs 95)."""
|
|
try:
|
|
# Retrieve relevant memories using direct Mem0 search
|
|
search_result = self.memory.search(query=message, user_id=user_id, limit=3)
|
|
relevant_memories = search_result.get("results", [])
|
|
memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories)
|
|
|
|
# Generate Assistant response using Mem0's standard pattern
|
|
system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
|
|
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": message}]
|
|
response = self.openai_client.chat.completions.create(model=settings.default_model, messages=messages)
|
|
assistant_response = response.choices[0].message.content
|
|
|
|
# Create new memories from the conversation
|
|
messages.append({"role": "assistant", "content": assistant_response})
|
|
self.memory.add(messages, user_id=user_id)
|
|
|
|
return {
|
|
"response": assistant_response,
|
|
"memories_used": len(relevant_memories),
|
|
"model_used": settings.default_model
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in chat_with_memory: {e}")
|
|
return {
|
|
"error": str(e),
|
|
"response": "I apologize, but I encountered an error processing your request.",
|
|
"memories_used": 0,
|
|
"model_used": None
|
|
}
|
|
|
|
async def health_check(self) -> Dict[str, str]:
|
|
"""Basic health check - just connectivity."""
|
|
status = {}
|
|
|
|
# Check custom OpenAI endpoint
|
|
try:
|
|
models = self.openai_client.models.list()
|
|
status["openai_endpoint"] = "healthy"
|
|
except Exception as e:
|
|
status["openai_endpoint"] = f"unhealthy: {str(e)}"
|
|
|
|
# Check Mem0 memory
|
|
try:
|
|
self.memory.search(query="test", user_id="health_check", limit=1)
|
|
status["mem0_memory"] = "healthy"
|
|
except Exception as e:
|
|
status["mem0_memory"] = f"unhealthy: {str(e)}"
|
|
|
|
return status
|
|
|
|
|
|
# Global instance
|
|
mem0_manager = Mem0Manager() |