Compare commits
10 commits
cloud-back
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 82accabc73 | |||
| 6f9b545c15 | |||
| 5bcecf4649 | |||
| 638a591dc5 | |||
| a190527076 | |||
| 9e86c30548 | |||
| 2c1d73a1ec | |||
| a228780146 | |||
| 50edce2d3c | |||
| 35c1bbec4e |
11 changed files with 2088 additions and 522 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
"""Simple API key authentication for Mem0 Interface."""
|
"""Simple API key authentication for Mem0 Interface."""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import HTTPException, Security, status
|
from fastapi import HTTPException, Security, status, Header
|
||||||
from fastapi.security import APIKeyHeader
|
from fastapi.security import APIKeyHeader
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
|
|
@ -19,7 +19,9 @@ class AuthService:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize auth service with API key to user mapping."""
|
"""Initialize auth service with API key to user mapping."""
|
||||||
self.api_key_to_user = settings.api_key_mapping
|
self.api_key_to_user = settings.api_key_mapping
|
||||||
logger.info(f"Auth service initialized with {len(self.api_key_to_user)} API keys")
|
logger.info(
|
||||||
|
f"Auth service initialized with {len(self.api_key_to_user)} API keys"
|
||||||
|
)
|
||||||
|
|
||||||
def verify_api_key(self, api_key: str) -> str:
|
def verify_api_key(self, api_key: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -37,8 +39,7 @@ class AuthService:
|
||||||
if api_key not in self.api_key_to_user:
|
if api_key not in self.api_key_to_user:
|
||||||
logger.warning(f"Invalid API key attempted: {api_key[:10]}...")
|
logger.warning(f"Invalid API key attempted: {api_key[:10]}...")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||||
detail="Invalid API key"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id = self.api_key_to_user[api_key]
|
user_id = self.api_key_to_user[api_key]
|
||||||
|
|
@ -68,7 +69,7 @@ class AuthService:
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Access denied: You can only access your own memories"
|
detail=f"Access denied: You can only access your own memories",
|
||||||
)
|
)
|
||||||
|
|
||||||
return authenticated_user_id
|
return authenticated_user_id
|
||||||
|
|
@ -91,9 +92,46 @@ async def get_current_user(api_key: str = Security(api_key_header)) -> str:
|
||||||
return auth_service.verify_api_key(api_key)
|
return auth_service.verify_api_key(api_key)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user_openai(
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
FastAPI dependency for OpenAI-compatible authentication.
|
||||||
|
Supports both Authorization: Bearer and X-API-Key headers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
authorization: Authorization header (Bearer token)
|
||||||
|
x_api_key: X-API-Key header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Authenticated user_id
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If no valid API key is provided
|
||||||
|
"""
|
||||||
|
api_key = None
|
||||||
|
|
||||||
|
# Try Bearer token first (OpenAI standard)
|
||||||
|
if authorization and authorization.startswith("Bearer "):
|
||||||
|
api_key = authorization[7:] # Remove "Bearer " prefix
|
||||||
|
logger.debug("Extracted API key from Authorization Bearer token")
|
||||||
|
# Fall back to X-API-Key header
|
||||||
|
elif x_api_key:
|
||||||
|
api_key = x_api_key
|
||||||
|
logger.debug("Extracted API key from X-API-Key header")
|
||||||
|
else:
|
||||||
|
logger.warning("No API key provided in Authorization or X-API-Key headers")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing API key. Provide either 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header",
|
||||||
|
)
|
||||||
|
|
||||||
|
return auth_service.verify_api_key(api_key)
|
||||||
|
|
||||||
|
|
||||||
async def verify_user_access(
|
async def verify_user_access(
|
||||||
api_key: str = Security(api_key_header),
|
api_key: str = Security(api_key_header), user_id: Optional[str] = None
|
||||||
user_id: Optional[str] = None
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
FastAPI dependency to verify user can access the requested user_id.
|
FastAPI dependency to verify user can access the requested user_id.
|
||||||
|
|
@ -114,7 +152,7 @@ async def verify_user_access(
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access denied: You can only access your own memories"
|
detail="Access denied: You can only access your own memories",
|
||||||
)
|
)
|
||||||
|
|
||||||
return authenticated_user_id
|
return authenticated_user_id
|
||||||
|
|
|
||||||
|
|
@ -11,39 +11,87 @@ class Settings(BaseSettings):
|
||||||
"""Application settings loaded from environment variables."""
|
"""Application settings loaded from environment variables."""
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env",
|
env_file=".env", case_sensitive=False, extra="ignore"
|
||||||
case_sensitive=False,
|
|
||||||
extra='ignore'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# API Configuration
|
# API Configuration
|
||||||
# Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env)
|
# Accept both OPENAI_API_KEY (from docker-compose) and OPENAI_COMPAT_API_KEY (from direct .env)
|
||||||
openai_api_key: str = Field(validation_alias=AliasChoices('OPENAI_API_KEY', 'OPENAI_COMPAT_API_KEY', 'openai_api_key'))
|
openai_api_key: str = Field(
|
||||||
openai_base_url: str = Field(validation_alias=AliasChoices('OPENAI_BASE_URL', 'OPENAI_COMPAT_BASE_URL', 'openai_base_url'))
|
validation_alias=AliasChoices(
|
||||||
cohere_api_key: str = Field(validation_alias=AliasChoices('COHERE_API_KEY', 'cohere_api_key'))
|
"OPENAI_API_KEY", "OPENAI_COMPAT_API_KEY", "openai_api_key"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
openai_base_url: str = Field(
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"OPENAI_BASE_URL", "OPENAI_COMPAT_BASE_URL", "openai_base_url"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cohere_api_key: str = Field(
|
||||||
|
validation_alias=AliasChoices("COHERE_API_KEY", "cohere_api_key")
|
||||||
|
)
|
||||||
|
|
||||||
# Database Configuration
|
# Database Configuration
|
||||||
qdrant_host: str = Field(default="localhost", validation_alias=AliasChoices('QDRANT_HOST', 'qdrant_host'))
|
qdrant_host: str = Field(
|
||||||
qdrant_port: int = Field(default=6333, validation_alias=AliasChoices('QDRANT_PORT', 'qdrant_port'))
|
default="localhost", validation_alias=AliasChoices("QDRANT_HOST", "qdrant_host")
|
||||||
qdrant_collection_name: str = Field(default="mem0", validation_alias=AliasChoices('QDRANT_COLLECTION_NAME', 'qdrant_collection_name'))
|
)
|
||||||
|
qdrant_port: int = Field(
|
||||||
|
default=6333, validation_alias=AliasChoices("QDRANT_PORT", "qdrant_port")
|
||||||
|
)
|
||||||
|
qdrant_collection_name: str = Field(
|
||||||
|
default="mem0",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"QDRANT_COLLECTION_NAME", "qdrant_collection_name"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Neo4j Configuration
|
# Neo4j Configuration
|
||||||
neo4j_uri: str = Field(default="bolt://localhost:7687", validation_alias=AliasChoices('NEO4J_URI', 'neo4j_uri'))
|
neo4j_uri: str = Field(
|
||||||
neo4j_username: str = Field(default="neo4j", validation_alias=AliasChoices('NEO4J_USERNAME', 'neo4j_username'))
|
default="bolt://localhost:7687",
|
||||||
neo4j_password: str = Field(default="mem0_neo4j_password", validation_alias=AliasChoices('NEO4J_PASSWORD', 'neo4j_password'))
|
validation_alias=AliasChoices("NEO4J_URI", "neo4j_uri"),
|
||||||
|
)
|
||||||
|
neo4j_username: str = Field(
|
||||||
|
default="neo4j",
|
||||||
|
validation_alias=AliasChoices("NEO4J_USERNAME", "neo4j_username"),
|
||||||
|
)
|
||||||
|
neo4j_password: str = Field(
|
||||||
|
default="mem0_neo4j_password",
|
||||||
|
validation_alias=AliasChoices("NEO4J_PASSWORD", "neo4j_password"),
|
||||||
|
)
|
||||||
|
|
||||||
# Application Configuration
|
# Application Configuration
|
||||||
log_level: str = Field(default="INFO", validation_alias=AliasChoices('LOG_LEVEL', 'log_level'))
|
log_level: str = Field(
|
||||||
cors_origins: str = Field(default="http://localhost:3000", validation_alias=AliasChoices('CORS_ORIGINS', 'cors_origins'))
|
default="INFO", validation_alias=AliasChoices("LOG_LEVEL", "log_level")
|
||||||
|
)
|
||||||
|
cors_origins: str = Field(
|
||||||
|
default="http://localhost:3000",
|
||||||
|
validation_alias=AliasChoices("CORS_ORIGINS", "cors_origins"),
|
||||||
|
)
|
||||||
|
|
||||||
# Model Configuration - Ultra-minimal (single model)
|
# Model Configuration - Ultra-minimal (single model)
|
||||||
default_model: str = Field(default="claude-sonnet-4", validation_alias=AliasChoices('DEFAULT_MODEL', 'default_model'))
|
default_model: str = Field(
|
||||||
|
default="claude-sonnet-4",
|
||||||
|
validation_alias=AliasChoices("DEFAULT_MODEL", "default_model"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedder Configuration
|
||||||
|
ollama_base_url: str = Field(
|
||||||
|
default="http://host.docker.internal:11434",
|
||||||
|
validation_alias=AliasChoices("OLLAMA_BASE_URL", "ollama_base_url"),
|
||||||
|
)
|
||||||
|
embedding_model: str = Field(
|
||||||
|
default="qwen3-embedding:4b-q8_0",
|
||||||
|
validation_alias=AliasChoices("EMBEDDING_MODEL", "embedding_model"),
|
||||||
|
)
|
||||||
|
embedding_dims: int = Field(
|
||||||
|
default=2560, validation_alias=AliasChoices("EMBEDDING_DIMS", "embedding_dims")
|
||||||
|
)
|
||||||
|
|
||||||
# Authentication Configuration
|
# Authentication Configuration
|
||||||
# Format: JSON string mapping API keys to user IDs
|
# Format: JSON string mapping API keys to user IDs
|
||||||
# Example: {"api_key_123": "alice", "api_key_456": "bob"}
|
# Example: {"api_key_123": "alice", "api_key_456": "bob"}
|
||||||
api_keys: str = Field(default="{}", validation_alias=AliasChoices('API_KEYS', 'api_keys'))
|
api_keys: str = Field(
|
||||||
|
default="{}", validation_alias=AliasChoices("API_KEYS", "api_keys")
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cors_origins_list(self) -> List[str]:
|
def cors_origins_list(self) -> List[str]:
|
||||||
|
|
|
||||||
621
backend/main.py
621
backend/main.py
|
|
@ -1,25 +1,55 @@
|
||||||
"""Main FastAPI application for Mem0 Interface POC."""
|
"""Main FastAPI application for Mem0 Interface POC."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
import structlog
|
import structlog
|
||||||
|
import asyncio
|
||||||
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
|
|
||||||
|
# Rate limiter - uses IP address as key, falls back to API key for authenticated requests
|
||||||
|
def get_rate_limit_key(request: Request) -> str:
|
||||||
|
"""Get rate limit key - prefer API key if available, otherwise IP."""
|
||||||
|
api_key = request.headers.get("x-api-key", "")
|
||||||
|
if api_key:
|
||||||
|
return f"apikey:{api_key[:16]}" # Use first 16 chars of API key
|
||||||
|
return get_remote_address(request)
|
||||||
|
|
||||||
|
|
||||||
|
limiter = Limiter(key_func=get_rate_limit_key)
|
||||||
from models import (
|
from models import (
|
||||||
ChatRequest, MemoryAddRequest, MemoryAddResponse,
|
ChatRequest,
|
||||||
MemorySearchRequest, MemorySearchResponse, MemoryUpdateRequest,
|
MemoryAddRequest,
|
||||||
MemoryItem, GraphResponse, HealthResponse, ErrorResponse,
|
MemoryAddResponse,
|
||||||
GlobalStatsResponse, UserStatsResponse
|
MemorySearchRequest,
|
||||||
|
MemorySearchResponse,
|
||||||
|
MemoryUpdateRequest,
|
||||||
|
MemoryItem,
|
||||||
|
GraphResponse,
|
||||||
|
HealthResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
GlobalStatsResponse,
|
||||||
|
UserStatsResponse,
|
||||||
|
OpenAIChatCompletionRequest,
|
||||||
|
OpenAIChatCompletionResponse,
|
||||||
|
OpenAIChoice,
|
||||||
|
OpenAIChoiceMessage,
|
||||||
|
OpenAIUsage,
|
||||||
)
|
)
|
||||||
from mem0_manager import mem0_manager
|
from mem0_manager import mem0_manager
|
||||||
from auth import get_current_user, auth_service
|
from auth import get_current_user, get_current_user_openai, auth_service
|
||||||
|
|
||||||
# Configure structured logging
|
# Configure structured logging
|
||||||
structlog.configure(
|
structlog.configure(
|
||||||
|
|
@ -32,7 +62,7 @@ structlog.configure(
|
||||||
structlog.processors.StackInfoRenderer(),
|
structlog.processors.StackInfoRenderer(),
|
||||||
structlog.processors.format_exc_info,
|
structlog.processors.format_exc_info,
|
||||||
structlog.processors.UnicodeDecoder(),
|
structlog.processors.UnicodeDecoder(),
|
||||||
structlog.processors.JSONRenderer()
|
structlog.processors.JSONRenderer(),
|
||||||
],
|
],
|
||||||
context_class=dict,
|
context_class=dict,
|
||||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||||
|
|
@ -58,9 +88,27 @@ async def lifespan(app: FastAPI):
|
||||||
else:
|
else:
|
||||||
logger.info("All services are healthy")
|
logger.info("All services are healthy")
|
||||||
|
|
||||||
|
# Start MCP session manager if available
|
||||||
|
mcp_context = None
|
||||||
|
try:
|
||||||
|
from mcp_server import mcp_lifespan
|
||||||
|
|
||||||
|
mcp_context = mcp_lifespan()
|
||||||
|
await mcp_context.__aenter__()
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("MCP server not available")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start MCP session manager: {e}")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown
|
# Shutdown
|
||||||
|
if mcp_context:
|
||||||
|
try:
|
||||||
|
await mcp_context.__aexit__(None, None, None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error stopping MCP session manager: {e}")
|
||||||
|
|
||||||
logger.info("Shutting down Mem0 Interface POC")
|
logger.info("Shutting down Mem0 Interface POC")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -69,19 +117,47 @@ app = FastAPI(
|
||||||
title="Mem0 Interface POC",
|
title="Mem0 Interface POC",
|
||||||
description="Minimal but fully functional Mem0 interface with PostgreSQL and Neo4j integration",
|
description="Minimal but fully functional Mem0 interface with PostgreSQL and Neo4j integration",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
lifespan=lifespan
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add CORS middleware - Allow all origins for development
|
# Add rate limiter to app state and exception handler
|
||||||
|
app.state.limiter = limiter
|
||||||
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
|
||||||
|
# Add CORS middleware - Allow all origins (secured via API key auth)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"], # Allow all origins for development
|
allow_origins=["*"], # Allow all origins - secured via API key authentication
|
||||||
allow_credentials=False, # Must be False when allow_origins=["*"]
|
allow_credentials=False, # Must be False when allow_origins=["*"]
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Request size limit middleware - prevent DoS via large payloads
|
||||||
|
MAX_REQUEST_SIZE = 10 * 1024 * 1024 # 10MB limit
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def limit_request_size(request, call_next):
|
||||||
|
"""Reject requests that exceed the maximum allowed size."""
|
||||||
|
content_length = request.headers.get("content-length")
|
||||||
|
if content_length:
|
||||||
|
try:
|
||||||
|
if int(content_length) > MAX_REQUEST_SIZE:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=413,
|
||||||
|
content={
|
||||||
|
"error": "Request payload too large",
|
||||||
|
"max_size_bytes": MAX_REQUEST_SIZE,
|
||||||
|
"max_size_mb": MAX_REQUEST_SIZE / (1024 * 1024),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass # Invalid content-length header, let it through for other validation
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
# Request logging middleware with monitoring
|
# Request logging middleware with monitoring
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def log_requests(request, call_next):
|
async def log_requests(request, call_next):
|
||||||
|
|
@ -94,19 +170,19 @@ async def log_requests(request, call_next):
|
||||||
# Extract user_id from request if available
|
# Extract user_id from request if available
|
||||||
user_id = None
|
user_id = None
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
# Try to extract user_id from request body for POST requests
|
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
if body:
|
if body:
|
||||||
import json
|
|
||||||
data = json.loads(body)
|
data = json.loads(body)
|
||||||
user_id = data.get('user_id')
|
user_id = data.get("user_id")
|
||||||
except:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass # Non-JSON body, user_id extraction not possible
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Could not extract user_id from request body", error=str(e))
|
||||||
elif "user_id" in str(request.url.path):
|
elif "user_id" in str(request.url.path):
|
||||||
# Extract user_id from path for GET requests
|
# Extract user_id from path for GET requests
|
||||||
path_parts = request.url.path.split('/')
|
path_parts = request.url.path.split("/")
|
||||||
if len(path_parts) > 2 and path_parts[-2] in ['memories', 'stats']:
|
if len(path_parts) > 2 and path_parts[-2] in ["memories", "stats"]:
|
||||||
user_id = path_parts[-1]
|
user_id = path_parts[-1]
|
||||||
|
|
||||||
# Log start of request
|
# Log start of request
|
||||||
|
|
@ -115,7 +191,7 @@ async def log_requests(request, call_next):
|
||||||
correlation_id=correlation_id,
|
correlation_id=correlation_id,
|
||||||
method=request.method,
|
method=request.method,
|
||||||
path=request.url.path,
|
path=request.url.path,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
@ -136,7 +212,7 @@ async def log_requests(request, call_next):
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
process_time_ms=round(process_time_ms, 2),
|
process_time_ms=round(process_time_ms, 2),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
slow_request=True
|
slow_request=True,
|
||||||
)
|
)
|
||||||
elif response.status_code >= 400:
|
elif response.status_code >= 400:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -147,7 +223,7 @@ async def log_requests(request, call_next):
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
process_time_ms=round(process_time_ms, 2),
|
process_time_ms=round(process_time_ms, 2),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
slow_request=False
|
slow_request=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -158,7 +234,7 @@ async def log_requests(request, call_next):
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
process_time_ms=round(process_time_ms, 2),
|
process_time_ms=round(process_time_ms, 2),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
slow_request=False
|
slow_request=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
@ -167,11 +243,23 @@ async def log_requests(request, call_next):
|
||||||
# Exception handlers
|
# Exception handlers
|
||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def global_exception_handler(request, exc):
|
async def global_exception_handler(request, exc):
|
||||||
"""Global exception handler."""
|
"""Global exception handler - logs details but returns generic message."""
|
||||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
# Log full exception details for debugging (internal only)
|
||||||
|
logger.error(
|
||||||
|
"Unhandled exception",
|
||||||
|
exc_info=True,
|
||||||
|
path=request.url.path,
|
||||||
|
method=request.method,
|
||||||
|
error_type=type(exc).__name__,
|
||||||
|
error_message=str(exc),
|
||||||
|
)
|
||||||
|
# Return generic error to client - don't expose internal details
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
content={"error": "Internal server error", "detail": str(exc)}
|
content={
|
||||||
|
"error": "Internal server error",
|
||||||
|
"message": "An unexpected error occurred",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -181,50 +269,59 @@ async def health_check():
|
||||||
"""Check the health of all services."""
|
"""Check the health of all services."""
|
||||||
try:
|
try:
|
||||||
services = await mem0_manager.health_check()
|
services = await mem0_manager.health_check()
|
||||||
overall_status = "healthy" if all("healthy" in status for status in services.values()) else "degraded"
|
overall_status = (
|
||||||
|
"healthy"
|
||||||
|
if all("healthy" in status for status in services.values())
|
||||||
|
else "degraded"
|
||||||
|
)
|
||||||
|
|
||||||
return HealthResponse(
|
return HealthResponse(
|
||||||
status=overall_status,
|
status=overall_status,
|
||||||
services=services,
|
services=services,
|
||||||
timestamp=datetime.utcnow().isoformat()
|
timestamp=datetime.utcnow().isoformat(),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Health check failed: {e}")
|
logger.error(f"Health check failed: {e}", exc_info=True)
|
||||||
return HealthResponse(
|
return HealthResponse(
|
||||||
status="unhealthy",
|
status="unhealthy",
|
||||||
services={"error": str(e)},
|
services={"error": "Health check failed - see logs for details"},
|
||||||
timestamp=datetime.utcnow().isoformat()
|
timestamp=datetime.utcnow().isoformat(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Core chat endpoint with memory enhancement
|
# Core chat endpoint with memory enhancement
|
||||||
@app.post("/chat")
|
@app.post("/chat")
|
||||||
|
@limiter.limit("30/minute") # Chat is expensive - limit to 30/min
|
||||||
async def chat_with_memory(
|
async def chat_with_memory(
|
||||||
request: ChatRequest,
|
request: Request,
|
||||||
authenticated_user: str = Depends(get_current_user)
|
chat_request: ChatRequest,
|
||||||
|
authenticated_user: str = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Ultra-minimal chat endpoint - pure Mem0 + custom endpoint."""
|
"""Ultra-minimal chat endpoint - pure Mem0 + custom endpoint."""
|
||||||
try:
|
try:
|
||||||
# Verify user can only access their own data
|
# Verify user can only access their own data
|
||||||
if authenticated_user != request.user_id:
|
if authenticated_user != chat_request.user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"Access denied: You can only chat as yourself (authenticated as '{authenticated_user}')"
|
detail=f"Access denied: You can only chat as yourself (authenticated as '{authenticated_user}')",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Processing chat request for user: {request.user_id}")
|
logger.info(f"Processing chat request for user: {chat_request.user_id}")
|
||||||
|
|
||||||
# Convert ChatMessage objects to dict format if context provided
|
# Convert ChatMessage objects to dict format if context provided
|
||||||
context_dict = None
|
context_dict = None
|
||||||
if request.context:
|
if chat_request.context:
|
||||||
context_dict = [{"role": msg.role, "content": msg.content} for msg in request.context]
|
context_dict = [
|
||||||
|
{"role": msg.role, "content": msg.content}
|
||||||
|
for msg in chat_request.context
|
||||||
|
]
|
||||||
|
|
||||||
result = await mem0_manager.chat_with_memory(
|
result = await mem0_manager.chat_with_memory(
|
||||||
message=request.message,
|
message=chat_request.message,
|
||||||
user_id=request.user_id,
|
user_id=chat_request.user_id,
|
||||||
agent_id=request.agent_id,
|
agent_id=chat_request.agent_id,
|
||||||
run_id=request.run_id,
|
run_id=chat_request.run_id,
|
||||||
context=context_dict
|
context=context_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -233,32 +330,173 @@ async def chat_with_memory(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in chat endpoint: {e}")
|
logger.error(f"Error in chat endpoint: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_openai_response(
|
||||||
|
completion_id: str, model: str, content: str, created: int
|
||||||
|
):
|
||||||
|
"""Generate SSE stream for OpenAI-compatible streaming by chunking the response."""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# First chunk with role
|
||||||
|
chunk = {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": ""},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
# Stream content in chunks (3 words at a time for smooth effect)
|
||||||
|
words = content.split()
|
||||||
|
chunk_size = 3
|
||||||
|
|
||||||
|
for i in range(0, len(words), chunk_size):
|
||||||
|
word_chunk = " ".join(words[i : i + chunk_size])
|
||||||
|
if i + chunk_size < len(words):
|
||||||
|
word_chunk += " "
|
||||||
|
|
||||||
|
chunk = {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{"index": 0, "delta": {"content": word_chunk}, "finish_reason": None}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
# Final chunk
|
||||||
|
chunk = {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
@app.post("/chat/completions")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def openai_chat_completions(
|
||||||
|
request: Request,
|
||||||
|
completion_request: OpenAIChatCompletionRequest,
|
||||||
|
authenticated_user: str = Depends(get_current_user_openai),
|
||||||
|
):
|
||||||
|
"""OpenAI-compatible chat completions endpoint with mem0 memory integration."""
|
||||||
|
try:
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
user_id = authenticated_user
|
||||||
|
logger.info(
|
||||||
|
f"OpenAI chat completion for user: {user_id} (streaming={completion_request.stream})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract last user message
|
||||||
|
user_messages = [
|
||||||
|
m for m in completion_request.messages if m.get("role") == "user"
|
||||||
|
]
|
||||||
|
if not user_messages:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No user messages provided. Include at least one message with role='user'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
last_message = user_messages[-1].get("content", "")
|
||||||
|
context = (
|
||||||
|
completion_request.messages[:-1]
|
||||||
|
if len(completion_request.messages) > 1
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call chat_with_memory
|
||||||
|
result = await mem0_manager.chat_with_memory(
|
||||||
|
message=last_message,
|
||||||
|
user_id=user_id,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||||
|
created_time = int(time.time())
|
||||||
|
assistant_content = result.get("response", "")
|
||||||
|
|
||||||
|
if completion_request.stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_openai_response(
|
||||||
|
completion_id=completion_id,
|
||||||
|
model=settings.default_model,
|
||||||
|
content=assistant_content,
|
||||||
|
created=created_time,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return OpenAIChatCompletionResponse(
|
||||||
|
id=completion_id,
|
||||||
|
object="chat.completion",
|
||||||
|
created=created_time,
|
||||||
|
model=settings.default_model,
|
||||||
|
choices=[
|
||||||
|
OpenAIChoice(
|
||||||
|
index=0,
|
||||||
|
message=OpenAIChoiceMessage(
|
||||||
|
role="assistant", content=assistant_content
|
||||||
|
),
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=OpenAIUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in OpenAI chat completions: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
# Memory management endpoints - pure Mem0 passthroughs
|
# Memory management endpoints - pure Mem0 passthroughs
|
||||||
@app.post("/memories")
|
@app.post("/memories")
|
||||||
|
@limiter.limit("60/minute") # Memory operations - 60/min
|
||||||
async def add_memories(
|
async def add_memories(
|
||||||
request: MemoryAddRequest,
|
request: Request,
|
||||||
authenticated_user: str = Depends(get_current_user)
|
memory_request: MemoryAddRequest,
|
||||||
|
authenticated_user: str = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Add memories - pure Mem0 passthrough."""
|
"""Add memories - pure Mem0 passthrough."""
|
||||||
try:
|
try:
|
||||||
# Verify user can only add to their own memories
|
# Verify user can only add to their own memories
|
||||||
if authenticated_user != request.user_id:
|
if authenticated_user != memory_request.user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"Access denied: You can only add memories for yourself (authenticated as '{authenticated_user}')"
|
detail=f"Access denied: You can only add memories for yourself (authenticated as '{authenticated_user}')",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Adding memories for user: {request.user_id}")
|
logger.info(f"Adding memories for user: {memory_request.user_id}")
|
||||||
|
|
||||||
result = await mem0_manager.add_memories(
|
result = await mem0_manager.add_memories(
|
||||||
messages=request.messages,
|
messages=memory_request.messages,
|
||||||
user_id=request.user_id,
|
user_id=memory_request.user_id,
|
||||||
agent_id=request.agent_id,
|
agent_id=memory_request.agent_id,
|
||||||
run_id=request.run_id,
|
run_id=memory_request.run_id,
|
||||||
metadata=request.metadata
|
metadata=memory_request.metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -267,33 +505,40 @@ async def add_memories(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding memories: {e}")
|
logger.error(f"Error adding memories: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/memories/search")
|
@app.post("/memories/search")
|
||||||
|
@limiter.limit("120/minute") # Search is lighter - 120/min
|
||||||
async def search_memories(
|
async def search_memories(
|
||||||
request: MemorySearchRequest,
|
request: Request,
|
||||||
authenticated_user: str = Depends(get_current_user)
|
search_request: MemorySearchRequest,
|
||||||
|
authenticated_user: str = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Search memories - pure Mem0 passthrough."""
|
"""Search memories - pure Mem0 passthrough."""
|
||||||
try:
|
try:
|
||||||
# Verify user can only search their own memories
|
# Verify user can only search their own memories
|
||||||
if authenticated_user != request.user_id:
|
if authenticated_user != search_request.user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"Access denied: You can only search your own memories (authenticated as '{authenticated_user}')"
|
detail=f"Access denied: You can only search your own memories (authenticated as '{authenticated_user}')",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Searching memories for user: {request.user_id}, query: {request.query}")
|
logger.info(
|
||||||
|
f"Searching memories for user: {search_request.user_id}, query: {search_request.query}"
|
||||||
|
)
|
||||||
|
|
||||||
result = await mem0_manager.search_memories(
|
result = await mem0_manager.search_memories(
|
||||||
query=request.query,
|
query=search_request.query,
|
||||||
user_id=request.user_id,
|
user_id=search_request.user_id,
|
||||||
limit=request.limit,
|
limit=search_request.limit,
|
||||||
threshold=request.threshold or 0.2,
|
threshold=search_request.threshold or 0.2,
|
||||||
filters=request.filters,
|
filters=search_request.filters,
|
||||||
agent_id=request.agent_id,
|
agent_id=search_request.agent_id,
|
||||||
run_id=request.run_id
|
run_id=search_request.run_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -302,16 +547,21 @@ async def search_memories(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error searching memories: {e}")
|
logger.error(f"Error searching memories: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/memories/{user_id}")
|
@app.get("/memories/{user_id}")
|
||||||
|
@limiter.limit("120/minute")
|
||||||
async def get_user_memories(
|
async def get_user_memories(
|
||||||
|
request: Request,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
authenticated_user: str = Depends(get_current_user),
|
authenticated_user: str = Depends(get_current_user),
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
agent_id: Optional[str] = None,
|
agent_id: Optional[str] = None,
|
||||||
run_id: Optional[str] = None
|
run_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Get all memories for a user with hierarchy filtering - pure Mem0 passthrough."""
|
"""Get all memories for a user with hierarchy filtering - pure Mem0 passthrough."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -319,16 +569,13 @@ async def get_user_memories(
|
||||||
if authenticated_user != user_id:
|
if authenticated_user != user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"Access denied: You can only retrieve your own memories (authenticated as '{authenticated_user}')"
|
detail=f"Access denied: You can only retrieve your own memories (authenticated as '{authenticated_user}')",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Retrieving memories for user: {user_id}")
|
logger.info(f"Retrieving memories for user: {user_id}")
|
||||||
|
|
||||||
memories = await mem0_manager.get_user_memories(
|
memories = await mem0_manager.get_user_memories(
|
||||||
user_id=user_id,
|
user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id
|
||||||
limit=limit,
|
|
||||||
agent_id=agent_id,
|
|
||||||
run_id=run_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return memories
|
return memories
|
||||||
|
|
@ -337,28 +584,44 @@ async def get_user_memories(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving user memories: {e}")
|
logger.error(f"Error retrieving user memories: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.put("/memories")
|
@app.put("/memories")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
async def update_memory(
|
async def update_memory(
|
||||||
request: MemoryUpdateRequest,
|
request: Request,
|
||||||
authenticated_user: str = Depends(get_current_user)
|
update_request: MemoryUpdateRequest,
|
||||||
|
authenticated_user: str = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Update memory - pure Mem0 passthrough."""
|
"""Update memory - verifies ownership before update."""
|
||||||
try:
|
try:
|
||||||
# Verify user owns the memory being updated
|
# Verify user owns the memory being updated
|
||||||
if authenticated_user != request.user_id:
|
if authenticated_user != update_request.user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')"
|
detail=f"Access denied: You can only update your own memories (authenticated as '{authenticated_user}')",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Updating memory: {request.memory_id}")
|
# Verify memory ownership with O(1) lookup instead of fetching all memories
|
||||||
|
if not await mem0_manager.verify_memory_ownership(
|
||||||
|
update_request.memory_id, authenticated_user
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Memory '{update_request.memory_id}' not found or access denied",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Updating memory: {update_request.memory_id}", user_id=authenticated_user
|
||||||
|
)
|
||||||
|
|
||||||
result = await mem0_manager.update_memory(
|
result = await mem0_manager.update_memory(
|
||||||
memory_id=request.memory_id,
|
memory_id=update_request.memory_id,
|
||||||
content=request.content,
|
content=update_request.content,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -367,25 +630,31 @@ async def update_memory(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating memory: {e}")
|
logger.error(f"Error updating memory: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/memories/{memory_id}")
|
@app.delete("/memories/{memory_id}")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
async def delete_memory(
|
async def delete_memory(
|
||||||
|
request: Request,
|
||||||
memory_id: str,
|
memory_id: str,
|
||||||
user_id: str, # Add user_id as query parameter for verification
|
authenticated_user: str = Depends(get_current_user),
|
||||||
authenticated_user: str = Depends(get_current_user)
|
|
||||||
):
|
):
|
||||||
"""Delete a specific memory."""
|
"""Delete a specific memory - verifies ownership before deletion."""
|
||||||
try:
|
try:
|
||||||
# Verify user owns the memory being deleted
|
# Verify memory ownership with O(1) lookup instead of fetching all memories
|
||||||
if authenticated_user != user_id:
|
if not await mem0_manager.verify_memory_ownership(
|
||||||
|
memory_id, authenticated_user
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=404,
|
||||||
detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')"
|
detail=f"Memory '{memory_id}' not found or access denied",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Deleting memory: {memory_id}")
|
logger.info(f"Deleting memory: {memory_id}", user_id=authenticated_user)
|
||||||
|
|
||||||
result = await mem0_manager.delete_memory(memory_id=memory_id)
|
result = await mem0_manager.delete_memory(memory_id=memory_id)
|
||||||
|
|
||||||
|
|
@ -395,13 +664,16 @@ async def delete_memory(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting memory: {e}")
|
logger.error(f"Error deleting memory: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/memories/user/{user_id}")
|
@app.delete("/memories/user/{user_id}")
|
||||||
|
@limiter.limit("10/minute") # Dangerous bulk delete - heavily rate limited
|
||||||
async def delete_user_memories(
|
async def delete_user_memories(
|
||||||
user_id: str,
|
request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
|
||||||
authenticated_user: str = Depends(get_current_user)
|
|
||||||
):
|
):
|
||||||
"""Delete all memories for a specific user."""
|
"""Delete all memories for a specific user."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -409,7 +681,7 @@ async def delete_user_memories(
|
||||||
if authenticated_user != user_id:
|
if authenticated_user != user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')"
|
detail=f"Access denied: You can only delete your own memories (authenticated as '{authenticated_user}')",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Deleting all memories for user: {user_id}")
|
logger.info(f"Deleting all memories for user: {user_id}")
|
||||||
|
|
@ -422,14 +694,17 @@ async def delete_user_memories(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting user memories: {e}")
|
logger.error(f"Error deleting user memories: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Graph relationships endpoint - pure Mem0 passthrough
|
# Graph relationships endpoint - pure Mem0 passthrough
|
||||||
@app.get("/graph/relationships/{user_id}")
|
@app.get("/graph/relationships/{user_id}")
|
||||||
|
@limiter.limit("60/minute")
|
||||||
async def get_graph_relationships(
|
async def get_graph_relationships(
|
||||||
user_id: str,
|
request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
|
||||||
authenticated_user: str = Depends(get_current_user)
|
|
||||||
):
|
):
|
||||||
"""Get graph relationships - pure Mem0 passthrough."""
|
"""Get graph relationships - pure Mem0 passthrough."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -437,11 +712,13 @@ async def get_graph_relationships(
|
||||||
if authenticated_user != user_id:
|
if authenticated_user != user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"Access denied: You can only view your own relationships (authenticated as '{authenticated_user}')"
|
detail=f"Access denied: You can only view your own relationships (authenticated as '{authenticated_user}')",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Retrieving graph relationships for user: {user_id}")
|
logger.info(f"Retrieving graph relationships for user: {user_id}")
|
||||||
result = await mem0_manager.get_graph_relationships(user_id=user_id, agent_id=None, run_id=None, limit=10000)
|
result = await mem0_manager.get_graph_relationships(
|
||||||
|
user_id=user_id, agent_id=None, run_id=None, limit=10000
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -449,68 +726,101 @@ async def get_graph_relationships(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving graph relationships: {e}")
|
logger.error(f"Error retrieving graph relationships: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Memory history endpoint - new feature
|
# Memory history endpoint - new feature
|
||||||
@app.get("/memories/{memory_id}/history")
|
@app.get("/memories/{memory_id}/history")
|
||||||
async def get_memory_history(memory_id: str):
|
@limiter.limit("120/minute")
|
||||||
|
async def get_memory_history(
|
||||||
|
request: Request,
|
||||||
|
memory_id: str,
|
||||||
|
user_id: str, # Required query param to verify ownership
|
||||||
|
authenticated_user: str = Depends(get_current_user),
|
||||||
|
):
|
||||||
"""Get memory change history - pure Mem0 passthrough."""
|
"""Get memory change history - pure Mem0 passthrough."""
|
||||||
try:
|
try:
|
||||||
logger.info(f"Retrieving history for memory: {memory_id}")
|
# Verify user can only access their own memory history
|
||||||
|
if authenticated_user != user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"Access denied: You can only view your own memory history (authenticated as '{authenticated_user}')",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify memory ownership with O(1) lookup instead of fetching all memories
|
||||||
|
if not await mem0_manager.verify_memory_ownership(memory_id, user_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Memory '{memory_id}' not found or access denied",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Retrieving history for memory: {memory_id}", user_id=user_id)
|
||||||
|
|
||||||
result = await mem0_manager.get_memory_history(memory_id=memory_id)
|
result = await mem0_manager.get_memory_history(memory_id=memory_id)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving memory history: {e}")
|
logger.error(f"Error retrieving memory history: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Statistics and monitoring endpoints
|
# Statistics and monitoring endpoints
|
||||||
@app.get("/stats", response_model=GlobalStatsResponse)
|
@app.get("/stats", response_model=GlobalStatsResponse)
|
||||||
async def get_global_stats():
|
@limiter.limit("60/minute")
|
||||||
"""Get global application statistics."""
|
async def get_global_stats(
|
||||||
|
request: Request, authenticated_user: str = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Get global application statistics - requires authentication."""
|
||||||
try:
|
try:
|
||||||
from monitoring import stats
|
from monitoring import stats
|
||||||
|
|
||||||
# Get basic stats from monitoring
|
|
||||||
basic_stats = stats.get_global_stats()
|
basic_stats = stats.get_global_stats()
|
||||||
|
|
||||||
# Get actual memory count from Mem0 (simplified approach)
|
|
||||||
try:
|
try:
|
||||||
# This is a rough estimate - in production you might want a more efficient method
|
sample_result = await mem0_manager.search_memories(
|
||||||
sample_result = await mem0_manager.search_memories(query="*", user_id="__stats_check__", limit=1)
|
query="*", user_id="__stats_check__", limit=1
|
||||||
# For now, we'll use the basic stats total_memories value
|
)
|
||||||
# You could implement a more accurate count by querying the database directly
|
total_memories = basic_stats["total_memories"]
|
||||||
total_memories = basic_stats['total_memories'] # Will be 0 for now
|
except Exception:
|
||||||
except:
|
|
||||||
total_memories = 0
|
total_memories = 0
|
||||||
|
|
||||||
return GlobalStatsResponse(
|
return GlobalStatsResponse(
|
||||||
total_memories=total_memories,
|
total_memories=total_memories,
|
||||||
total_users=basic_stats['total_users'],
|
total_users=basic_stats["total_users"],
|
||||||
api_calls_today=basic_stats['api_calls_today'],
|
api_calls_today=basic_stats["api_calls_today"],
|
||||||
avg_response_time_ms=basic_stats['avg_response_time_ms'],
|
avg_response_time_ms=basic_stats["avg_response_time_ms"],
|
||||||
memory_operations={
|
memory_operations={
|
||||||
"add": basic_stats['memory_operations']['add'],
|
"add": basic_stats["memory_operations"]["add"],
|
||||||
"search": basic_stats['memory_operations']['search'],
|
"search": basic_stats["memory_operations"]["search"],
|
||||||
"update": basic_stats['memory_operations']['update'],
|
"update": basic_stats["memory_operations"]["update"],
|
||||||
"delete": basic_stats['memory_operations']['delete']
|
"delete": basic_stats["memory_operations"]["delete"],
|
||||||
},
|
},
|
||||||
uptime_seconds=basic_stats['uptime_seconds']
|
uptime_seconds=basic_stats["uptime_seconds"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting global stats: {e}")
|
logger.error(f"Error getting global stats: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats/{user_id}", response_model=UserStatsResponse)
|
@app.get("/stats/{user_id}", response_model=UserStatsResponse)
|
||||||
|
@limiter.limit("120/minute")
|
||||||
async def get_user_stats(
|
async def get_user_stats(
|
||||||
user_id: str,
|
request: Request, user_id: str, authenticated_user: str = Depends(get_current_user)
|
||||||
authenticated_user: str = Depends(get_current_user)
|
|
||||||
):
|
):
|
||||||
"""Get user-specific statistics."""
|
"""Get user-specific statistics."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -518,7 +828,7 @@ async def get_user_stats(
|
||||||
if authenticated_user != user_id:
|
if authenticated_user != user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"Access denied: You can only view your own statistics (authenticated as '{authenticated_user}')"
|
detail=f"Access denied: You can only view your own statistics (authenticated as '{authenticated_user}')",
|
||||||
)
|
)
|
||||||
|
|
||||||
from monitoring import stats
|
from monitoring import stats
|
||||||
|
|
@ -528,63 +838,92 @@ async def get_user_stats(
|
||||||
|
|
||||||
# Get actual memory count for this user
|
# Get actual memory count for this user
|
||||||
try:
|
try:
|
||||||
user_memories = await mem0_manager.get_user_memories(user_id=user_id, limit=10000)
|
user_memories = await mem0_manager.get_user_memories(
|
||||||
|
user_id=user_id, limit=10000
|
||||||
|
)
|
||||||
memory_count = len(user_memories)
|
memory_count = len(user_memories)
|
||||||
except:
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get memory count for user {user_id}: {e}")
|
||||||
memory_count = 0
|
memory_count = 0
|
||||||
|
|
||||||
# Get relationship count for this user
|
# Get relationship count for this user
|
||||||
try:
|
try:
|
||||||
graph_data = await mem0_manager.get_graph_relationships(user_id=user_id, agent_id=None, run_id=None)
|
graph_data = await mem0_manager.get_graph_relationships(
|
||||||
relationship_count = len(graph_data.get('relationships', []))
|
user_id=user_id, agent_id=None, run_id=None
|
||||||
except:
|
)
|
||||||
|
relationship_count = len(graph_data.get("relationships", []))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get relationship count for user {user_id}: {e}")
|
||||||
relationship_count = 0
|
relationship_count = 0
|
||||||
|
|
||||||
return UserStatsResponse(
|
return UserStatsResponse(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
memory_count=memory_count,
|
memory_count=memory_count,
|
||||||
relationship_count=relationship_count,
|
relationship_count=relationship_count,
|
||||||
last_activity=basic_stats['last_activity'],
|
last_activity=basic_stats["last_activity"],
|
||||||
api_calls_today=basic_stats['api_calls_today'],
|
api_calls_today=basic_stats["api_calls_today"],
|
||||||
avg_response_time_ms=basic_stats['avg_response_time_ms']
|
avg_response_time_ms=basic_stats["avg_response_time_ms"],
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting user stats for {user_id}: {e}")
|
logger.error(f"Error getting user stats for {user_id}: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An internal error occurred. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Utility endpoints
|
# Utility endpoints
|
||||||
@app.get("/models")
|
@app.get("/models")
|
||||||
async def get_available_models():
|
@limiter.limit("120/minute")
|
||||||
"""Get current model configuration."""
|
async def get_available_models(
|
||||||
|
request: Request, authenticated_user: str = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Get current model configuration - requires authentication."""
|
||||||
return {
|
return {
|
||||||
"current_model": settings.default_model,
|
"current_model": settings.default_model,
|
||||||
"endpoint": settings.openai_base_url,
|
"endpoint": settings.openai_base_url,
|
||||||
"note": "Using single model with pure Mem0 intelligence"
|
"note": "Using single model with pure Mem0 intelligence",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/users")
|
@app.get("/users")
|
||||||
async def get_active_users():
|
@limiter.limit("60/minute")
|
||||||
"""Get list of users with memories (simplified implementation)."""
|
async def get_active_users(
|
||||||
|
request: Request, authenticated_user: str = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Get list of users with memories (simplified implementation) - requires authentication."""
|
||||||
# This would typically query the database for users with memories
|
# This would typically query the database for users with memories
|
||||||
# For now, return a placeholder
|
# For now, return a placeholder
|
||||||
return {
|
return {
|
||||||
"message": "This endpoint would return users with stored memories",
|
"message": "This endpoint would return users with stored memories",
|
||||||
"note": "Implementation depends on direct database access or Mem0 user enumeration capabilities"
|
"note": "Implementation depends on direct database access or Mem0 user enumeration capabilities",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Mount MCP server at /mcp endpoint
|
||||||
|
try:
|
||||||
|
from mcp_server import create_mcp_app
|
||||||
|
|
||||||
|
mcp_app = create_mcp_app()
|
||||||
|
app.mount("/mcp", mcp_app)
|
||||||
|
logger.info("MCP server mounted at /mcp")
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"MCP server not available (missing dependencies): {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to mount MCP server: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
print("Starting UVicorn server...")
|
print("Starting UVicorn server...")
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"main:app",
|
"main:app",
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=8000,
|
port=8000,
|
||||||
log_level=settings.log_level.lower(),
|
log_level=settings.log_level.lower(),
|
||||||
reload=True
|
reload=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
240
backend/mcp_server.py
Normal file
240
backend/mcp_server.py
Normal file
|
|
@ -0,0 +1,240 @@
|
||||||
|
"""MCP Server for Mem0 Memory Service.
|
||||||
|
|
||||||
|
Exposes memory operations as MCP tools over HTTP with API key authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.routing import Mount
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Context variable for authenticated user_id
|
||||||
|
# Set by middleware, read by tools
|
||||||
|
current_user_id: ContextVar[str] = ContextVar("current_user_id", default="")
|
||||||
|
|
||||||
|
|
||||||
|
def get_authenticated_user() -> str:
|
||||||
|
"""Get the authenticated user_id from context.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no authenticated user in context.
|
||||||
|
"""
|
||||||
|
user_id = current_user_id.get()
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("No authenticated user in context")
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
|
class MCPAuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to authenticate MCP requests using API key."""
|
||||||
|
|
||||||
|
async def dispatch(self, request, call_next):
|
||||||
|
# Extract API key from headers
|
||||||
|
api_key = (
|
||||||
|
request.headers.get("x-api-key") or
|
||||||
|
request.headers.get("X-API-Key") or
|
||||||
|
request.headers.get("Authorization", "").replace("Bearer ", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "Missing API key. Provide x-api-key or Authorization header."},
|
||||||
|
status_code=401
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map API key to user_id
|
||||||
|
user_id = settings.api_key_mapping.get(api_key)
|
||||||
|
if not user_id:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "Invalid API key"},
|
||||||
|
status_code=401
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store user_id in context for tools to access
|
||||||
|
token = current_user_id.set(user_id)
|
||||||
|
try:
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
finally:
|
||||||
|
current_user_id.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
# Create FastMCP server
|
||||||
|
# streamable_http_path="/" since we mount at /mcp in main.py
|
||||||
|
mcp = FastMCP(
|
||||||
|
"Mem0 Memory Service",
|
||||||
|
stateless_http=True,
|
||||||
|
json_response=True,
|
||||||
|
streamable_http_path="/"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def add_memory(
|
||||||
|
content: str = Field(description="Content to add to memory"),
|
||||||
|
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
|
||||||
|
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
|
||||||
|
metadata: Optional[dict] = Field(default=None, description="Optional metadata to attach to the memory"),
|
||||||
|
) -> dict:
|
||||||
|
"""Add content to the user's memory.
|
||||||
|
|
||||||
|
The user_id is automatically determined from the API key authentication.
|
||||||
|
Use agent_id and run_id for multi-agent or session-based memory organization.
|
||||||
|
"""
|
||||||
|
from mem0_manager import mem0_manager
|
||||||
|
|
||||||
|
user_id = get_authenticated_user()
|
||||||
|
logger.info(f"MCP add_memory: user={user_id}, agent={agent_id}, run={run_id}")
|
||||||
|
|
||||||
|
result = await mem0_manager.add_memories(
|
||||||
|
messages=[{"role": "user", "content": content}],
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
run_id=run_id,
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_memory(
|
||||||
|
query: str = Field(description="Search query to find relevant memories"),
|
||||||
|
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier to filter memories"),
|
||||||
|
run_id: Optional[str] = Field(default=None, description="Optional run identifier to filter memories"),
|
||||||
|
limit: int = Field(default=10, ge=1, le=100, description="Maximum number of results to return"),
|
||||||
|
) -> dict:
|
||||||
|
"""Search the user's memories.
|
||||||
|
|
||||||
|
Returns memories most relevant to the query. The user_id is automatically
|
||||||
|
determined from API key authentication.
|
||||||
|
"""
|
||||||
|
from mem0_manager import mem0_manager
|
||||||
|
|
||||||
|
user_id = get_authenticated_user()
|
||||||
|
logger.info(f"MCP search_memory: user={user_id}, query={query[:50]}..., limit={limit}")
|
||||||
|
|
||||||
|
result = await mem0_manager.search_memories(
|
||||||
|
query=query,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
run_id=run_id,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def remove_memory(
|
||||||
|
memory_id: str = Field(description="The ID of the memory to remove"),
|
||||||
|
) -> dict:
|
||||||
|
"""Remove a specific memory by its ID.
|
||||||
|
|
||||||
|
Only memories belonging to the authenticated user can be deleted.
|
||||||
|
Verifies ownership before deletion.
|
||||||
|
"""
|
||||||
|
from mem0_manager import mem0_manager
|
||||||
|
|
||||||
|
user_id = get_authenticated_user()
|
||||||
|
logger.info(f"MCP remove_memory: user={user_id}, memory_id={memory_id}")
|
||||||
|
|
||||||
|
# Verify ownership: get user's memories and check if memory_id exists
|
||||||
|
user_memories = await mem0_manager.get_user_memories(
|
||||||
|
user_id=user_id,
|
||||||
|
limit=10000 # Get all to check ownership
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_ids = {m.get("id") for m in user_memories if m.get("id")}
|
||||||
|
|
||||||
|
if memory_id not in memory_ids:
|
||||||
|
raise ValueError(f"Memory '{memory_id}' not found or access denied")
|
||||||
|
|
||||||
|
result = await mem0_manager.delete_memory(memory_id=memory_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def chat(
|
||||||
|
message: str = Field(description="The user's message to chat with"),
|
||||||
|
agent_id: Optional[str] = Field(default=None, description="Optional agent identifier for multi-agent scenarios"),
|
||||||
|
run_id: Optional[str] = Field(default=None, description="Optional run identifier for session tracking"),
|
||||||
|
) -> str:
|
||||||
|
"""Chat with memory context.
|
||||||
|
|
||||||
|
Retrieves relevant memories based on the message, generates a response
|
||||||
|
using the configured LLM, and stores the conversation in memory.
|
||||||
|
The user_id is automatically determined from API key authentication.
|
||||||
|
"""
|
||||||
|
from mem0_manager import mem0_manager
|
||||||
|
|
||||||
|
user_id = get_authenticated_user()
|
||||||
|
logger.info(f"MCP chat: user={user_id}, agent={agent_id}, message={message[:50]}...")
|
||||||
|
|
||||||
|
result = await mem0_manager.chat_with_memory(
|
||||||
|
message=message,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
run_id=run_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def mcp_lifespan():
|
||||||
|
"""Context manager for MCP session lifecycle.
|
||||||
|
|
||||||
|
Must be used in the main FastAPI lifespan since mounted app
|
||||||
|
lifespans don't run automatically.
|
||||||
|
"""
|
||||||
|
async with mcp.session_manager.run():
|
||||||
|
logger.info("MCP session manager started")
|
||||||
|
yield
|
||||||
|
logger.info("MCP session manager stopped")
|
||||||
|
|
||||||
|
|
||||||
|
def create_mcp_app() -> Starlette:
|
||||||
|
"""Create and configure the MCP Starlette application.
|
||||||
|
|
||||||
|
Returns a Starlette app with MCP endpoints, authentication middleware,
|
||||||
|
and CORS support.
|
||||||
|
|
||||||
|
Note: The MCP session manager must be started via mcp_lifespan() in
|
||||||
|
the main FastAPI lifespan, not here.
|
||||||
|
"""
|
||||||
|
# Get the StreamableHTTP app - it handles requests at "/" since we mount at /mcp
|
||||||
|
streamable_app = mcp.streamable_http_app()
|
||||||
|
|
||||||
|
# Create Starlette app with MCP routes
|
||||||
|
app = Starlette(
|
||||||
|
routes=[Mount("/", app=streamable_app)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add authentication middleware
|
||||||
|
app.add_middleware(MCPAuthMiddleware)
|
||||||
|
|
||||||
|
# Add CORS middleware for browser clients
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=False,
|
||||||
|
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
expose_headers=["Mcp-Session-Id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
@ -5,24 +5,48 @@ from typing import Dict, List, Optional, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from mem0 import Memory
|
from mem0 import Memory
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
retry_if_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
|
)
|
||||||
|
import structlog
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
from monitoring import timed
|
from monitoring import timed
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
# Retry decorator for database operations (Qdrant, Neo4j)
|
||||||
|
db_retry = retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||||
|
retry=retry_if_exception_type((ConnectionError, TimeoutError, OSError)),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
|
# Monkey-patch Mem0's OpenAI LLM to remove the 'store' parameter for LiteLLM compatibility
|
||||||
from mem0.llms.openai import OpenAILLM
|
from mem0.llms.openai import OpenAILLM
|
||||||
|
|
||||||
_original_generate_response = OpenAILLM.generate_response
|
_original_generate_response = OpenAILLM.generate_response
|
||||||
|
|
||||||
def patched_generate_response(self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs):
|
|
||||||
|
def patched_generate_response(
|
||||||
|
self, messages, response_format=None, tools=None, tool_choice="auto", **kwargs
|
||||||
|
):
|
||||||
# Remove 'store' parameter as LiteLLM doesn't support it
|
# Remove 'store' parameter as LiteLLM doesn't support it
|
||||||
if hasattr(self.config, 'store'):
|
if hasattr(self.config, "store"):
|
||||||
self.config.store = None
|
self.config.store = None
|
||||||
# Remove 'top_p' to avoid conflict with temperature for Claude models
|
# Remove 'top_p' to avoid conflict with temperature for Claude models
|
||||||
if hasattr(self.config, 'top_p'):
|
if hasattr(self.config, "top_p"):
|
||||||
self.config.top_p = None
|
self.config.top_p = None
|
||||||
return _original_generate_response(self, messages, response_format, tools, tool_choice, **kwargs)
|
return _original_generate_response(
|
||||||
|
self, messages, response_format, tools, tool_choice, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
OpenAILLM.generate_response = patched_generate_response
|
OpenAILLM.generate_response = patched_generate_response
|
||||||
logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter")
|
logger.info("Applied LiteLLM compatibility patch: disabled 'store' parameter")
|
||||||
|
|
@ -36,8 +60,16 @@ class Mem0Manager:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Custom endpoint configuration with graph memory enabled
|
# Custom endpoint configuration with graph memory enabled
|
||||||
logger.info("Initializing ultra-minimal Mem0Manager with custom endpoint with settings:", settings)
|
logger.info(
|
||||||
|
"Initializing Mem0Manager with custom endpoint",
|
||||||
|
model=settings.default_model,
|
||||||
|
embedding_model=settings.embedding_model,
|
||||||
|
embedding_dims=settings.embedding_dims,
|
||||||
|
qdrant_host=settings.qdrant_host,
|
||||||
|
neo4j_uri=settings.neo4j_uri,
|
||||||
|
)
|
||||||
config = {
|
config = {
|
||||||
|
"version": "v1.1",
|
||||||
"enable_graph": True,
|
"enable_graph": True,
|
||||||
"llm": {
|
"llm": {
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
|
|
@ -46,17 +78,16 @@ class Mem0Manager:
|
||||||
"api_key": settings.openai_api_key,
|
"api_key": settings.openai_api_key,
|
||||||
"openai_base_url": settings.openai_base_url,
|
"openai_base_url": settings.openai_base_url,
|
||||||
"temperature": 0.1,
|
"temperature": 0.1,
|
||||||
"top_p": None # Don't use top_p with Claude models
|
"top_p": None,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"embedder": {
|
"embedder": {
|
||||||
"provider": "ollama",
|
"provider": "ollama",
|
||||||
"config": {
|
"config": {
|
||||||
"model": "qwen3-embedding:4b-q8_0",
|
"model": settings.embedding_model,
|
||||||
# "api_key": settings.embedder_api_key,
|
"ollama_base_url": settings.ollama_base_url,
|
||||||
"ollama_base_url": "http://172.17.0.1:11434",
|
"embedding_dims": settings.embedding_dims,
|
||||||
"embedding_dims": 2560
|
},
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"vector_store": {
|
"vector_store": {
|
||||||
"provider": "qdrant",
|
"provider": "qdrant",
|
||||||
|
|
@ -64,38 +95,39 @@ class Mem0Manager:
|
||||||
"collection_name": settings.qdrant_collection_name,
|
"collection_name": settings.qdrant_collection_name,
|
||||||
"host": settings.qdrant_host,
|
"host": settings.qdrant_host,
|
||||||
"port": settings.qdrant_port,
|
"port": settings.qdrant_port,
|
||||||
"embedding_model_dims": 2560,
|
"embedding_model_dims": settings.embedding_dims,
|
||||||
"on_disk": True
|
"on_disk": True,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"graph_store": {
|
"graph_store": {
|
||||||
"provider": "neo4j",
|
"provider": "neo4j",
|
||||||
"config": {
|
"config": {
|
||||||
"url": settings.neo4j_uri,
|
"url": settings.neo4j_uri,
|
||||||
"username": settings.neo4j_username,
|
"username": settings.neo4j_username,
|
||||||
"password": settings.neo4j_password
|
"password": settings.neo4j_password,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"reranker": {
|
"reranker": {
|
||||||
"provider": "cohere",
|
"provider": "cohere",
|
||||||
"config": {
|
"config": {
|
||||||
"api_key": settings.cohere_api_key,
|
"api_key": settings.cohere_api_key,
|
||||||
"model": "rerank-english-v3.0",
|
"model": "rerank-english-v3.0",
|
||||||
"top_n": 10
|
"top_n": 10,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
self.memory = Memory.from_config(config)
|
self.memory = Memory.from_config(config)
|
||||||
self.openai_client = OpenAI(
|
self.openai_client = OpenAI(
|
||||||
api_key=settings.openai_api_key,
|
api_key=settings.openai_api_key,
|
||||||
base_url=settings.openai_base_url
|
base_url=settings.openai_base_url,
|
||||||
|
timeout=60.0, # 60 second timeout for LLM calls
|
||||||
|
max_retries=2, # Retry failed requests up to 2 times
|
||||||
)
|
)
|
||||||
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
|
logger.info("Initialized ultra-minimal Mem0Manager with custom endpoint")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Pure passthrough methods - no custom logic
|
# Pure passthrough methods - no custom logic
|
||||||
|
@db_retry
|
||||||
@timed("add_memories")
|
@timed("add_memories")
|
||||||
async def add_memories(
|
async def add_memories(
|
||||||
self,
|
self,
|
||||||
|
|
@ -103,14 +135,14 @@ class Mem0Manager:
|
||||||
user_id: Optional[str] = "default",
|
user_id: Optional[str] = "default",
|
||||||
agent_id: Optional[str] = None,
|
agent_id: Optional[str] = None,
|
||||||
run_id: Optional[str] = None,
|
run_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Add memories - simplified native Mem0 pattern (10 lines vs 45)."""
|
"""Add memories - simplified native Mem0 pattern (10 lines vs 45)."""
|
||||||
try:
|
try:
|
||||||
# Convert ChatMessage objects to dict if needed
|
# Convert ChatMessage objects to dict if needed
|
||||||
formatted_messages = []
|
formatted_messages = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if hasattr(msg, 'dict'):
|
if hasattr(msg, "dict"):
|
||||||
formatted_messages.append(msg.dict())
|
formatted_messages.append(msg.dict())
|
||||||
else:
|
else:
|
||||||
formatted_messages.append(msg)
|
formatted_messages.append(msg)
|
||||||
|
|
@ -123,26 +155,35 @@ class Mem0Manager:
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
"source": "chat_conversation",
|
"source": "chat_conversation",
|
||||||
"message_count": len(formatted_messages),
|
"message_count": len(formatted_messages),
|
||||||
"auto_generated": True
|
"auto_generated": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Merge user metadata with auto metadata (user metadata takes precedence)
|
# Merge user metadata with auto metadata (user metadata takes precedence)
|
||||||
enhanced_metadata = {**auto_metadata, **combined_metadata}
|
enhanced_metadata = {**auto_metadata, **combined_metadata}
|
||||||
|
|
||||||
# Direct Mem0 add with enhanced metadata
|
# Direct Mem0 add with enhanced metadata
|
||||||
result = self.memory.add(formatted_messages, user_id=user_id,
|
result = self.memory.add(
|
||||||
agent_id=agent_id, run_id=run_id,
|
formatted_messages,
|
||||||
metadata=enhanced_metadata)
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
run_id=run_id,
|
||||||
|
metadata=enhanced_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"added_memories": result if isinstance(result, list) else [result],
|
"added_memories": result if isinstance(result, list) else [result],
|
||||||
"message": "Memories added successfully",
|
"message": "Memories added successfully",
|
||||||
"hierarchy": {"user_id": user_id, "agent_id": agent_id, "run_id": run_id}
|
"hierarchy": {
|
||||||
|
"user_id": user_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"run_id": run_id,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding memories: {e}")
|
logger.error(f"Error adding memories: {e}")
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
|
@db_retry
|
||||||
@timed("search_memories")
|
@timed("search_memories")
|
||||||
async def search_memories(
|
async def search_memories(
|
||||||
self,
|
self,
|
||||||
|
|
@ -155,37 +196,79 @@ class Mem0Manager:
|
||||||
# rerank: bool = False,
|
# rerank: bool = False,
|
||||||
# filter_memories: bool = False,
|
# filter_memories: bool = False,
|
||||||
agent_id: Optional[str] = None,
|
agent_id: Optional[str] = None,
|
||||||
run_id: Optional[str] = None
|
run_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Search memories - native Mem0 pattern"""
|
"""Search memories - native Mem0 pattern"""
|
||||||
try:
|
try:
|
||||||
# Minimal empty query protection for API compatibility
|
# Minimal empty query protection for API compatibility
|
||||||
if not query or query.strip() == "":
|
if not query or query.strip() == "":
|
||||||
return {"memories": [], "total_count": 0, "query": query, "note": "Empty query provided, no results returned. Use a specific query to search memories."}
|
return {
|
||||||
|
"memories": [],
|
||||||
|
"total_count": 0,
|
||||||
|
"query": query,
|
||||||
|
"note": "Empty query provided, no results returned. Use a specific query to search memories.",
|
||||||
|
}
|
||||||
# Direct Mem0 search - trust native handling
|
# Direct Mem0 search - trust native handling
|
||||||
result = self.memory.search(query=query, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit, threshold=threshold, filters=filters)
|
result = self.memory.search(
|
||||||
return {"memories": result.get("results", []), "total_count": len(result.get("results", [])), "query": query}
|
query=query,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
run_id=run_id,
|
||||||
|
limit=limit,
|
||||||
|
threshold=threshold,
|
||||||
|
filters=filters,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"memories": result.get("results", []),
|
||||||
|
"total_count": len(result.get("results", [])),
|
||||||
|
"query": query,
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error searching memories: {e}")
|
logger.error(f"Error searching memories: {e}")
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
|
@db_retry
|
||||||
async def get_user_memories(
|
async def get_user_memories(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
agent_id: Optional[str] = None,
|
agent_id: Optional[str] = None,
|
||||||
run_id: Optional[str] = None,
|
run_id: Optional[str] = None,
|
||||||
filters: Optional[Dict[str, Any]] = None
|
filters: Optional[Dict[str, Any]] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Get all memories for a user - native Mem0 pattern."""
|
"""Get all memories for a user - native Mem0 pattern."""
|
||||||
try:
|
try:
|
||||||
# Direct Mem0 get_all call - trust native parameter handling
|
# Direct Mem0 get_all call - trust native parameter handling
|
||||||
result = self.memory.get_all(user_id=user_id, limit=limit, agent_id=agent_id, run_id=run_id, filters=filters)
|
result = self.memory.get_all(
|
||||||
|
user_id=user_id,
|
||||||
|
limit=limit,
|
||||||
|
agent_id=agent_id,
|
||||||
|
run_id=run_id,
|
||||||
|
filters=filters,
|
||||||
|
)
|
||||||
return result.get("results", [])
|
return result.get("results", [])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting user memories: {e}")
|
logger.error(f"Error getting user memories: {e}")
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
|
@db_retry
|
||||||
|
async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get a single memory by ID. Returns None if not found."""
|
||||||
|
try:
|
||||||
|
result = self.memory.get(memory_id=memory_id)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Memory {memory_id} not found or error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def verify_memory_ownership(self, memory_id: str, user_id: str) -> bool:
|
||||||
|
"""Check if a memory belongs to a user. O(1) instead of O(n)."""
|
||||||
|
memory = await self.get_memory(memory_id)
|
||||||
|
if memory is None:
|
||||||
|
return False
|
||||||
|
return memory.get("user_id") == user_id
|
||||||
|
|
||||||
|
@db_retry
|
||||||
@timed("update_memory")
|
@timed("update_memory")
|
||||||
async def update_memory(
|
async def update_memory(
|
||||||
self,
|
self,
|
||||||
|
|
@ -194,15 +277,13 @@ class Mem0Manager:
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Update memory - pure Mem0 passthrough."""
|
"""Update memory - pure Mem0 passthrough."""
|
||||||
try:
|
try:
|
||||||
result = self.memory.update(
|
result = self.memory.update(memory_id=memory_id, data=content)
|
||||||
memory_id=memory_id,
|
|
||||||
data=content
|
|
||||||
)
|
|
||||||
return {"message": "Memory updated successfully", "result": result}
|
return {"message": "Memory updated successfully", "result": result}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating memory: {e}")
|
logger.error(f"Error updating memory: {e}")
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
|
@db_retry
|
||||||
@timed("delete_memory")
|
@timed("delete_memory")
|
||||||
async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
|
async def delete_memory(self, memory_id: str) -> Dict[str, Any]:
|
||||||
"""Delete memory - pure Mem0 passthrough."""
|
"""Delete memory - pure Mem0 passthrough."""
|
||||||
|
|
@ -211,7 +292,7 @@ class Mem0Manager:
|
||||||
return {"message": "Memory deleted successfully"}
|
return {"message": "Memory deleted successfully"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting memory: {e}")
|
logger.error(f"Error deleting memory: {e}")
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]:
|
async def delete_user_memories(self, user_id: Optional[str]) -> Dict[str, Any]:
|
||||||
"""Delete all user memories - pure Mem0 passthrough."""
|
"""Delete all user memories - pure Mem0 passthrough."""
|
||||||
|
|
@ -220,7 +301,7 @@ class Mem0Manager:
|
||||||
return {"message": "All user memories deleted successfully"}
|
return {"message": "All user memories deleted successfully"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting user memories: {e}")
|
logger.error(f"Error deleting user memories: {e}")
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
|
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
|
||||||
"""Get memory change history - pure Mem0 passthrough."""
|
"""Get memory change history - pure Mem0 passthrough."""
|
||||||
|
|
@ -229,22 +310,24 @@ class Mem0Manager:
|
||||||
return {
|
return {
|
||||||
"memory_id": memory_id,
|
"memory_id": memory_id,
|
||||||
"history": history,
|
"history": history,
|
||||||
"message": "Memory history retrieved successfully"
|
"message": "Memory history retrieved successfully",
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting memory history: {e}")
|
logger.error(f"Error getting memory history: {e}")
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
|
async def get_graph_relationships(
|
||||||
async def get_graph_relationships(self, user_id: Optional[str], agent_id: Optional[str], run_id: Optional[str], limit: int = 50) -> Dict[str, Any]:
|
self,
|
||||||
|
user_id: Optional[str],
|
||||||
|
agent_id: Optional[str],
|
||||||
|
run_id: Optional[str],
|
||||||
|
limit: int = 50,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Get graph relationships - using correct Mem0 get_all() method."""
|
"""Get graph relationships - using correct Mem0 get_all() method."""
|
||||||
try:
|
try:
|
||||||
# Use get_all() to retrieve memories with graph relationships
|
# Use get_all() to retrieve memories with graph relationships
|
||||||
result = self.memory.get_all(
|
result = self.memory.get_all(
|
||||||
user_id=user_id,
|
user_id=user_id, agent_id=agent_id, run_id=run_id, limit=limit
|
||||||
agent_id=agent_id,
|
|
||||||
run_id=run_id,
|
|
||||||
limit=limit
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract relationships from Mem0's response structure
|
# Extract relationships from Mem0's response structure
|
||||||
|
|
@ -272,7 +355,7 @@ class Mem0Manager:
|
||||||
"agent_id": agent_id,
|
"agent_id": agent_id,
|
||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
"total_memories": len(result.get("results", [])),
|
"total_memories": len(result.get("results", [])),
|
||||||
"total_relationships": len(relationships)
|
"total_relationships": len(relationships),
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -286,7 +369,7 @@ class Mem0Manager:
|
||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
"total_memories": 0,
|
"total_memories": 0,
|
||||||
"total_relationships": 0,
|
"total_relationships": 0,
|
||||||
"error": str(e)
|
"error": str(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
@timed("chat_with_memory")
|
@timed("chat_with_memory")
|
||||||
|
|
@ -304,49 +387,70 @@ class Mem0Manager:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
total_start_time = time.time()
|
total_start_time = time.time()
|
||||||
print(f"\n🚀 Starting chat request for user: {user_id}")
|
logger.info("Starting chat request", user_id=user_id)
|
||||||
|
|
||||||
# Stage 1: Memory Search
|
|
||||||
search_start_time = time.time()
|
search_start_time = time.time()
|
||||||
search_result = self.memory.search(query=message, user_id=user_id, agent_id=agent_id, run_id=run_id, limit=10, threshold=0.3)
|
search_result = self.memory.search(
|
||||||
|
query=message,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
run_id=run_id,
|
||||||
|
limit=10,
|
||||||
|
threshold=0.3,
|
||||||
|
)
|
||||||
relevant_memories = search_result.get("results", [])
|
relevant_memories = search_result.get("results", [])
|
||||||
memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories)
|
memories_str = "\n".join(
|
||||||
|
f"- {entry['memory']}" for entry in relevant_memories
|
||||||
|
)
|
||||||
search_time = time.time() - search_start_time
|
search_time = time.time() - search_start_time
|
||||||
print(f"🔍 Memory search took: {search_time:.2f}s (found {len(relevant_memories)} memories)")
|
logger.debug(
|
||||||
|
"Memory search completed",
|
||||||
|
search_time_s=round(search_time, 2),
|
||||||
|
memories_found=len(relevant_memories),
|
||||||
|
)
|
||||||
|
|
||||||
# Stage 2: Prepare LLM messages
|
|
||||||
prep_start_time = time.time()
|
prep_start_time = time.time()
|
||||||
system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
|
system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}"
|
||||||
messages = [{"role": "system", "content": system_prompt}]
|
messages = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
# Add conversation context if provided (last 50 messages)
|
|
||||||
if context:
|
if context:
|
||||||
messages.extend(context)
|
messages.extend(context)
|
||||||
print(f"📝 Added {len(context)} context messages")
|
logger.debug("Added context messages", context_count=len(context))
|
||||||
|
|
||||||
# Add current user message
|
|
||||||
messages.append({"role": "user", "content": message})
|
messages.append({"role": "user", "content": message})
|
||||||
prep_time = time.time() - prep_start_time
|
prep_time = time.time() - prep_start_time
|
||||||
print(f"📋 Message preparation took: {prep_time:.3f}s")
|
|
||||||
|
|
||||||
# Stage 3: LLM Call
|
|
||||||
llm_start_time = time.time()
|
llm_start_time = time.time()
|
||||||
response = self.openai_client.chat.completions.create(model=settings.default_model, messages=messages)
|
response = self.openai_client.chat.completions.create(
|
||||||
|
model=settings.default_model, messages=messages
|
||||||
|
)
|
||||||
assistant_response = response.choices[0].message.content
|
assistant_response = response.choices[0].message.content
|
||||||
llm_time = time.time() - llm_start_time
|
llm_time = time.time() - llm_start_time
|
||||||
print(f"🤖 LLM call took: {llm_time:.2f}s (model: {settings.default_model})")
|
logger.debug(
|
||||||
|
"LLM call completed",
|
||||||
|
llm_time_s=round(llm_time, 2),
|
||||||
|
model=settings.default_model,
|
||||||
|
)
|
||||||
|
|
||||||
# Stage 4: Memory Add
|
|
||||||
add_start_time = time.time()
|
add_start_time = time.time()
|
||||||
memory_messages = [{"role": "user", "content": message}, {"role": "assistant", "content": assistant_response}]
|
memory_messages = [
|
||||||
|
{"role": "user", "content": message},
|
||||||
|
{"role": "assistant", "content": assistant_response},
|
||||||
|
]
|
||||||
self.memory.add(memory_messages, user_id=user_id)
|
self.memory.add(memory_messages, user_id=user_id)
|
||||||
add_time = time.time() - add_start_time
|
add_time = time.time() - add_start_time
|
||||||
print(f"💾 Memory add took: {add_time:.2f}s")
|
|
||||||
|
|
||||||
# Total timing summary
|
|
||||||
total_time = time.time() - total_start_time
|
total_time = time.time() - total_start_time
|
||||||
print(f"⏱️ TOTAL: {total_time:.2f}s | Search: {search_time:.2f}s | LLM: {llm_time:.2f}s | Add: {add_time:.2f}s | Prep: {prep_time:.3f}s")
|
logger.info(
|
||||||
print(f"📊 Breakdown: Search {(search_time/total_time)*100:.1f}% | LLM {(llm_time/total_time)*100:.1f}% | Add {(add_time/total_time)*100:.1f}%\n")
|
"Chat request completed",
|
||||||
|
user_id=user_id,
|
||||||
|
total_time_s=round(total_time, 2),
|
||||||
|
search_time_s=round(search_time, 2),
|
||||||
|
llm_time_s=round(llm_time, 2),
|
||||||
|
add_time_s=round(add_time, 2),
|
||||||
|
memories_used=len(relevant_memories),
|
||||||
|
model=settings.default_model,
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"response": assistant_response,
|
"response": assistant_response,
|
||||||
|
|
@ -356,17 +460,22 @@ class Mem0Manager:
|
||||||
"total": round(total_time, 2),
|
"total": round(total_time, 2),
|
||||||
"search": round(search_time, 2),
|
"search": round(search_time, 2),
|
||||||
"llm": round(llm_time, 2),
|
"llm": round(llm_time, 2),
|
||||||
"add": round(add_time, 2)
|
"add": round(add_time, 2),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in chat_with_memory: {e}")
|
logger.error(
|
||||||
|
"Error in chat_with_memory",
|
||||||
|
error=str(e),
|
||||||
|
user_id=user_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"response": "I apologize, but I encountered an error processing your request.",
|
"response": "I apologize, but I encountered an error processing your request.",
|
||||||
"memories_used": 0,
|
"memories_used": 0,
|
||||||
"model_used": None
|
"model_used": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def health_check(self) -> Dict[str, str]:
|
async def health_check(self) -> Dict[str, str]:
|
||||||
|
|
|
||||||
|
|
@ -2,53 +2,115 @@
|
||||||
|
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
# Constants for input validation
|
||||||
|
MAX_MESSAGE_LENGTH = 50000 # ~12k tokens max per message
|
||||||
|
MAX_QUERY_LENGTH = 10000 # ~2.5k tokens max per query
|
||||||
|
MAX_USER_ID_LENGTH = 100 # Reasonable user ID length
|
||||||
|
MAX_MEMORY_ID_LENGTH = 100 # Memory IDs are typically UUIDs
|
||||||
|
MAX_CONTEXT_MESSAGES = 100 # Max conversation context messages
|
||||||
|
USER_ID_PATTERN = r"^[a-zA-Z0-9_\-\.@]+$" # Alphanumeric with common separators
|
||||||
|
|
||||||
|
|
||||||
# Request Models
|
# Request Models
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
"""Chat message structure."""
|
"""Chat message structure."""
|
||||||
role: str = Field(..., description="Message role (user, assistant, system)")
|
|
||||||
content: str = Field(..., description="Message content")
|
role: str = Field(
|
||||||
|
..., max_length=20, description="Message role (user, assistant, system)"
|
||||||
|
)
|
||||||
|
content: str = Field(
|
||||||
|
..., max_length=MAX_MESSAGE_LENGTH, description="Message content"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
"""Ultra-minimal chat request."""
|
"""Ultra-minimal chat request."""
|
||||||
message: str = Field(..., description="User message")
|
|
||||||
user_id: Optional[str] = Field("default", description="User identifier")
|
message: str = Field(..., max_length=MAX_MESSAGE_LENGTH, description="User message")
|
||||||
agent_id: Optional[str] = Field(None, description="Agent identifier")
|
user_id: Optional[str] = Field(
|
||||||
run_id: Optional[str] = Field(None, description="Run identifier")
|
"default",
|
||||||
context: Optional[List[ChatMessage]] = Field(None, description="Previous conversation context")
|
max_length=MAX_USER_ID_LENGTH,
|
||||||
|
pattern=USER_ID_PATTERN,
|
||||||
|
description="User identifier (alphanumeric, _, -, ., @)",
|
||||||
|
)
|
||||||
|
agent_id: Optional[str] = Field(
|
||||||
|
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
|
||||||
|
)
|
||||||
|
run_id: Optional[str] = Field(
|
||||||
|
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
|
||||||
|
)
|
||||||
|
context: Optional[List[ChatMessage]] = Field(
|
||||||
|
None,
|
||||||
|
max_length=MAX_CONTEXT_MESSAGES,
|
||||||
|
description="Previous conversation context (max 100 messages)",
|
||||||
|
)
|
||||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||||
|
|
||||||
|
|
||||||
class MemoryAddRequest(BaseModel):
|
class MemoryAddRequest(BaseModel):
|
||||||
"""Request to add memories with hierarchy support - open-source compatible."""
|
"""Request to add memories with hierarchy support - open-source compatible."""
|
||||||
messages: List[ChatMessage] = Field(..., description="Messages to process")
|
|
||||||
user_id: Optional[str] = Field("default", description="User identifier")
|
messages: List[ChatMessage] = Field(
|
||||||
agent_id: Optional[str] = Field(None, description="Agent identifier")
|
...,
|
||||||
run_id: Optional[str] = Field(None, description="Run identifier")
|
max_length=MAX_CONTEXT_MESSAGES,
|
||||||
|
description="Messages to process (max 100 messages)",
|
||||||
|
)
|
||||||
|
user_id: Optional[str] = Field(
|
||||||
|
"default",
|
||||||
|
max_length=MAX_USER_ID_LENGTH,
|
||||||
|
pattern=USER_ID_PATTERN,
|
||||||
|
description="User identifier",
|
||||||
|
)
|
||||||
|
agent_id: Optional[str] = Field(
|
||||||
|
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
|
||||||
|
)
|
||||||
|
run_id: Optional[str] = Field(
|
||||||
|
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
|
||||||
|
)
|
||||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||||
|
|
||||||
|
|
||||||
class MemorySearchRequest(BaseModel):
|
class MemorySearchRequest(BaseModel):
|
||||||
"""Request to search memories with hierarchy filtering."""
|
"""Request to search memories with hierarchy filtering."""
|
||||||
query: str = Field(..., description="Search query")
|
|
||||||
user_id: Optional[str] = Field("default", description="User identifier")
|
|
||||||
agent_id: Optional[str] = Field(None, description="Agent identifier")
|
|
||||||
run_id: Optional[str] = Field(None, description="Run identifier")
|
|
||||||
limit: int = Field(5, description="Maximum number of results")
|
|
||||||
threshold: Optional[float] = Field(None, description="Minimum relevance score")
|
|
||||||
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters")
|
|
||||||
|
|
||||||
# Hierarchy filters (open-source compatible)
|
query: str = Field(..., max_length=MAX_QUERY_LENGTH, description="Search query")
|
||||||
agent_id: Optional[str] = Field(None, description="Filter by agent identifier")
|
user_id: Optional[str] = Field(
|
||||||
run_id: Optional[str] = Field(None, description="Filter by run identifier")
|
"default",
|
||||||
|
max_length=MAX_USER_ID_LENGTH,
|
||||||
|
pattern=USER_ID_PATTERN,
|
||||||
|
description="User identifier",
|
||||||
|
)
|
||||||
|
agent_id: Optional[str] = Field(
|
||||||
|
None, max_length=MAX_USER_ID_LENGTH, description="Agent identifier"
|
||||||
|
)
|
||||||
|
run_id: Optional[str] = Field(
|
||||||
|
None, max_length=MAX_USER_ID_LENGTH, description="Run identifier"
|
||||||
|
)
|
||||||
|
limit: int = Field(5, ge=1, le=100, description="Maximum number of results (1-100)")
|
||||||
|
threshold: Optional[float] = Field(
|
||||||
|
None, ge=0.0, le=1.0, description="Minimum relevance score (0-1)"
|
||||||
|
)
|
||||||
|
filters: Optional[Dict[str, Any]] = Field(None, description="Additional filters")
|
||||||
|
|
||||||
|
|
||||||
class MemoryUpdateRequest(BaseModel):
|
class MemoryUpdateRequest(BaseModel):
|
||||||
"""Request to update a memory."""
|
"""Request to update a memory."""
|
||||||
memory_id: str = Field(..., description="Memory ID to update")
|
|
||||||
content: str = Field(..., description="New memory content")
|
memory_id: str = Field(
|
||||||
|
..., max_length=MAX_MEMORY_ID_LENGTH, description="Memory ID to update"
|
||||||
|
)
|
||||||
|
user_id: str = Field(
|
||||||
|
...,
|
||||||
|
max_length=MAX_USER_ID_LENGTH,
|
||||||
|
pattern=USER_ID_PATTERN,
|
||||||
|
description="User identifier for ownership verification",
|
||||||
|
)
|
||||||
|
content: str = Field(
|
||||||
|
..., max_length=MAX_MESSAGE_LENGTH, description="New memory content"
|
||||||
|
)
|
||||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata")
|
metadata: Optional[Dict[str, Any]] = Field(None, description="Updated metadata")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -57,19 +119,23 @@ class MemoryUpdateRequest(BaseModel):
|
||||||
|
|
||||||
class MemoryItem(BaseModel):
|
class MemoryItem(BaseModel):
|
||||||
"""Individual memory item."""
|
"""Individual memory item."""
|
||||||
|
|
||||||
id: str = Field(..., description="Memory unique identifier")
|
id: str = Field(..., description="Memory unique identifier")
|
||||||
memory: str = Field(..., description="Memory content")
|
memory: str = Field(..., description="Memory content")
|
||||||
user_id: Optional[str] = Field(None, description="Associated user ID")
|
user_id: Optional[str] = Field(None, description="Associated user ID")
|
||||||
agent_id: Optional[str] = Field(None, description="Associated agent ID")
|
agent_id: Optional[str] = Field(None, description="Associated agent ID")
|
||||||
run_id: Optional[str] = Field(None, description="Associated run ID")
|
run_id: Optional[str] = Field(None, description="Associated run ID")
|
||||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata")
|
metadata: Optional[Dict[str, Any]] = Field(None, description="Memory metadata")
|
||||||
score: Optional[float] = Field(None, description="Relevance score (for search results)")
|
score: Optional[float] = Field(
|
||||||
|
None, description="Relevance score (for search results)"
|
||||||
|
)
|
||||||
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
||||||
updated_at: Optional[str] = Field(None, description="Last update timestamp")
|
updated_at: Optional[str] = Field(None, description="Last update timestamp")
|
||||||
|
|
||||||
|
|
||||||
class MemorySearchResponse(BaseModel):
|
class MemorySearchResponse(BaseModel):
|
||||||
"""Memory search results - pure Mem0 structure."""
|
"""Memory search results - pure Mem0 structure."""
|
||||||
|
|
||||||
memories: List[MemoryItem] = Field(..., description="Found memories")
|
memories: List[MemoryItem] = Field(..., description="Found memories")
|
||||||
total_count: int = Field(..., description="Total number of memories found")
|
total_count: int = Field(..., description="Total number of memories found")
|
||||||
query: str = Field(..., description="Original search query")
|
query: str = Field(..., description="Original search query")
|
||||||
|
|
@ -77,27 +143,37 @@ class MemorySearchResponse(BaseModel):
|
||||||
|
|
||||||
class MemoryAddResponse(BaseModel):
|
class MemoryAddResponse(BaseModel):
|
||||||
"""Response from adding memories - pure Mem0 structure."""
|
"""Response from adding memories - pure Mem0 structure."""
|
||||||
added_memories: List[Dict[str, Any]] = Field(..., description="Memories that were added")
|
|
||||||
|
added_memories: List[Dict[str, Any]] = Field(
|
||||||
|
..., description="Memories that were added"
|
||||||
|
)
|
||||||
message: str = Field(..., description="Success message")
|
message: str = Field(..., description="Success message")
|
||||||
|
|
||||||
|
|
||||||
class GraphRelationship(BaseModel):
|
class GraphRelationship(BaseModel):
|
||||||
"""Graph relationship structure."""
|
"""Graph relationship structure."""
|
||||||
|
|
||||||
source: str = Field(..., description="Source entity")
|
source: str = Field(..., description="Source entity")
|
||||||
relationship: str = Field(..., description="Relationship type")
|
relationship: str = Field(..., description="Relationship type")
|
||||||
target: str = Field(..., description="Target entity")
|
target: str = Field(..., description="Target entity")
|
||||||
properties: Optional[Dict[str, Any]] = Field(None, description="Relationship properties")
|
properties: Optional[Dict[str, Any]] = Field(
|
||||||
|
None, description="Relationship properties"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GraphResponse(BaseModel):
|
class GraphResponse(BaseModel):
|
||||||
"""Graph relationships - pure Mem0 structure."""
|
"""Graph relationships - pure Mem0 structure."""
|
||||||
relationships: List[GraphRelationship] = Field(..., description="Found relationships")
|
|
||||||
|
relationships: List[GraphRelationship] = Field(
|
||||||
|
..., description="Found relationships"
|
||||||
|
)
|
||||||
entities: List[str] = Field(..., description="Unique entities")
|
entities: List[str] = Field(..., description="Unique entities")
|
||||||
user_id: str = Field(..., description="User identifier")
|
user_id: str = Field(..., description="User identifier")
|
||||||
|
|
||||||
|
|
||||||
class HealthResponse(BaseModel):
|
class HealthResponse(BaseModel):
|
||||||
"""Health check response."""
|
"""Health check response."""
|
||||||
|
|
||||||
status: str = Field(..., description="Service status")
|
status: str = Field(..., description="Service status")
|
||||||
services: Dict[str, str] = Field(..., description="Individual service statuses")
|
services: Dict[str, str] = Field(..., description="Individual service statuses")
|
||||||
timestamp: str = Field(..., description="Health check timestamp")
|
timestamp: str = Field(..., description="Health check timestamp")
|
||||||
|
|
@ -105,6 +181,7 @@ class HealthResponse(BaseModel):
|
||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
class ErrorResponse(BaseModel):
|
||||||
"""Error response structure."""
|
"""Error response structure."""
|
||||||
|
|
||||||
error: str = Field(..., description="Error message")
|
error: str = Field(..., description="Error message")
|
||||||
detail: Optional[str] = Field(None, description="Detailed error information")
|
detail: Optional[str] = Field(None, description="Detailed error information")
|
||||||
status_code: int = Field(..., description="HTTP status code")
|
status_code: int = Field(..., description="HTTP status code")
|
||||||
|
|
@ -112,8 +189,10 @@ class ErrorResponse(BaseModel):
|
||||||
|
|
||||||
# Statistics and Monitoring Models
|
# Statistics and Monitoring Models
|
||||||
|
|
||||||
|
|
||||||
class MemoryOperationStats(BaseModel):
|
class MemoryOperationStats(BaseModel):
|
||||||
"""Memory operation statistics."""
|
"""Memory operation statistics."""
|
||||||
|
|
||||||
add: int = Field(..., description="Number of add operations")
|
add: int = Field(..., description="Number of add operations")
|
||||||
search: int = Field(..., description="Number of search operations")
|
search: int = Field(..., description="Number of search operations")
|
||||||
update: int = Field(..., description="Number of update operations")
|
update: int = Field(..., description="Number of update operations")
|
||||||
|
|
@ -122,19 +201,111 @@ class MemoryOperationStats(BaseModel):
|
||||||
|
|
||||||
class GlobalStatsResponse(BaseModel):
|
class GlobalStatsResponse(BaseModel):
|
||||||
"""Global application statistics."""
|
"""Global application statistics."""
|
||||||
|
|
||||||
total_memories: int = Field(..., description="Total memories across all users")
|
total_memories: int = Field(..., description="Total memories across all users")
|
||||||
total_users: int = Field(..., description="Total number of users")
|
total_users: int = Field(..., description="Total number of users")
|
||||||
api_calls_today: int = Field(..., description="Total API calls today")
|
api_calls_today: int = Field(..., description="Total API calls today")
|
||||||
avg_response_time_ms: float = Field(..., description="Average response time in milliseconds")
|
avg_response_time_ms: float = Field(
|
||||||
memory_operations: MemoryOperationStats = Field(..., description="Memory operation breakdown")
|
..., description="Average response time in milliseconds"
|
||||||
|
)
|
||||||
|
memory_operations: MemoryOperationStats = Field(
|
||||||
|
..., description="Memory operation breakdown"
|
||||||
|
)
|
||||||
uptime_seconds: float = Field(..., description="Application uptime in seconds")
|
uptime_seconds: float = Field(..., description="Application uptime in seconds")
|
||||||
|
|
||||||
|
|
||||||
class UserStatsResponse(BaseModel):
|
class UserStatsResponse(BaseModel):
|
||||||
"""User-specific statistics."""
|
"""User-specific statistics."""
|
||||||
|
|
||||||
user_id: str = Field(..., description="User identifier")
|
user_id: str = Field(..., description="User identifier")
|
||||||
memory_count: int = Field(..., description="Number of memories for this user")
|
memory_count: int = Field(..., description="Number of memories for this user")
|
||||||
relationship_count: int = Field(..., description="Number of graph relationships for this user")
|
relationship_count: int = Field(
|
||||||
|
..., description="Number of graph relationships for this user"
|
||||||
|
)
|
||||||
last_activity: Optional[str] = Field(None, description="Last activity timestamp")
|
last_activity: Optional[str] = Field(None, description="Last activity timestamp")
|
||||||
api_calls_today: int = Field(..., description="API calls made by this user today")
|
api_calls_today: int = Field(..., description="API calls made by this user today")
|
||||||
avg_response_time_ms: float = Field(..., description="Average response time for this user's requests")
|
avg_response_time_ms: float = Field(
|
||||||
|
..., description="Average response time for this user's requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# OpenAI-Compatible API Models
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIMessage(BaseModel):
|
||||||
|
"""OpenAI message format."""
|
||||||
|
|
||||||
|
role: str = Field(..., description="Message role (system, user, assistant)")
|
||||||
|
content: str = Field(..., description="Message content")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatCompletionRequest(BaseModel):
|
||||||
|
"""OpenAI chat completion request format."""
|
||||||
|
|
||||||
|
model: str = Field(..., description="Model to use (will use configured default)")
|
||||||
|
messages: List[Dict[str, str]] = Field(..., description="List of messages")
|
||||||
|
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
|
||||||
|
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
|
||||||
|
stream: Optional[bool] = Field(False, description="Whether to stream responses")
|
||||||
|
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
|
||||||
|
n: Optional[int] = Field(1, description="Number of completions to generate")
|
||||||
|
stop: Optional[List[str]] = Field(None, description="Stop sequences")
|
||||||
|
presence_penalty: Optional[float] = Field(0, description="Presence penalty")
|
||||||
|
frequency_penalty: Optional[float] = Field(0, description="Frequency penalty")
|
||||||
|
user: Optional[str] = Field(
|
||||||
|
None, description="User identifier (ignored, uses API key)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIUsage(BaseModel):
|
||||||
|
"""Token usage information."""
|
||||||
|
|
||||||
|
prompt_tokens: int = Field(..., description="Tokens in the prompt")
|
||||||
|
completion_tokens: int = Field(..., description="Tokens in the completion")
|
||||||
|
total_tokens: int = Field(..., description="Total tokens used")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChoiceMessage(BaseModel):
|
||||||
|
"""Message in a choice."""
|
||||||
|
|
||||||
|
role: str = Field(..., description="Role of the message")
|
||||||
|
content: str = Field(..., description="Content of the message")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChoice(BaseModel):
|
||||||
|
"""Individual completion choice."""
|
||||||
|
|
||||||
|
index: int = Field(..., description="Choice index")
|
||||||
|
message: OpenAIChoiceMessage = Field(..., description="Message content")
|
||||||
|
finish_reason: str = Field(..., description="Reason for completion finish")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatCompletionResponse(BaseModel):
|
||||||
|
"""OpenAI chat completion response format."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="Unique completion ID")
|
||||||
|
object: str = Field(default="chat.completion", description="Object type")
|
||||||
|
created: int = Field(..., description="Unix timestamp of creation")
|
||||||
|
model: str = Field(..., description="Model used for completion")
|
||||||
|
choices: List[OpenAIChoice] = Field(..., description="List of completion choices")
|
||||||
|
usage: Optional[OpenAIUsage] = Field(None, description="Token usage information")
|
||||||
|
|
||||||
|
|
||||||
|
# Streaming-specific models
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIStreamDelta(BaseModel):
|
||||||
|
"""Delta content in a streaming chunk."""
|
||||||
|
|
||||||
|
role: Optional[str] = Field(None, description="Role (only in first chunk)")
|
||||||
|
content: Optional[str] = Field(None, description="Incremental content")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIStreamChoice(BaseModel):
|
||||||
|
"""Individual streaming choice."""
|
||||||
|
|
||||||
|
index: int = Field(..., description="Choice index")
|
||||||
|
delta: OpenAIStreamDelta = Field(..., description="Delta content")
|
||||||
|
finish_reason: Optional[str] = Field(
|
||||||
|
None, description="Reason for completion finish"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ ollama
|
||||||
# Utilities
|
# Utilities
|
||||||
pydantic
|
pydantic
|
||||||
pydantic-settings
|
pydantic-settings
|
||||||
|
tenacity
|
||||||
python-dotenv
|
python-dotenv
|
||||||
httpx
|
httpx
|
||||||
aiofiles
|
aiofiles
|
||||||
|
|
@ -31,3 +32,9 @@ python-json-logger
|
||||||
# CORS and Security
|
# CORS and Security
|
||||||
python-jose[cryptography]
|
python-jose[cryptography]
|
||||||
passlib[bcrypt]
|
passlib[bcrypt]
|
||||||
|
|
||||||
|
# Rate Limiting
|
||||||
|
slowapi
|
||||||
|
|
||||||
|
# MCP Server
|
||||||
|
mcp[server]>=1.0.0
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ services:
|
||||||
container_name: mem0-qdrant
|
container_name: mem0-qdrant
|
||||||
expose:
|
expose:
|
||||||
- "6333"
|
- "6333"
|
||||||
|
networks:
|
||||||
|
- mem0_network
|
||||||
volumes:
|
volumes:
|
||||||
- qdrant_data:/qdrant/storage
|
- qdrant_data:/qdrant/storage
|
||||||
command: >
|
command: >
|
||||||
|
|
@ -32,6 +34,8 @@ services:
|
||||||
expose:
|
expose:
|
||||||
- "7474" # HTTP - Internal only
|
- "7474" # HTTP - Internal only
|
||||||
- "7687" # Bolt - Internal only
|
- "7687" # Bolt - Internal only
|
||||||
|
networks:
|
||||||
|
- mem0_network
|
||||||
volumes:
|
volumes:
|
||||||
- neo4j_data:/data
|
- neo4j_data:/data
|
||||||
- neo4j_logs:/logs
|
- neo4j_logs:/logs
|
||||||
|
|
@ -65,8 +69,14 @@ services:
|
||||||
CORS_ORIGINS: ${CORS_ORIGINS:-http://localhost:3000}
|
CORS_ORIGINS: ${CORS_ORIGINS:-http://localhost:3000}
|
||||||
DEFAULT_MODEL: ${DEFAULT_MODEL:-claude-sonnet-4}
|
DEFAULT_MODEL: ${DEFAULT_MODEL:-claude-sonnet-4}
|
||||||
API_KEYS: ${API_KEYS:-{}}
|
API_KEYS: ${API_KEYS:-{}}
|
||||||
ports:
|
OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://host.docker.internal:11434}
|
||||||
- "${BACKEND_PORT:-8000}:8000"
|
EMBEDDING_MODEL: ${EMBEDDING_MODEL:-nomic-embed-text}
|
||||||
|
EMBEDDING_DIMS: ${EMBEDDING_DIMS:-2560}
|
||||||
|
expose:
|
||||||
|
- "8000"
|
||||||
|
networks:
|
||||||
|
- npm_network
|
||||||
|
- mem0_network
|
||||||
depends_on:
|
depends_on:
|
||||||
qdrant:
|
qdrant:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|
@ -75,7 +85,8 @@ services:
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
volumes:
|
volumes:
|
||||||
- ./backend:/app
|
- ./backend:/app
|
||||||
command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
- ./frontend:/app/frontend
|
||||||
|
command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
qdrant_data:
|
qdrant_data:
|
||||||
|
|
@ -85,5 +96,6 @@ volumes:
|
||||||
neo4j_plugins:
|
neo4j_plugins:
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
default:
|
mem0_network:
|
||||||
name: mem0-network
|
npm_network:
|
||||||
|
external: true
|
||||||
|
|
|
||||||
|
|
@ -18,12 +18,106 @@
|
||||||
display: flex;
|
display: flex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Login Screen */
|
||||||
|
.login-screen {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
width: 100%;
|
||||||
|
height: 100vh;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-screen.hidden {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-box {
|
||||||
|
background: white;
|
||||||
|
padding: 40px;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.2);
|
||||||
|
width: 100%;
|
||||||
|
max-width: 400px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-box h1 {
|
||||||
|
margin-bottom: 10px;
|
||||||
|
color: #333;
|
||||||
|
font-size: 28px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-box p {
|
||||||
|
color: #666;
|
||||||
|
font-size: 14px;
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 30px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-box input {
|
||||||
|
width: 100%;
|
||||||
|
padding: 14px;
|
||||||
|
border: 2px solid #e0e0e0;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
outline: none;
|
||||||
|
transition: border-color 0.3s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-box input:focus {
|
||||||
|
border-color: #667eea;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-box button {
|
||||||
|
width: 100%;
|
||||||
|
padding: 14px;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 16px;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: transform 0.2s, opacity 0.3s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-box button:hover {
|
||||||
|
transform: translateY(-2px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-box button:disabled {
|
||||||
|
opacity: 0.6;
|
||||||
|
cursor: not-allowed;
|
||||||
|
transform: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-error {
|
||||||
|
background: #ffe6e6;
|
||||||
|
border: 1px solid #ffcccc;
|
||||||
|
color: #cc0000;
|
||||||
|
padding: 12px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
font-size: 14px;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.login-error.show {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
.container {
|
.container {
|
||||||
display: flex;
|
display: flex;
|
||||||
width: 100%;
|
width: 100%;
|
||||||
height: 100vh;
|
height: 100vh;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.container.hidden {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
/* Chat Section */
|
/* Chat Section */
|
||||||
.chat-section {
|
.chat-section {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
|
|
@ -58,7 +152,12 @@
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.clear-chat-btn {
|
.header-buttons {
|
||||||
|
display: flex;
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.clear-chat-btn, .logout-btn {
|
||||||
background: #f8f9fa;
|
background: #f8f9fa;
|
||||||
color: #666;
|
color: #666;
|
||||||
border: 1px solid #e0e0e0;
|
border: 1px solid #e0e0e0;
|
||||||
|
|
@ -73,16 +172,27 @@
|
||||||
transition: all 0.2s ease;
|
transition: all 0.2s ease;
|
||||||
}
|
}
|
||||||
|
|
||||||
.clear-chat-btn:hover {
|
.clear-chat-btn:hover, .logout-btn:hover {
|
||||||
background: #e9ecef;
|
background: #e9ecef;
|
||||||
border-color: #ced4da;
|
border-color: #ced4da;
|
||||||
color: #495057;
|
color: #495057;
|
||||||
}
|
}
|
||||||
|
|
||||||
.clear-chat-btn:active {
|
.clear-chat-btn:active, .logout-btn:active {
|
||||||
background: #dee2e6;
|
background: #dee2e6;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.logout-btn {
|
||||||
|
background: #fff3cd;
|
||||||
|
border-color: #ffc107;
|
||||||
|
color: #856404;
|
||||||
|
}
|
||||||
|
|
||||||
|
.logout-btn:hover {
|
||||||
|
background: #ffe69c;
|
||||||
|
border-color: #ffb300;
|
||||||
|
}
|
||||||
|
|
||||||
.chat-messages {
|
.chat-messages {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
|
|
@ -281,17 +391,42 @@
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div class="container">
|
<!-- Login Screen -->
|
||||||
|
<div class="login-screen" id="loginScreen">
|
||||||
|
<div class="login-box">
|
||||||
|
<h1>🧠 Mem0 Chat</h1>
|
||||||
|
<p>Enter your API key to access your memory-powered assistant</p>
|
||||||
|
|
||||||
|
<div class="login-error" id="loginError"></div>
|
||||||
|
|
||||||
|
<input
|
||||||
|
type="password"
|
||||||
|
id="apiKeyInput"
|
||||||
|
placeholder="Enter your API key (e.g., sk-xxxxx)"
|
||||||
|
autocomplete="off"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<button id="loginButton">Connect</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Main Chat Interface (hidden initially) -->
|
||||||
|
<div class="container hidden" id="mainContainer">
|
||||||
<!-- Chat Section -->
|
<!-- Chat Section -->
|
||||||
<div class="chat-section">
|
<div class="chat-section">
|
||||||
<div class="chat-header">
|
<div class="chat-header">
|
||||||
<div class="chat-header-content">
|
<div class="chat-header-content">
|
||||||
<h1>What can I help you with?</h1>
|
<h1>What can I help you with?</h1>
|
||||||
<p>Chat with your memories - User: pratik</p>
|
<p>Chat with your memories - User: <span id="currentUser">...</span></p>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="header-buttons">
|
||||||
<button class="clear-chat-btn" id="clearChatBtn" title="Clear chat history">
|
<button class="clear-chat-btn" id="clearChatBtn" title="Clear chat history">
|
||||||
🗑️ Clear Chat
|
🗑️ Clear Chat
|
||||||
</button>
|
</button>
|
||||||
|
<button class="logout-btn" id="logoutBtn" title="Logout">
|
||||||
|
🚪 Logout
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="chat-messages" id="chatMessages">
|
<div class="chat-messages" id="chatMessages">
|
||||||
|
|
@ -319,10 +454,20 @@
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
// Configuration
|
// Configuration
|
||||||
const API_BASE = 'http://localhost:8000';
|
const API_BASE = window.location.origin;
|
||||||
const USER_ID = 'pratik';
|
|
||||||
|
// State
|
||||||
|
let API_KEY = null;
|
||||||
|
let USER_ID = null;
|
||||||
|
|
||||||
// DOM Elements
|
// DOM Elements
|
||||||
|
const loginScreen = document.getElementById('loginScreen');
|
||||||
|
const mainContainer = document.getElementById('mainContainer');
|
||||||
|
const apiKeyInput = document.getElementById('apiKeyInput');
|
||||||
|
const loginButton = document.getElementById('loginButton');
|
||||||
|
const loginError = document.getElementById('loginError');
|
||||||
|
const logoutBtn = document.getElementById('logoutBtn');
|
||||||
|
const currentUser = document.getElementById('currentUser');
|
||||||
const chatMessages = document.getElementById('chatMessages');
|
const chatMessages = document.getElementById('chatMessages');
|
||||||
const messageInput = document.getElementById('messageInput');
|
const messageInput = document.getElementById('messageInput');
|
||||||
const sendButton = document.getElementById('sendButton');
|
const sendButton = document.getElementById('sendButton');
|
||||||
|
|
@ -336,19 +481,143 @@
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
document.addEventListener('DOMContentLoaded', function() {
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
loadChatHistory();
|
// Check if already logged in
|
||||||
loadMemories();
|
const savedApiKey = localStorage.getItem('apiKey');
|
||||||
|
const savedUserId = localStorage.getItem('userId');
|
||||||
|
|
||||||
|
if (savedApiKey && savedUserId) {
|
||||||
|
// Auto-login with saved credentials
|
||||||
|
API_KEY = savedApiKey;
|
||||||
|
USER_ID = savedUserId;
|
||||||
|
showMainInterface();
|
||||||
|
}
|
||||||
|
|
||||||
// Event listeners
|
// Event listeners
|
||||||
|
loginButton.addEventListener('click', handleLogin);
|
||||||
|
apiKeyInput.addEventListener('keydown', (e) => {
|
||||||
|
if (e.key === 'Enter') handleLogin();
|
||||||
|
});
|
||||||
|
logoutBtn.addEventListener('click', handleLogout);
|
||||||
sendButton.addEventListener('click', sendMessage);
|
sendButton.addEventListener('click', sendMessage);
|
||||||
messageInput.addEventListener('keydown', handleKeyDown);
|
messageInput.addEventListener('keydown', handleKeyDown);
|
||||||
messageInput.addEventListener('input', autoResizeTextarea);
|
messageInput.addEventListener('input', autoResizeTextarea);
|
||||||
refreshButton.addEventListener('click', loadMemories);
|
refreshButton.addEventListener('click', loadMemories);
|
||||||
clearChatBtn.addEventListener('click', clearChatWithConfirmation);
|
clearChatBtn.addEventListener('click', clearChatWithConfirmation);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Handle login
|
||||||
|
async function handleLogin() {
|
||||||
|
const apiKey = apiKeyInput.value.trim();
|
||||||
|
|
||||||
|
if (!apiKey) {
|
||||||
|
showLoginError('Please enter an API key');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
loginButton.disabled = true;
|
||||||
|
loginButton.textContent = 'Verifying...';
|
||||||
|
hideLoginError();
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Verify API key by calling /health with auth
|
||||||
|
const response = await fetch(`${API_BASE}/health`, {
|
||||||
|
headers: {
|
||||||
|
'X-API-Key': apiKey
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Invalid API key');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get user_id by trying to call a test endpoint
|
||||||
|
// We'll use /models since it doesn't require auth parameters
|
||||||
|
const userResponse = await fetch(`${API_BASE}/v1/chat/completions`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'X-API-Key': apiKey
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: [{ role: 'user', content: 'test' }],
|
||||||
|
stream: false
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!userResponse.ok) {
|
||||||
|
throw new Error('Failed to verify user');
|
||||||
|
}
|
||||||
|
|
||||||
|
// API key is valid, extract user_id from auth_service mapping
|
||||||
|
// We'll store both and use a simple username extraction
|
||||||
|
API_KEY = apiKey;
|
||||||
|
|
||||||
|
// Try to extract username from API key (e.g., sk-alice -> alice)
|
||||||
|
if (apiKey.startsWith('sk-')) {
|
||||||
|
const parts = apiKey.substring(3).split('-');
|
||||||
|
USER_ID = parts[0]; // Get first part after sk-
|
||||||
|
} else {
|
||||||
|
USER_ID = 'user';
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save to localStorage
|
||||||
|
localStorage.setItem('apiKey', API_KEY);
|
||||||
|
localStorage.setItem('userId', USER_ID);
|
||||||
|
|
||||||
|
// Show main interface
|
||||||
|
showMainInterface();
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Login error:', error);
|
||||||
|
showLoginError('Invalid API key. Please check and try again.');
|
||||||
|
loginButton.disabled = false;
|
||||||
|
loginButton.textContent = 'Connect';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show main interface
|
||||||
|
function showMainInterface() {
|
||||||
|
loginScreen.classList.add('hidden');
|
||||||
|
mainContainer.classList.remove('hidden');
|
||||||
|
currentUser.textContent = USER_ID;
|
||||||
|
|
||||||
|
// Load data
|
||||||
|
loadChatHistory();
|
||||||
|
loadMemories();
|
||||||
|
|
||||||
// Initialize textarea height
|
// Initialize textarea height
|
||||||
autoResizeTextarea();
|
autoResizeTextarea();
|
||||||
});
|
messageInput.focus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle logout
|
||||||
|
function handleLogout() {
|
||||||
|
if (confirm('Are you sure you want to logout?')) {
|
||||||
|
// Clear credentials
|
||||||
|
localStorage.removeItem('apiKey');
|
||||||
|
localStorage.removeItem('userId');
|
||||||
|
API_KEY = null;
|
||||||
|
USER_ID = null;
|
||||||
|
|
||||||
|
// Show login screen
|
||||||
|
mainContainer.classList.add('hidden');
|
||||||
|
loginScreen.classList.remove('hidden');
|
||||||
|
apiKeyInput.value = '';
|
||||||
|
hideLoginError();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show login error
|
||||||
|
function showLoginError(message) {
|
||||||
|
loginError.textContent = message;
|
||||||
|
loginError.classList.add('show');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hide login error
|
||||||
|
function hideLoginError() {
|
||||||
|
loginError.classList.remove('show');
|
||||||
|
}
|
||||||
|
|
||||||
// Load chat history from localStorage
|
// Load chat history from localStorage
|
||||||
function loadChatHistory() {
|
function loadChatHistory() {
|
||||||
|
|
@ -419,6 +688,7 @@
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
'X-API-Key': API_KEY
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
message: message,
|
message: message,
|
||||||
|
|
@ -460,7 +730,11 @@
|
||||||
// Load memories from backend
|
// Load memories from backend
|
||||||
async function loadMemories() {
|
async function loadMemories() {
|
||||||
try{
|
try{
|
||||||
const response = await fetch(`${API_BASE}/memories/${USER_ID}?limit=50`);
|
const response = await fetch(`${API_BASE}/memories/${USER_ID}?limit=50`, {
|
||||||
|
headers: {
|
||||||
|
'X-API-Key': API_KEY
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(`HTTP error! status: ${response.status}`);
|
throw new Error(`HTTP error! status: ${response.status}`);
|
||||||
|
|
@ -512,8 +786,11 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${API_BASE}/memories/${memoryId}`, {
|
const response = await fetch(`${API_BASE}/memories/${memoryId}?user_id=${USER_ID}`, {
|
||||||
method: 'DELETE'
|
method: 'DELETE',
|
||||||
|
headers: {
|
||||||
|
'X-API-Key': API_KEY
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
|
|
|
||||||
122
setup.sh
Executable file
122
setup.sh
Executable file
|
|
@ -0,0 +1,122 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
COMPOSE_PROJECT_NAME="${COMPOSE_PROJECT_NAME:-mem0}"
|
||||||
|
VOLUMES=("qdrant_data" "neo4j_data" "neo4j_logs" "neo4j_import" "neo4j_plugins")
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo " Mem0 Setup Script"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
check_volumes_exist() {
|
||||||
|
local exists=false
|
||||||
|
for vol in "${VOLUMES[@]}"; do
|
||||||
|
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
|
||||||
|
if docker volume ls -q | grep -q "^${full_name}$"; then
|
||||||
|
exists=true
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "$exists"
|
||||||
|
}
|
||||||
|
|
||||||
|
list_existing_volumes() {
|
||||||
|
echo "Existing volumes:"
|
||||||
|
for vol in "${VOLUMES[@]}"; do
|
||||||
|
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
|
||||||
|
if docker volume ls -q | grep -q "^${full_name}$"; then
|
||||||
|
size=$(docker system df -v 2>/dev/null | grep "$full_name" | awk '{print $4}' || echo "unknown")
|
||||||
|
echo " - $full_name ($size)"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
remove_volumes() {
|
||||||
|
echo "Stopping containers..."
|
||||||
|
docker compose down 2>/dev/null || true
|
||||||
|
|
||||||
|
echo "Removing volumes..."
|
||||||
|
for vol in "${VOLUMES[@]}"; do
|
||||||
|
full_name="${COMPOSE_PROJECT_NAME}_${vol}"
|
||||||
|
if docker volume ls -q | grep -q "^${full_name}$"; then
|
||||||
|
echo " Removing $full_name..."
|
||||||
|
docker volume rm "$full_name" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "Volumes removed."
|
||||||
|
}
|
||||||
|
|
||||||
|
build_and_start() {
|
||||||
|
echo ""
|
||||||
|
echo "Building containers..."
|
||||||
|
docker compose build
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Starting services..."
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Waiting for services to be healthy..."
|
||||||
|
sleep 5
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Checking health..."
|
||||||
|
curl -s http://localhost:8000/health 2>/dev/null | jq . || echo "Health check not available yet. Services may still be starting."
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo " Setup Complete!"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
echo "Services:"
|
||||||
|
echo " - Backend API: http://localhost:8000"
|
||||||
|
echo " - API Docs: http://localhost:8000/docs"
|
||||||
|
echo " - Health Check: http://localhost:8000/health"
|
||||||
|
echo ""
|
||||||
|
echo "Logs: docker compose logs -f backend"
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ "$(check_volumes_exist)" = "true" ]; then
|
||||||
|
echo "Existing data volumes detected!"
|
||||||
|
echo ""
|
||||||
|
list_existing_volumes
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " 1) Keep existing data and start services"
|
||||||
|
echo " 2) Reset everything (DELETE ALL DATA) and start fresh"
|
||||||
|
echo " 3) Exit"
|
||||||
|
echo ""
|
||||||
|
read -p "Choose an option [1/2/3]: " choice
|
||||||
|
|
||||||
|
case $choice in
|
||||||
|
1)
|
||||||
|
echo ""
|
||||||
|
echo "Keeping existing data..."
|
||||||
|
build_and_start
|
||||||
|
;;
|
||||||
|
2)
|
||||||
|
echo ""
|
||||||
|
read -p "Are you sure you want to DELETE ALL DATA? Type 'yes' to confirm: " confirm
|
||||||
|
if [ "$confirm" = "yes" ]; then
|
||||||
|
remove_volumes
|
||||||
|
build_and_start
|
||||||
|
else
|
||||||
|
echo "Aborted."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
3)
|
||||||
|
echo "Exiting."
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Invalid option. Exiting."
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
else
|
||||||
|
echo "No existing volumes found. Starting fresh setup..."
|
||||||
|
build_and_start
|
||||||
|
fi
|
||||||
|
|
@ -19,13 +19,24 @@ import time
|
||||||
BASE_URL = "http://localhost:8000"
|
BASE_URL = "http://localhost:8000"
|
||||||
TEST_USER = f"test_user_{int(datetime.now().timestamp())}"
|
TEST_USER = f"test_user_{int(datetime.now().timestamp())}"
|
||||||
|
|
||||||
|
# API Key for authentication - set via environment or use default test key
|
||||||
|
import os
|
||||||
|
|
||||||
|
API_KEY = os.environ.get("MEM0_API_KEY", "test-api-key")
|
||||||
|
AUTH_HEADERS = {"X-API-Key": API_KEY}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Mem0 Integration Tests - Real API Testing (Zero Mocking)",
|
description="Mem0 Integration Tests - Real API Testing (Zero Mocking)",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
"-v",
|
||||||
|
action="store_true",
|
||||||
|
help="Show detailed output and API responses",
|
||||||
)
|
)
|
||||||
parser.add_argument("--verbose", "-v", action="store_true",
|
|
||||||
help="Show detailed output and API responses")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
verbose = args.verbose
|
verbose = args.verbose
|
||||||
|
|
@ -39,6 +50,9 @@ def main():
|
||||||
# Test sequence - order matters for data dependencies
|
# Test sequence - order matters for data dependencies
|
||||||
tests = [
|
tests = [
|
||||||
test_health_check,
|
test_health_check,
|
||||||
|
test_auth_required_endpoints,
|
||||||
|
test_ownership_verification,
|
||||||
|
test_request_size_limit,
|
||||||
test_empty_search_protection,
|
test_empty_search_protection,
|
||||||
test_add_memories_with_hierarchy,
|
test_add_memories_with_hierarchy,
|
||||||
test_search_memories_basic,
|
test_search_memories_basic,
|
||||||
|
|
@ -51,7 +65,7 @@ def main():
|
||||||
test_graph_relationships,
|
test_graph_relationships,
|
||||||
test_delete_specific_memory,
|
test_delete_specific_memory,
|
||||||
test_delete_all_user_memories,
|
test_delete_all_user_memories,
|
||||||
test_cleanup_verification
|
test_cleanup_verification,
|
||||||
]
|
]
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
@ -82,6 +96,7 @@ def main():
|
||||||
print("❌ Some tests failed! Check the output above.")
|
print("❌ Some tests failed! Check the output above.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def run_test(name, test_func, verbose):
|
def run_test(name, test_func, verbose):
|
||||||
"""Run a single test with error handling"""
|
"""Run a single test with error handling"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -102,6 +117,7 @@ def run_test(name, test_func, verbose):
|
||||||
print(f"❌ {name}: {e}")
|
print(f"❌ {name}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def log_response(response, verbose, context=""):
|
def log_response(response, verbose, context=""):
|
||||||
"""Log API response details if verbose"""
|
"""Log API response details if verbose"""
|
||||||
if verbose:
|
if verbose:
|
||||||
|
|
@ -111,22 +127,30 @@ def log_response(response, verbose, context=""):
|
||||||
if isinstance(data, dict) and len(data) < 5:
|
if isinstance(data, dict) and len(data) < 5:
|
||||||
print(f" {context} Response: {data}")
|
print(f" {context} Response: {data}")
|
||||||
else:
|
else:
|
||||||
print(f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}")
|
print(
|
||||||
|
f" {context} Response keys: {list(data.keys()) if isinstance(data, dict) else 'list'}"
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
print(f" {context} Response: {response.text[:100]}...")
|
print(f" {context} Response: {response.text[:100]}...")
|
||||||
|
|
||||||
|
|
||||||
# ================== TEST FUNCTIONS ==================
|
# ================== TEST FUNCTIONS ==================
|
||||||
|
|
||||||
|
|
||||||
def test_health_check(verbose):
|
def test_health_check(verbose):
|
||||||
"""Test service health endpoint"""
|
"""Test service health endpoint"""
|
||||||
response = requests.get(f"{BASE_URL}/health", timeout=10)
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/health", timeout=10
|
||||||
|
) # Health doesn't require auth
|
||||||
log_response(response, verbose, "Health")
|
log_response(response, verbose, "Health")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "status" in data, "Health response missing 'status' field"
|
assert "status" in data, "Health response missing 'status' field"
|
||||||
assert data["status"] in ["healthy", "degraded"], f"Invalid status: {data['status']}"
|
assert data["status"] in ["healthy", "degraded"], (
|
||||||
|
f"Invalid status: {data['status']}"
|
||||||
|
)
|
||||||
|
|
||||||
# Check individual services
|
# Check individual services
|
||||||
assert "services" in data, "Health response missing 'services' field"
|
assert "services" in data, "Health response missing 'services' field"
|
||||||
|
|
@ -136,18 +160,19 @@ def test_health_check(verbose):
|
||||||
for service, status in data["services"].items():
|
for service, status in data["services"].items():
|
||||||
print(f" {service}: {status}")
|
print(f" {service}: {status}")
|
||||||
|
|
||||||
|
|
||||||
def test_empty_search_protection(verbose):
|
def test_empty_search_protection(verbose):
|
||||||
"""Test empty query protection (should not return 500 error)"""
|
"""Test empty query protection (should not return 500 error)"""
|
||||||
payload = {
|
payload = {"query": "", "user_id": TEST_USER, "limit": 5}
|
||||||
"query": "",
|
|
||||||
"user_id": TEST_USER,
|
|
||||||
"limit": 5
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=10)
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=10
|
||||||
|
)
|
||||||
log_response(response, verbose, "Empty Search")
|
log_response(response, verbose, "Empty Search")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Empty query failed with {response.status_code}"
|
assert response.status_code == 200, (
|
||||||
|
f"Empty query failed with {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["memories"] == [], "Empty query should return empty memories list"
|
assert data["memories"] == [], "Empty query should return empty memories list"
|
||||||
|
|
@ -158,25 +183,39 @@ def test_empty_search_protection(verbose):
|
||||||
print(f" Empty search note: {data['note']}")
|
print(f" Empty search note: {data['note']}")
|
||||||
print(f" Total count: {data.get('total_count', 0)}")
|
print(f" Total count: {data.get('total_count', 0)}")
|
||||||
|
|
||||||
|
|
||||||
def test_add_memories_with_hierarchy(verbose):
|
def test_add_memories_with_hierarchy(verbose):
|
||||||
"""Test adding memories with multi-level hierarchy support"""
|
"""Test adding memories with multi-level hierarchy support"""
|
||||||
payload = {
|
payload = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "I work at TechCorp as a Senior Software Engineer"},
|
{
|
||||||
{"role": "user", "content": "My colleague Sarah from Marketing team helped with Q3 presentation"},
|
"role": "user",
|
||||||
{"role": "user", "content": "Meeting with John the Product Manager tomorrow about new feature development"}
|
"content": "I work at TechCorp as a Senior Software Engineer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "My colleague Sarah from Marketing team helped with Q3 presentation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Meeting with John the Product Manager tomorrow about new feature development",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
"user_id": TEST_USER,
|
"user_id": TEST_USER,
|
||||||
"agent_id": "test_agent",
|
"agent_id": "test_agent",
|
||||||
"run_id": "test_run_001",
|
"run_id": "test_run_001",
|
||||||
"session_id": "test_session_001",
|
"session_id": "test_session_001",
|
||||||
"metadata": {"test": "integration", "scenario": "work_context"}
|
"metadata": {"test": "integration", "scenario": "work_context"},
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60)
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
|
||||||
|
)
|
||||||
log_response(response, verbose, "Add Memories")
|
log_response(response, verbose, "Add Memories")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Add memories failed with {response.status_code}"
|
assert response.status_code == 200, (
|
||||||
|
f"Add memories failed with {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "added_memories" in data, "Response missing 'added_memories'"
|
assert "added_memories" in data, "Response missing 'added_memories'"
|
||||||
|
|
@ -191,23 +230,26 @@ def test_add_memories_with_hierarchy(verbose):
|
||||||
relations = first_memory["relations"]
|
relations = first_memory["relations"]
|
||||||
if "added_entities" in relations and relations["added_entities"]:
|
if "added_entities" in relations and relations["added_entities"]:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" Graph extracted: {len(relations['added_entities'])} relationships")
|
print(
|
||||||
|
f" Graph extracted: {len(relations['added_entities'])} relationships"
|
||||||
|
)
|
||||||
print(f" Sample relations: {relations['added_entities'][:3]}")
|
print(f" Sample relations: {relations['added_entities'][:3]}")
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" Added {len(memories)} memory blocks")
|
print(f" Added {len(memories)} memory blocks")
|
||||||
print(f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001")
|
print(
|
||||||
|
f" Hierarchy - Agent: test_agent, Run: test_run_001, Session: test_session_001"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_search_memories_basic(verbose):
|
def test_search_memories_basic(verbose):
|
||||||
"""Test basic memory search functionality"""
|
"""Test basic memory search functionality"""
|
||||||
# Test meaningful search
|
# Test meaningful search
|
||||||
payload = {
|
payload = {"query": "TechCorp", "user_id": TEST_USER, "limit": 10}
|
||||||
"query": "TechCorp",
|
|
||||||
"user_id": TEST_USER,
|
|
||||||
"limit": 10
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15)
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
|
||||||
|
)
|
||||||
log_response(response, verbose, "Search")
|
log_response(response, verbose, "Search")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Search failed with {response.status_code}"
|
assert response.status_code == 200, f"Search failed with {response.status_code}"
|
||||||
|
|
@ -232,6 +274,7 @@ def test_search_memories_basic(verbose):
|
||||||
print(f" Found {data['total_count']} memories")
|
print(f" Found {data['total_count']} memories")
|
||||||
print(f" First memory: {memory['memory'][:50]}...")
|
print(f" First memory: {memory['memory'][:50]}...")
|
||||||
|
|
||||||
|
|
||||||
def test_search_memories_hierarchy_filters(verbose):
|
def test_search_memories_hierarchy_filters(verbose):
|
||||||
"""Test multi-level hierarchy filtering in search"""
|
"""Test multi-level hierarchy filtering in search"""
|
||||||
# Test with hierarchy filters
|
# Test with hierarchy filters
|
||||||
|
|
@ -241,13 +284,17 @@ def test_search_memories_hierarchy_filters(verbose):
|
||||||
"agent_id": "test_agent",
|
"agent_id": "test_agent",
|
||||||
"run_id": "test_run_001",
|
"run_id": "test_run_001",
|
||||||
"session_id": "test_session_001",
|
"session_id": "test_session_001",
|
||||||
"limit": 10
|
"limit": 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(f"{BASE_URL}/memories/search", json=payload, timeout=15)
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/memories/search", json=payload, headers=AUTH_HEADERS, timeout=15
|
||||||
|
)
|
||||||
log_response(response, verbose, "Hierarchy Search")
|
log_response(response, verbose, "Hierarchy Search")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Hierarchy search failed with {response.status_code}"
|
assert response.status_code == 200, (
|
||||||
|
f"Hierarchy search failed with {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "memories" in data, "Hierarchy search response missing 'memories'"
|
assert "memories" in data, "Hierarchy search response missing 'memories'"
|
||||||
|
|
@ -257,7 +304,10 @@ def test_search_memories_hierarchy_filters(verbose):
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" Found {len(data['memories'])} memories with hierarchy filters")
|
print(f" Found {len(data['memories'])} memories with hierarchy filters")
|
||||||
print(f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001")
|
print(
|
||||||
|
f" Filters: agent_id=test_agent, run_id=test_run_001, session_id=test_session_001"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_user_memories_with_hierarchy(verbose):
|
def test_get_user_memories_with_hierarchy(verbose):
|
||||||
"""Test retrieving user memories with hierarchy filtering"""
|
"""Test retrieving user memories with hierarchy filtering"""
|
||||||
|
|
@ -266,13 +316,20 @@ def test_get_user_memories_with_hierarchy(verbose):
|
||||||
"limit": 20,
|
"limit": 20,
|
||||||
"agent_id": "test_agent",
|
"agent_id": "test_agent",
|
||||||
"run_id": "test_run_001",
|
"run_id": "test_run_001",
|
||||||
"session_id": "test_session_001"
|
"session_id": "test_session_001",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}", params=params, timeout=15)
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/memories/{TEST_USER}",
|
||||||
|
params=params,
|
||||||
|
headers=AUTH_HEADERS,
|
||||||
|
timeout=15,
|
||||||
|
)
|
||||||
log_response(response, verbose, "Get User Memories with Hierarchy")
|
log_response(response, verbose, "Get User Memories with Hierarchy")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Get user memories with hierarchy failed with {response.status_code}"
|
assert response.status_code == 200, (
|
||||||
|
f"Get user memories with hierarchy failed with {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
memories = response.json()
|
memories = response.json()
|
||||||
assert isinstance(memories, list), "User memories should return a list"
|
assert isinstance(memories, list), "User memories should return a list"
|
||||||
|
|
@ -290,10 +347,13 @@ def test_get_user_memories_with_hierarchy(verbose):
|
||||||
if verbose:
|
if verbose:
|
||||||
print(" No memories found with hierarchy filters (may be expected)")
|
print(" No memories found with hierarchy filters (may be expected)")
|
||||||
|
|
||||||
|
|
||||||
def test_memory_history(verbose):
|
def test_memory_history(verbose):
|
||||||
"""Test memory history endpoint"""
|
"""Test memory history endpoint"""
|
||||||
# First get a memory to check history for
|
# First get a memory to check history for
|
||||||
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
|
||||||
|
)
|
||||||
assert response.status_code == 200, "Failed to get memory for history test"
|
assert response.status_code == 200, "Failed to get memory for history test"
|
||||||
|
|
||||||
memories = response.json()
|
memories = response.json()
|
||||||
|
|
@ -305,27 +365,38 @@ def test_memory_history(verbose):
|
||||||
memory_id = memories[0]["id"]
|
memory_id = memories[0]["id"]
|
||||||
|
|
||||||
# Test memory history endpoint
|
# Test memory history endpoint
|
||||||
response = requests.get(f"{BASE_URL}/memories/{memory_id}/history", timeout=15)
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/memories/{memory_id}/history?user_id={TEST_USER}",
|
||||||
|
headers=AUTH_HEADERS,
|
||||||
|
timeout=15,
|
||||||
|
)
|
||||||
log_response(response, verbose, "Memory History")
|
log_response(response, verbose, "Memory History")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Memory history failed with {response.status_code}"
|
assert response.status_code == 200, (
|
||||||
|
f"Memory history failed with {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "memory_id" in data, "History response missing 'memory_id'"
|
assert "memory_id" in data, "History response missing 'memory_id'"
|
||||||
assert "history" in data, "History response missing 'history'"
|
assert "history" in data, "History response missing 'history'"
|
||||||
assert "message" in data, "History response missing success message"
|
assert "message" in data, "History response missing success message"
|
||||||
assert data["memory_id"] == memory_id, f"Wrong memory_id in response: {data['memory_id']}"
|
assert data["memory_id"] == memory_id, (
|
||||||
|
f"Wrong memory_id in response: {data['memory_id']}"
|
||||||
|
)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" Retrieved history for memory {memory_id}")
|
print(f" Retrieved history for memory {memory_id}")
|
||||||
print(f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}")
|
print(
|
||||||
|
f" History entries: {len(data['history']) if isinstance(data['history'], list) else 'N/A'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_update_memory(verbose):
|
def test_update_memory(verbose):
|
||||||
"""Test updating a specific memory"""
|
"""Test updating a specific memory"""
|
||||||
# First get a memory to update
|
# First get a memory to update
|
||||||
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
|
||||||
|
)
|
||||||
assert response.status_code == 200, "Failed to get memory for update test"
|
assert response.status_code == 200, "Failed to get memory for update test"
|
||||||
|
|
||||||
memories = response.json()
|
memories = response.json()
|
||||||
|
|
@ -337,10 +408,13 @@ def test_update_memory(verbose):
|
||||||
# Update the memory
|
# Update the memory
|
||||||
payload = {
|
payload = {
|
||||||
"memory_id": memory_id,
|
"memory_id": memory_id,
|
||||||
"content": f"UPDATED: {original_content}"
|
"user_id": TEST_USER,
|
||||||
|
"content": f"UPDATED: {original_content}",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.put(f"{BASE_URL}/memories", json=payload, timeout=10)
|
response = requests.put(
|
||||||
|
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=10
|
||||||
|
)
|
||||||
log_response(response, verbose, "Update")
|
log_response(response, verbose, "Update")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Update failed with {response.status_code}"
|
assert response.status_code == 200, f"Update failed with {response.status_code}"
|
||||||
|
|
@ -352,15 +426,15 @@ def test_update_memory(verbose):
|
||||||
print(f" Updated memory {memory_id}")
|
print(f" Updated memory {memory_id}")
|
||||||
print(f" Original: {original_content[:30]}...")
|
print(f" Original: {original_content[:30]}...")
|
||||||
|
|
||||||
|
|
||||||
def test_chat_with_memory(verbose):
|
def test_chat_with_memory(verbose):
|
||||||
"""Test memory-enhanced chat functionality"""
|
"""Test memory-enhanced chat functionality"""
|
||||||
payload = {
|
payload = {"message": "What company do I work for?", "user_id": TEST_USER}
|
||||||
"message": "What company do I work for?",
|
|
||||||
"user_id": TEST_USER
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=90)
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=90
|
||||||
|
)
|
||||||
log_response(response, verbose, "Chat")
|
log_response(response, verbose, "Chat")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Chat failed with {response.status_code}"
|
assert response.status_code == 200, f"Chat failed with {response.status_code}"
|
||||||
|
|
@ -383,12 +457,15 @@ def test_chat_with_memory(verbose):
|
||||||
print(" Chat endpoint timed out (LLM API may be slow)")
|
print(" Chat endpoint timed out (LLM API may be slow)")
|
||||||
# Still test that the endpoint exists and accepts requests
|
# Still test that the endpoint exists and accepts requests
|
||||||
try:
|
try:
|
||||||
response = requests.post(f"{BASE_URL}/chat", json=payload, timeout=5)
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/chat", json=payload, headers=AUTH_HEADERS, timeout=5
|
||||||
|
)
|
||||||
except requests.exceptions.ReadTimeout:
|
except requests.exceptions.ReadTimeout:
|
||||||
# This is expected - endpoint exists but processing is slow
|
# This is expected - endpoint exists but processing is slow
|
||||||
if verbose:
|
if verbose:
|
||||||
print(" Chat endpoint confirmed active (processing timeout expected)")
|
print(" Chat endpoint confirmed active (processing timeout expected)")
|
||||||
|
|
||||||
|
|
||||||
def test_graph_relationships_creation(verbose):
|
def test_graph_relationships_creation(verbose):
|
||||||
"""Test graph relationships creation with entity-rich memories"""
|
"""Test graph relationships creation with entity-rich memories"""
|
||||||
# Create a separate test user for graph relationship testing
|
# Create a separate test user for graph relationship testing
|
||||||
|
|
@ -397,41 +474,67 @@ def test_graph_relationships_creation(verbose):
|
||||||
# Add memories with clear entity relationships
|
# Add memories with clear entity relationships
|
||||||
payload = {
|
payload = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "John Smith works at Microsoft as a Senior Software Engineer"},
|
{
|
||||||
{"role": "user", "content": "John Smith is friends with Sarah Johnson who works at Google"},
|
"role": "user",
|
||||||
{"role": "user", "content": "Sarah Johnson lives in Seattle and loves hiking"},
|
"content": "John Smith works at Microsoft as a Senior Software Engineer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "John Smith is friends with Sarah Johnson who works at Google",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Sarah Johnson lives in Seattle and loves hiking",
|
||||||
|
},
|
||||||
{"role": "user", "content": "Microsoft is located in Redmond, Washington"},
|
{"role": "user", "content": "Microsoft is located in Redmond, Washington"},
|
||||||
{"role": "user", "content": "John Smith and Sarah Johnson both graduated from Stanford University"}
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "John Smith and Sarah Johnson both graduated from Stanford University",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
"user_id": graph_test_user,
|
"user_id": graph_test_user,
|
||||||
"metadata": {"test": "graph_relationships", "scenario": "entity_creation"}
|
"metadata": {"test": "graph_relationships", "scenario": "entity_creation"},
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(f"{BASE_URL}/memories", json=payload, timeout=60)
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/memories", json=payload, headers=AUTH_HEADERS, timeout=60
|
||||||
|
)
|
||||||
log_response(response, verbose, "Add Graph Memories")
|
log_response(response, verbose, "Add Graph Memories")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Add graph memories failed with {response.status_code}"
|
assert response.status_code == 200, (
|
||||||
|
f"Add graph memories failed with {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "added_memories" in data, "Response missing 'added_memories'"
|
assert "added_memories" in data, "Response missing 'added_memories'"
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" Added {len(data['added_memories'])} memories for graph relationship testing")
|
print(
|
||||||
|
f" Added {len(data['added_memories'])} memories for graph relationship testing"
|
||||||
|
)
|
||||||
|
|
||||||
# Wait a moment for graph processing (Mem0 graph extraction can be async)
|
# Wait a moment for graph processing (Mem0 graph extraction can be async)
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
# Test graph relationships endpoint
|
# Test graph relationships endpoint
|
||||||
response = requests.get(f"{BASE_URL}/graph/relationships/{graph_test_user}", timeout=15)
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/graph/relationships/{graph_test_user}",
|
||||||
|
headers=AUTH_HEADERS,
|
||||||
|
timeout=15,
|
||||||
|
)
|
||||||
log_response(response, verbose, "Graph Relationships")
|
log_response(response, verbose, "Graph Relationships")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Graph relationships failed with {response.status_code}"
|
assert response.status_code == 200, (
|
||||||
|
f"Graph relationships failed with {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
graph_data = response.json()
|
graph_data = response.json()
|
||||||
assert "relationships" in graph_data, "Graph response missing 'relationships'"
|
assert "relationships" in graph_data, "Graph response missing 'relationships'"
|
||||||
assert "entities" in graph_data, "Graph response missing 'entities'"
|
assert "entities" in graph_data, "Graph response missing 'entities'"
|
||||||
assert "user_id" in graph_data, "Graph response missing 'user_id'"
|
assert "user_id" in graph_data, "Graph response missing 'user_id'"
|
||||||
assert graph_data["user_id"] == graph_test_user, f"Wrong user_id in graph: {graph_data['user_id']}"
|
assert graph_data["user_id"] == graph_test_user, (
|
||||||
|
f"Wrong user_id in graph: {graph_data['user_id']}"
|
||||||
|
)
|
||||||
|
|
||||||
relationships = graph_data["relationships"]
|
relationships = graph_data["relationships"]
|
||||||
entities = graph_data["entities"]
|
entities = graph_data["entities"]
|
||||||
|
|
@ -451,16 +554,24 @@ def test_graph_relationships_creation(verbose):
|
||||||
|
|
||||||
# Print sample entities if they exist
|
# Print sample entities if they exist
|
||||||
if entities:
|
if entities:
|
||||||
print(f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}")
|
print(
|
||||||
|
f" Sample entities: {[e.get('name', str(e)) for e in entities[:5]]}"
|
||||||
|
)
|
||||||
|
|
||||||
# Verify relationship structure (if relationships exist)
|
# Verify relationship structure (if relationships exist)
|
||||||
for rel in relationships:
|
for rel in relationships:
|
||||||
assert "source" in rel or "from" in rel, f"Relationship missing source/from: {rel}"
|
assert "source" in rel or "from" in rel, (
|
||||||
|
f"Relationship missing source/from: {rel}"
|
||||||
|
)
|
||||||
assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}"
|
assert "target" in rel or "to" in rel, f"Relationship missing target/to: {rel}"
|
||||||
assert "relationship" in rel or "type" in rel, f"Relationship missing type: {rel}"
|
assert "relationship" in rel or "type" in rel, (
|
||||||
|
f"Relationship missing type: {rel}"
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up graph test user memories
|
# Clean up graph test user memories
|
||||||
cleanup_response = requests.delete(f"{BASE_URL}/memories/user/{graph_test_user}", timeout=15)
|
cleanup_response = requests.delete(
|
||||||
|
f"{BASE_URL}/memories/user/{graph_test_user}", headers=AUTH_HEADERS, timeout=15
|
||||||
|
)
|
||||||
assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories"
|
assert cleanup_response.status_code == 200, "Failed to cleanup graph test memories"
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
|
|
@ -469,12 +580,17 @@ def test_graph_relationships_creation(verbose):
|
||||||
# Note: We expect some relationships even if graph extraction is basic
|
# Note: We expect some relationships even if graph extraction is basic
|
||||||
# The test passes if the endpoint works and returns proper structure
|
# The test passes if the endpoint works and returns proper structure
|
||||||
|
|
||||||
|
|
||||||
def test_graph_relationships(verbose):
|
def test_graph_relationships(verbose):
|
||||||
"""Test graph relationships endpoint"""
|
"""Test graph relationships endpoint"""
|
||||||
response = requests.get(f"{BASE_URL}/graph/relationships/{TEST_USER}", timeout=15)
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/graph/relationships/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
|
||||||
|
)
|
||||||
log_response(response, verbose, "Graph")
|
log_response(response, verbose, "Graph")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Graph endpoint failed with {response.status_code}"
|
assert response.status_code == 200, (
|
||||||
|
f"Graph endpoint failed with {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "relationships" in data, "Graph response missing 'relationships'"
|
assert "relationships" in data, "Graph response missing 'relationships'"
|
||||||
|
|
@ -486,10 +602,13 @@ def test_graph_relationships(verbose):
|
||||||
print(f" Relationships: {len(data['relationships'])}")
|
print(f" Relationships: {len(data['relationships'])}")
|
||||||
print(f" Entities: {len(data['entities'])}")
|
print(f" Entities: {len(data['entities'])}")
|
||||||
|
|
||||||
|
|
||||||
def test_delete_specific_memory(verbose):
|
def test_delete_specific_memory(verbose):
|
||||||
"""Test deleting a specific memory"""
|
"""Test deleting a specific memory"""
|
||||||
# Get a memory to delete
|
# Get a memory to delete
|
||||||
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=1", timeout=10)
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/memories/{TEST_USER}?limit=1", headers=AUTH_HEADERS, timeout=10
|
||||||
|
)
|
||||||
assert response.status_code == 200, "Failed to get memory for deletion test"
|
assert response.status_code == 200, "Failed to get memory for deletion test"
|
||||||
|
|
||||||
memories = response.json()
|
memories = response.json()
|
||||||
|
|
@ -498,7 +617,9 @@ def test_delete_specific_memory(verbose):
|
||||||
memory_id = memories[0]["id"]
|
memory_id = memories[0]["id"]
|
||||||
|
|
||||||
# Delete the memory
|
# Delete the memory
|
||||||
response = requests.delete(f"{BASE_URL}/memories/{memory_id}", timeout=10)
|
response = requests.delete(
|
||||||
|
f"{BASE_URL}/memories/{memory_id}", headers=AUTH_HEADERS, timeout=10
|
||||||
|
)
|
||||||
log_response(response, verbose, "Delete")
|
log_response(response, verbose, "Delete")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Delete failed with {response.status_code}"
|
assert response.status_code == 200, f"Delete failed with {response.status_code}"
|
||||||
|
|
@ -509,9 +630,12 @@ def test_delete_specific_memory(verbose):
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f" Deleted memory {memory_id}")
|
print(f" Deleted memory {memory_id}")
|
||||||
|
|
||||||
|
|
||||||
def test_delete_all_user_memories(verbose):
|
def test_delete_all_user_memories(verbose):
|
||||||
"""Test deleting all memories for a user"""
|
"""Test deleting all memories for a user"""
|
||||||
response = requests.delete(f"{BASE_URL}/memories/user/{TEST_USER}", timeout=15)
|
response = requests.delete(
|
||||||
|
f"{BASE_URL}/memories/user/{TEST_USER}", headers=AUTH_HEADERS, timeout=15
|
||||||
|
)
|
||||||
log_response(response, verbose, "Delete All")
|
log_response(response, verbose, "Delete All")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Delete all failed with {response.status_code}"
|
assert response.status_code == 200, f"Delete all failed with {response.status_code}"
|
||||||
|
|
@ -522,12 +646,17 @@ def test_delete_all_user_memories(verbose):
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Deleted all memories for {TEST_USER}")
|
print(f"Deleted all memories for {TEST_USER}")
|
||||||
|
|
||||||
|
|
||||||
def test_cleanup_verification(verbose):
|
def test_cleanup_verification(verbose):
|
||||||
"""Verify cleanup was successful"""
|
"""Verify cleanup was successful"""
|
||||||
response = requests.get(f"{BASE_URL}/memories/{TEST_USER}?limit=10", timeout=10)
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/memories/{TEST_USER}?limit=10", headers=AUTH_HEADERS, timeout=10
|
||||||
|
)
|
||||||
log_response(response, verbose, "Cleanup Check")
|
log_response(response, verbose, "Cleanup Check")
|
||||||
|
|
||||||
assert response.status_code == 200, f"Cleanup verification failed with {response.status_code}"
|
assert response.status_code == 200, (
|
||||||
|
f"Cleanup verification failed with {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
memories = response.json()
|
memories = response.json()
|
||||||
assert isinstance(memories, list), "Should return list even if empty"
|
assert isinstance(memories, list), "Should return list even if empty"
|
||||||
|
|
@ -539,5 +668,79 @@ def test_cleanup_verification(verbose):
|
||||||
if verbose:
|
if verbose:
|
||||||
print(" Cleanup successful - no memories remain")
|
print(" Cleanup successful - no memories remain")
|
||||||
|
|
||||||
|
|
||||||
|
# ================== SECURITY TEST FUNCTIONS ==================
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_required_endpoints(verbose):
|
||||||
|
"""Test that protected endpoints require authentication"""
|
||||||
|
endpoints_requiring_auth = [
|
||||||
|
("GET", f"{BASE_URL}/memories/{TEST_USER}"),
|
||||||
|
("POST", f"{BASE_URL}/memories/search"),
|
||||||
|
("GET", f"{BASE_URL}/stats"),
|
||||||
|
("GET", f"{BASE_URL}/models"),
|
||||||
|
("GET", f"{BASE_URL}/users"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for method, url in endpoints_requiring_auth:
|
||||||
|
if method == "GET":
|
||||||
|
response = requests.get(url, timeout=5)
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
url, json={"query": "test", "user_id": TEST_USER}, timeout=5
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code in [401, 403], (
|
||||||
|
f"{method} {url} should require auth, got {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" {method} {url}: {response.status_code} (auth required)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_ownership_verification(verbose):
|
||||||
|
"""Test that users can only access their own data"""
|
||||||
|
other_user = "other_user_not_me"
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/memories/{other_user}", headers=AUTH_HEADERS, timeout=5
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code in [403, 404], (
|
||||||
|
f"Accessing other user's memories should be denied, got {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" Ownership check passed: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_size_limit(verbose):
|
||||||
|
"""Test request size limit enforcement (10MB max)"""
|
||||||
|
large_payload = {
|
||||||
|
"messages": [{"role": "user", "content": "x" * (11 * 1024 * 1024)}],
|
||||||
|
"user_id": TEST_USER,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/memories",
|
||||||
|
json=large_payload,
|
||||||
|
headers={**AUTH_HEADERS, "Content-Length": str(11 * 1024 * 1024)},
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 413, (
|
||||||
|
f"Large request should return 413, got {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" Request size limit enforced: {response.status_code}")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f" Request size limit test: connection issue (expected for large payload)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
Loading…
Reference in a new issue