production improvements: configurable embeddings, v1.1, O(1) ownership, retries
- Make Ollama URL configurable via OLLAMA_BASE_URL env var - Add version: v1.1 to Mem0 config (required for latest features) - Make embedding model and dimensions configurable - Fix ownership check: O(1) lookup instead of fetching 10k records - Add tenacity retry logic for database operations
This commit is contained in:
parent
50edce2d3c
commit
a228780146
4 changed files with 120 additions and 44 deletions
|
|
@ -11,39 +11,87 @@ class Settings(BaseSettings):
|
|||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
case_sensitive=False,
|
||||
extra='ignore'
|
||||
env_file=".env", case_sensitive=False, extra="ignore"
|
||||
)
|
||||
|
||||
# API Configuration
|
||||
# Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env)
|
||||
openai_api_key: str = Field(validation_alias=AliasChoices('OPENAI_API_KEY', 'OPENAI_COMPAT_API_KEY', 'openai_api_key'))
|
||||
openai_base_url: str = Field(validation_alias=AliasChoices('OPENAI_BASE_URL', 'OPENAI_COMPAT_BASE_URL', 'openai_base_url'))
|
||||
cohere_api_key: str = Field(validation_alias=AliasChoices('COHERE_API_KEY', 'cohere_api_key'))
|
||||
openai_api_key: str = Field(
|
||||
validation_alias=AliasChoices(
|
||||
"OPENAI_API_KEY", "OPENAI_COMPAT_API_KEY", "openai_api_key"
|
||||
)
|
||||
)
|
||||
openai_base_url: str = Field(
|
||||
validation_alias=AliasChoices(
|
||||
"OPENAI_BASE_URL", "OPENAI_COMPAT_BASE_URL", "openai_base_url"
|
||||
)
|
||||
)
|
||||
cohere_api_key: str = Field(
|
||||
validation_alias=AliasChoices("COHERE_API_KEY", "cohere_api_key")
|
||||
)
|
||||
|
||||
# Database Configuration
|
||||
qdrant_host: str = Field(default="localhost", validation_alias=AliasChoices('QDRANT_HOST', 'qdrant_host'))
|
||||
qdrant_port: int = Field(default=6333, validation_alias=AliasChoices('QDRANT_PORT', 'qdrant_port'))
|
||||
qdrant_collection_name: str = Field(default="mem0", validation_alias=AliasChoices('QDRANT_COLLECTION_NAME', 'qdrant_collection_name'))
|
||||
qdrant_host: str = Field(
|
||||
default="localhost", validation_alias=AliasChoices("QDRANT_HOST", "qdrant_host")
|
||||
)
|
||||
qdrant_port: int = Field(
|
||||
default=6333, validation_alias=AliasChoices("QDRANT_PORT", "qdrant_port")
|
||||
)
|
||||
qdrant_collection_name: str = Field(
|
||||
default="mem0",
|
||||
validation_alias=AliasChoices(
|
||||
"QDRANT_COLLECTION_NAME", "qdrant_collection_name"
|
||||
),
|
||||
)
|
||||
|
||||
# Neo4j Configuration
|
||||
neo4j_uri: str = Field(default="bolt://localhost:7687", validation_alias=AliasChoices('NEO4J_URI', 'neo4j_uri'))
|
||||
neo4j_username: str = Field(default="neo4j", validation_alias=AliasChoices('NEO4J_USERNAME', 'neo4j_username'))
|
||||
neo4j_password: str = Field(default="mem0_neo4j_password", validation_alias=AliasChoices('NEO4J_PASSWORD', 'neo4j_password'))
|
||||
neo4j_uri: str = Field(
|
||||
default="bolt://localhost:7687",
|
||||
validation_alias=AliasChoices("NEO4J_URI", "neo4j_uri"),
|
||||
)
|
||||
neo4j_username: str = Field(
|
||||
default="neo4j",
|
||||
validation_alias=AliasChoices("NEO4J_USERNAME", "neo4j_username"),
|
||||
)
|
||||
neo4j_password: str = Field(
|
||||
default="mem0_neo4j_password",
|
||||
validation_alias=AliasChoices("NEO4J_PASSWORD", "neo4j_password"),
|
||||
)
|
||||
|
||||
# Application Configuration
|
||||
log_level: str = Field(default="INFO", validation_alias=AliasChoices('LOG_LEVEL', 'log_level'))
|
||||
cors_origins: str = Field(default="http://localhost:3000", validation_alias=AliasChoices('CORS_ORIGINS', 'cors_origins'))
|
||||
log_level: str = Field(
|
||||
default="INFO", validation_alias=AliasChoices("LOG_LEVEL", "log_level")
|
||||
)
|
||||
cors_origins: str = Field(
|
||||
default="http://localhost:3000",
|
||||
validation_alias=AliasChoices("CORS_ORIGINS", "cors_origins"),
|
||||
)
|
||||
|
||||
# Model Configuration - Ultra-minimal (single model)
|
||||
default_model: str = Field(default="claude-sonnet-4", validation_alias=AliasChoices('DEFAULT_MODEL', 'default_model'))
|
||||
default_model: str = Field(
|
||||
default="claude-sonnet-4",
|
||||
validation_alias=AliasChoices("DEFAULT_MODEL", "default_model"),
|
||||
)
|
||||
|
||||
# Embedder Configuration
|
||||
ollama_base_url: str = Field(
|
||||
default="http://host.docker.internal:11434",
|
||||
validation_alias=AliasChoices("OLLAMA_BASE_URL", "ollama_base_url"),
|
||||
)
|
||||
embedding_model: str = Field(
|
||||
default="qwen3-embedding:4b-q8_0",
|
||||
validation_alias=AliasChoices("EMBEDDING_MODEL", "embedding_model"),
|
||||
)
|
||||
embedding_dims: int = Field(
|
||||
default=2560, validation_alias=AliasChoices("EMBEDDING_DIMS", "embedding_dims")
|
||||
)
|
||||
|
||||
# Authentication Configuration
|
||||
# Format: JSON string mapping API keys to user IDs
|
||||
# Example: {"api_key_123": "alice", "api_key_456": "bob"}
|
||||
api_keys: str = Field(default="{}", validation_alias=AliasChoices('API_KEYS', 'api_keys'))
|
||||
|
||||
api_keys: str = Field(
|
||||
default="{}", validation_alias=AliasChoices("API_KEYS", "api_keys")
|
||||
)
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -464,13 +464,10 @@ async def update_memory(
|
|||
detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')",
|
||||
)
|
||||
|
||||
# Verify the memory actually belongs to the authenticated user
|
||||
user_memories = await mem0_manager.get_user_memories(
|
||||
user_id=authenticated_user, limit=10000
|
||||
)
|
||||
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
|
||||
|
||||
if update_request.memory_id not in memory_ids:
|
||||
# Verify memory ownership with O(1) lookup instead of fetching all memories
|
||||
if not await mem0_manager.verify_memory_ownership(
|
||||
update_request.memory_id, authenticated_user
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Memory '{update_request.memory_id}' not found or access denied",
|
||||
|
|
@ -506,13 +503,10 @@ async def delete_memory(
|
|||
):
|
||||
"""Delete a specific memory - verifies ownership before deletion."""
|
||||
try:
|
||||
# Verify the memory actually belongs to the authenticated user
|
||||
user_memories = await mem0_manager.get_user_memories(
|
||||
user_id=authenticated_user, limit=10000
|
||||
)
|
||||
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
|
||||
|
||||
if memory_id not in memory_ids:
|
||||
# Verify memory ownership with O(1) lookup instead of fetching all memories
|
||||
if not await mem0_manager.verify_memory_ownership(
|
||||
memory_id, authenticated_user
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Memory '{memory_id}' not found or access denied",
|
||||
|
|
@ -614,13 +608,8 @@ async def get_memory_history(
|
|||
detail=f"Access denied: You can only view your own memory history (authenticated as '{authenticated_user}')",
|
||||
)
|
||||
|
||||
# Verify the memory belongs to this user before returning history
|
||||
user_memories = await mem0_manager.get_user_memories(
|
||||
user_id=user_id, limit=10000
|
||||
)
|
||||
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
|
||||
|
||||
if memory_id not in memory_ids:
|
||||
# Verify memory ownership with O(1) lookup instead of fetching all memories
|
||||
if not await mem0_manager.verify_memory_ownership(memory_id, user_id):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Memory '{memory_id}' not found or access denied",
|
||||
|
|
|
|||
|
|
@ -5,12 +5,28 @@ from typing import Dict, List, Optional, Any
|
|||
from datetime import datetime
|
||||
from mem0 import Memory
|
||||
from openai import OpenAI
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
before_sleep_log,
|
||||
)
|
||||
|
||||
from config import settings
|
||||
from monitoring import timed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Retry decorator for database operations (Qdrant, Neo4j)
|
||||
db_retry = retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=True,
|
||||
)
|
||||
|
||||
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
|
||||
from mem0.llms.openai import OpenAILLM
|
||||
|
||||
|
|
@ -48,6 +64,7 @@ class Mem0Manager:
|
|||
settings,
|
||||
)
|
||||
config = {
|
||||
"version": "v1.1",
|
||||
"enable_graph": True,
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
|
|
@ -56,16 +73,15 @@ class Mem0Manager:
|
|||
"api_key": settings.openai_api_key,
|
||||
"openai_base_url": settings.openai_base_url,
|
||||
"temperature": 0.1,
|
||||
"top_p": None, # Don't use top_p with Claude models
|
||||
"top_p": None,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "qwen3-embedding:4b-q8_0",
|
||||
# "api_key": settings.embedder_api_key,
|
||||
"ollama_base_url": "http://172.17.0.1:11434",
|
||||
"embedding_dims": 2560,
|
||||
"model": settings.embedding_model,
|
||||
"ollama_base_url": settings.ollama_base_url,
|
||||
"embedding_dims": settings.embedding_dims,
|
||||
},
|
||||
},
|
||||
"vector_store": {
|
||||
|
|
@ -74,7 +90,7 @@ class Mem0Manager:
|
|||
"collection_name": settings.qdrant_collection_name,
|
||||
"host": settings.qdrant_host,
|
||||
"port": settings.qdrant_port,
|
||||
"embedding_model_dims": 2560,
|
||||
"embedding_model_dims": settings.embedding_dims,
|
||||
"on_disk": True,
|
||||
},
|
||||
},
|
||||
|
|
@ -106,6 +122,7 @@ class Mem0Manager:
|
|||
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
|
||||
|
||||
# Pure passthrough methods - no custom logic
|
||||
@db_retry
|
||||
@timed("add_memories")
|
||||
async def add_memories(
|
||||
self,
|
||||
|
|
@ -161,6 +178,7 @@ class Mem0Manager:
|
|||
logger.error(f"Error adding memories: {e}")
|
||||
raise
|
||||
|
||||
@db_retry
|
||||
@timed("search_memories")
|
||||
async def search_memories(
|
||||
self,
|
||||
|
|
@ -204,6 +222,7 @@ class Mem0Manager:
|
|||
logger.error(f"Error searching memories: {e}")
|
||||
raise
|
||||
|
||||
@db_retry
|
||||
async def get_user_memories(
|
||||
self,
|
||||
user_id: str,
|
||||
|
|
@ -227,6 +246,24 @@ class Mem0Manager:
|
|||
logger.error(f"Error getting user memories: {e}")
|
||||
raise
|
||||
|
||||
@db_retry
|
||||
async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single memory by ID. Returns None if not found."""
|
||||
try:
|
||||
result = self.memory.get(memory_id=memory_id)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"Memory {memory_id} not found or error: {e}")
|
||||
return None
|
||||
|
||||
async def verify_memory_ownership(self, memory_id: str, user_id: str) -> bool:
|
||||
"""Check if a memory belongs to a user. O(1) instead of O(n)."""
|
||||
memory = await self.get_memory(memory_id)
|
||||
if memory is None:
|
||||
return False
|
||||
return memory.get("user_id") == user_id
|
||||
|
||||
@db_retry
|
||||
@timed("update_memory")
|
||||
async def update_memory(
|
||||
self,
|
||||
|
|
@ -241,6 +278,7 @@ class Mem0Manager:
|
|||
logger.error(f"Error updating memory: {e}")
|
||||
raise
|
||||
|
||||
@db_retry
|
||||
@timed("delete_memory")
|
||||
async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
|
||||
"""Delete memory - pure Mem0 passthrough."""
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ ollama
|
|||
# Utilities
|
||||
pydantic
|
||||
pydantic-settings
|
||||
tenacity
|
||||
python-dotenv
|
||||
httpx
|
||||
aiofiles
|
||||
|
|
|
|||
Loading…
Reference in a new issue