production improvements: configurable embeddings, v1.1, O(1) ownership, retries
- Make Ollama URL configurable via OLLAMA_BASE_URL env var - Add version: v1.1 to Mem0 config (required for latest features) - Make embedding model and dimensions configurable - Fix ownership check: O(1) lookup instead of fetching 10k records - Add tenacity retry logic for database operations
This commit is contained in:
parent
50edce2d3c
commit
a228780146
4 changed files with 120 additions and 44 deletions
|
|
@ -11,39 +11,87 @@ class Settings(BaseSettings):
|
||||||
"""Application settings loaded from environment variables."""
|
"""Application settings loaded from environment variables."""
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env",
|
env_file=".env", case_sensitive=False, extra="ignore"
|
||||||
case_sensitive=False,
|
|
||||||
extra='ignore'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# API Configuration
|
# API Configuration
|
||||||
# Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env)
|
# Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env)
|
||||||
openai_api_key: str = Field(validation_alias=AliasChoices('OPENAI_API_KEY', 'OPENAI_COMPAT_API_KEY', 'openai_api_key'))
|
openai_api_key: str = Field(
|
||||||
openai_base_url: str = Field(validation_alias=AliasChoices('OPENAI_BASE_URL', 'OPENAI_COMPAT_BASE_URL', 'openai_base_url'))
|
validation_alias=AliasChoices(
|
||||||
cohere_api_key: str = Field(validation_alias=AliasChoices('COHERE_API_KEY', 'cohere_api_key'))
|
"OPENAI_API_KEY", "OPENAI_COMPAT_API_KEY", "openai_api_key"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
openai_base_url: str = Field(
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"OPENAI_BASE_URL", "OPENAI_COMPAT_BASE_URL", "openai_base_url"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cohere_api_key: str = Field(
|
||||||
|
validation_alias=AliasChoices("COHERE_API_KEY", "cohere_api_key")
|
||||||
|
)
|
||||||
|
|
||||||
# Database Configuration
|
# Database Configuration
|
||||||
qdrant_host: str = Field(default="localhost", validation_alias=AliasChoices('QDRANT_HOST', 'qdrant_host'))
|
qdrant_host: str = Field(
|
||||||
qdrant_port: int = Field(default=6333, validation_alias=AliasChoices('QDRANT_PORT', 'qdrant_port'))
|
default="localhost", validation_alias=AliasChoices("QDRANT_HOST", "qdrant_host")
|
||||||
qdrant_collection_name: str = Field(default="mem0", validation_alias=AliasChoices('QDRANT_COLLECTION_NAME', 'qdrant_collection_name'))
|
)
|
||||||
|
qdrant_port: int = Field(
|
||||||
|
default=6333, validation_alias=AliasChoices("QDRANT_PORT", "qdrant_port")
|
||||||
|
)
|
||||||
|
qdrant_collection_name: str = Field(
|
||||||
|
default="mem0",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"QDRANT_COLLECTION_NAME", "qdrant_collection_name"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Neo4j Configuration
|
# Neo4j Configuration
|
||||||
neo4j_uri: str = Field(default="bolt://localhost:7687", validation_alias=AliasChoices('NEO4J_URI', 'neo4j_uri'))
|
neo4j_uri: str = Field(
|
||||||
neo4j_username: str = Field(default="neo4j", validation_alias=AliasChoices('NEO4J_USERNAME', 'neo4j_username'))
|
default="bolt://localhost:7687",
|
||||||
neo4j_password: str = Field(default="mem0_neo4j_password", validation_alias=AliasChoices('NEO4J_PASSWORD', 'neo4j_password'))
|
validation_alias=AliasChoices("NEO4J_URI", "neo4j_uri"),
|
||||||
|
)
|
||||||
|
neo4j_username: str = Field(
|
||||||
|
default="neo4j",
|
||||||
|
validation_alias=AliasChoices("NEO4J_USERNAME", "neo4j_username"),
|
||||||
|
)
|
||||||
|
neo4j_password: str = Field(
|
||||||
|
default="mem0_neo4j_password",
|
||||||
|
validation_alias=AliasChoices("NEO4J_PASSWORD", "neo4j_password"),
|
||||||
|
)
|
||||||
|
|
||||||
# Application Configuration
|
# Application Configuration
|
||||||
log_level: str = Field(default="INFO", validation_alias=AliasChoices('LOG_LEVEL', 'log_level'))
|
log_level: str = Field(
|
||||||
cors_origins: str = Field(default="http://localhost:3000", validation_alias=AliasChoices('CORS_ORIGINS', 'cors_origins'))
|
default="INFO", validation_alias=AliasChoices("LOG_LEVEL", "log_level")
|
||||||
|
)
|
||||||
|
cors_origins: str = Field(
|
||||||
|
default="http://localhost:3000",
|
||||||
|
validation_alias=AliasChoices("CORS_ORIGINS", "cors_origins"),
|
||||||
|
)
|
||||||
|
|
||||||
# Model Configuration - Ultra-minimal (single model)
|
# Model Configuration - Ultra-minimal (single model)
|
||||||
default_model: str = Field(default="claude-sonnet-4", validation_alias=AliasChoices('DEFAULT_MODEL', 'default_model'))
|
default_model: str = Field(
|
||||||
|
default="claude-sonnet-4",
|
||||||
|
validation_alias=AliasChoices("DEFAULT_MODEL", "default_model"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedder Configuration
|
||||||
|
ollama_base_url: str = Field(
|
||||||
|
default="http://host.docker.internal:11434",
|
||||||
|
validation_alias=AliasChoices("OLLAMA_BASE_URL", "ollama_base_url"),
|
||||||
|
)
|
||||||
|
embedding_model: str = Field(
|
||||||
|
default="qwen3-embedding:4b-q8_0",
|
||||||
|
validation_alias=AliasChoices("EMBEDDING_MODEL", "embedding_model"),
|
||||||
|
)
|
||||||
|
embedding_dims: int = Field(
|
||||||
|
default=2560, validation_alias=AliasChoices("EMBEDDING_DIMS", "embedding_dims")
|
||||||
|
)
|
||||||
|
|
||||||
# Authentication Configuration
|
# Authentication Configuration
|
||||||
# Format: JSON string mapping API keys to user IDs
|
# Format: JSON string mapping API keys to user IDs
|
||||||
# Example: {"api_key_123": "alice", "api_key_456": "bob"}
|
# Example: {"api_key_123": "alice", "api_key_456": "bob"}
|
||||||
api_keys: str = Field(default="{}", validation_alias=AliasChoices('API_KEYS', 'api_keys'))
|
api_keys: str = Field(
|
||||||
|
default="{}", validation_alias=AliasChoices("API_KEYS", "api_keys")
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cors_origins_list(self) -> List[str]:
|
def cors_origins_list(self) -> List[str]:
|
||||||
|
|
|
||||||
|
|
@ -464,13 +464,10 @@ async def update_memory(
|
||||||
detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')",
|
detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the memory actually belongs to the authenticated user
|
# Verify memory ownership with O(1) lookup instead of fetching all memories
|
||||||
user_memories = await mem0_manager.get_user_memories(
|
if not await mem0_manager.verify_memory_ownership(
|
||||||
user_id=authenticated_user, limit=10000
|
update_request.memory_id, authenticated_user
|
||||||
)
|
):
|
||||||
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
|
|
||||||
|
|
||||||
if update_request.memory_id not in memory_ids:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Memory '{update_request.memory_id}' not found or access denied",
|
detail=f"Memory '{update_request.memory_id}' not found or access denied",
|
||||||
|
|
@ -506,13 +503,10 @@ async def delete_memory(
|
||||||
):
|
):
|
||||||
"""Delete a specific memory - verifies ownership before deletion."""
|
"""Delete a specific memory - verifies ownership before deletion."""
|
||||||
try:
|
try:
|
||||||
# Verify the memory actually belongs to the authenticated user
|
# Verify memory ownership with O(1) lookup instead of fetching all memories
|
||||||
user_memories = await mem0_manager.get_user_memories(
|
if not await mem0_manager.verify_memory_ownership(
|
||||||
user_id=authenticated_user, limit=10000
|
memory_id, authenticated_user
|
||||||
)
|
):
|
||||||
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
|
|
||||||
|
|
||||||
if memory_id not in memory_ids:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Memory '{memory_id}' not found or access denied",
|
detail=f"Memory '{memory_id}' not found or access denied",
|
||||||
|
|
@ -614,13 +608,8 @@ async def get_memory_history(
|
||||||
detail=f"Access denied: You can only view your own memory history (authenticated as '{authenticated_user}')",
|
detail=f"Access denied: You can only view your own memory history (authenticated as '{authenticated_user}')",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the memory belongs to this user before returning history
|
# Verify memory ownership with O(1) lookup instead of fetching all memories
|
||||||
user_memories = await mem0_manager.get_user_memories(
|
if not await mem0_manager.verify_memory_ownership(memory_id, user_id):
|
||||||
user_id=user_id, limit=10000
|
|
||||||
)
|
|
||||||
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
|
|
||||||
|
|
||||||
if memory_id not in memory_ids:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Memory '{memory_id}' not found or access denied",
|
detail=f"Memory '{memory_id}' not found or access denied",
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,28 @@ from typing import Dict, List, Optional, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from mem0 import Memory
|
from mem0 import Memory
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
retry_if_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
|
)
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
from monitoring import timed
|
from monitoring import timed
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Retry decorator for database operations (Qdrant, Neo4j)
|
||||||
|
db_retry = retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||||
|
retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
|
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
|
||||||
from mem0.llms.openai import OpenAILLM
|
from mem0.llms.openai import OpenAILLM
|
||||||
|
|
||||||
|
|
@ -48,6 +64,7 @@ class Mem0Manager:
|
||||||
settings,
|
settings,
|
||||||
)
|
)
|
||||||
config = {
|
config = {
|
||||||
|
"version": "v1.1",
|
||||||
"enable_graph": True,
|
"enable_graph": True,
|
||||||
"llm": {
|
"llm": {
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
|
|
@ -56,16 +73,15 @@ class Mem0Manager:
|
||||||
"api_key": settings.openai_api_key,
|
"api_key": settings.openai_api_key,
|
||||||
"openai_base_url": settings.openai_base_url,
|
"openai_base_url": settings.openai_base_url,
|
||||||
"temperature": 0.1,
|
"temperature": 0.1,
|
||||||
"top_p": None, # Don't use top_p with Claude models
|
"top_p": None,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"embedder": {
|
"embedder": {
|
||||||
"provider": "ollama",
|
"provider": "ollama",
|
||||||
"config": {
|
"config": {
|
||||||
"model": "qwen3-embedding:4b-q8_0",
|
"model": settings.embedding_model,
|
||||||
# "api_key": settings.embedder_api_key,
|
"ollama_base_url": settings.ollama_base_url,
|
||||||
"ollama_base_url": "http://172.17.0.1:11434",
|
"embedding_dims": settings.embedding_dims,
|
||||||
"embedding_dims": 2560,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"vector_store": {
|
"vector_store": {
|
||||||
|
|
@ -74,7 +90,7 @@ class Mem0Manager:
|
||||||
"collection_name": settings.qdrant_collection_name,
|
"collection_name": settings.qdrant_collection_name,
|
||||||
"host": settings.qdrant_host,
|
"host": settings.qdrant_host,
|
||||||
"port": settings.qdrant_port,
|
"port": settings.qdrant_port,
|
||||||
"embedding_model_dims": 2560,
|
"embedding_model_dims": settings.embedding_dims,
|
||||||
"on_disk": True,
|
"on_disk": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -106,6 +122,7 @@ class Mem0Manager:
|
||||||
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
|
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
|
||||||
|
|
||||||
# Pure passthrough methods - no custom logic
|
# Pure passthrough methods - no custom logic
|
||||||
|
@db_retry
|
||||||
@timed("add_memories")
|
@timed("add_memories")
|
||||||
async def add_memories(
|
async def add_memories(
|
||||||
self,
|
self,
|
||||||
|
|
@ -161,6 +178,7 @@ class Mem0Manager:
|
||||||
logger.error(f"Error adding memories: {e}")
|
logger.error(f"Error adding memories: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@db_retry
|
||||||
@timed("search_memories")
|
@timed("search_memories")
|
||||||
async def search_memories(
|
async def search_memories(
|
||||||
self,
|
self,
|
||||||
|
|
@ -204,6 +222,7 @@ class Mem0Manager:
|
||||||
logger.error(f"Error searching memories: {e}")
|
logger.error(f"Error searching memories: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@db_retry
|
||||||
async def get_user_memories(
|
async def get_user_memories(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|
@ -227,6 +246,24 @@ class Mem0Manager:
|
||||||
logger.error(f"Error getting user memories: {e}")
|
logger.error(f"Error getting user memories: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@db_retry
|
||||||
|
async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get a single memory by ID. Returns None if not found."""
|
||||||
|
try:
|
||||||
|
result = self.memory.get(memory_id=memory_id)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Memory {memory_id} not found or error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def verify_memory_ownership(self, memory_id: str, user_id: str) -> bool:
|
||||||
|
"""Check if a memory belongs to a user. O(1) instead of O(n)."""
|
||||||
|
memory = await self.get_memory(memory_id)
|
||||||
|
if memory is None:
|
||||||
|
return False
|
||||||
|
return memory.get("user_id") == user_id
|
||||||
|
|
||||||
|
@db_retry
|
||||||
@timed("update_memory")
|
@timed("update_memory")
|
||||||
async def update_memory(
|
async def update_memory(
|
||||||
self,
|
self,
|
||||||
|
|
@ -241,6 +278,7 @@ class Mem0Manager:
|
||||||
logger.error(f"Error updating memory: {e}")
|
logger.error(f"Error updating memory: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@db_retry
|
||||||
@timed("delete_memory")
|
@timed("delete_memory")
|
||||||
async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
|
async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
|
||||||
"""Delete memory - pure Mem0 passthrough."""
|
"""Delete memory - pure Mem0 passthrough."""
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ ollama
|
||||||
# Utilities
|
# Utilities
|
||||||
pydantic
|
pydantic
|
||||||
pydantic-settings
|
pydantic-settings
|
||||||
|
tenacity
|
||||||
python-dotenv
|
python-dotenv
|
||||||
httpx
|
httpx
|
||||||
aiofiles
|
aiofiles
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue