add OpenAI-compatible endpoint and improved login UI

- Add /v1/chat/completions and /chat/completions endpoints (OpenAI SDK compatible)
- Add streaming support with SSE for chat completions
- Add get_current_user_openai auth supporting Bearer token and X-API-Key
- Add OpenAI-compatible request/response models (OpenAIChatCompletionRequest, etc.)
- Cherry-pick improved login UI from cloud branch (styled login screen, logout button)
This commit is contained in:
Pratik Narola 2026-01-15 23:29:08 +05:30
parent a228780146
commit 2c1d73a1ec
4 changed files with 569 additions and 30 deletions

View file

@ -1,7 +1,7 @@
"""Simple API key authentication for Mem0 Interface.""" """Simple API key authentication for Mem0 Interface."""
from typing import Optional from typing import Optional
from fastapi import HTTPException, Security, status from fastapi import HTTPException, Security, status, Header
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
import structlog import structlog
@ -19,7 +19,9 @@ class AuthService:
def __init__(self): def __init__(self):
"""Initialize auth service with API key to user mapping.""" """Initialize auth service with API key to user mapping."""
self.api_key_to_user = settings.api_key_mapping self.api_key_to_user = settings.api_key_mapping
logger.info(f"Auth service initialized with {len(self.api_key_to_user)} API keys") logger.info(
f"Auth service initialized with {len(self.api_key_to_user)} API keys"
)
def verify_api_key(self, api_key: str) -> str: def verify_api_key(self, api_key: str) -> str:
""" """
@ -37,8 +39,7 @@ class AuthService:
if api_key not in self.api_key_to_user: if api_key not in self.api_key_to_user:
logger.warning(f"Invalid API key attempted: {api_key[:10]}...") logger.warning(f"Invalid API key attempted: {api_key[:10]}...")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
detail="Invalid API key"
) )
user_id = self.api_key_to_user[api_key] user_id = self.api_key_to_user[api_key]
@ -68,7 +69,7 @@ class AuthService:
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied: You can only access your own memories" detail=f"Access denied: You can only access your own memories",
) )
return authenticated_user_id return authenticated_user_id
@ -91,9 +92,46 @@ async def get_current_user(api_key: str = Security(api_key_header)) -> str:
return auth_service.verify_api_key(api_key) return auth_service.verify_api_key(api_key)
async def get_current_user_openai(
authorization: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
) -> str:
"""
FastAPI dependency for OpenAI-compatible authentication.
Supports both Authorization: Bearer and X-API-Key headers.
Args:
authorization: Authorization header (Bearer token)
x_api_key: X-API-Key header
Returns:
str: Authenticated user_id
Raises:
HTTPException: If no valid API key is provided
"""
api_key = None
# Try Bearer token first (OpenAI standard)
if authorization and authorization.startswith("Bearer "):
api_key = authorization[7:] # Remove "Bearer " prefix
logger.debug("Extracted API key from Authorization Bearer token")
# Fall back to X-API-Key header
elif x_api_key:
api_key = x_api_key
logger.debug("Extracted API key from X-API-Key header")
else:
logger.warning("No API key provided in Authorization or X-API-Key headers")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API key. Provide either 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header",
)
return auth_service.verify_api_key(api_key)
async def verify_user_access( async def verify_user_access(
api_key: str = Security(api_key_header), api_key: str = Security(api_key_header), user_id: Optional[str] = None
user_id: Optional[str] = None
) -> str: ) -> str:
""" """
FastAPI dependency to verify user can access the requested user_id. FastAPI dependency to verify user can access the requested user_id.
@ -114,7 +152,7 @@ async def verify_user_access(
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: You can only access your own memories" detail="Access denied: You can only access your own memories",
) )
return authenticated_user_id return authenticated_user_id

View file

@ -9,8 +9,9 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security, Request from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse, StreamingResponse
import structlog import structlog
import asyncio
from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
@ -41,9 +42,14 @@ from models import (
ErrorResponse, ErrorResponse,
GlobalStatsResponse, GlobalStatsResponse,
UserStatsResponse, UserStatsResponse,
OpenAIChatCompletionRequest,
OpenAIChatCompletionResponse,
OpenAIChoice,
OpenAIChoiceMessage,
OpenAIUsage,
) )
from mem0_manager import mem0_manager from mem0_manager import mem0_manager
from auth import get_current_user, auth_service from auth import get_current_user, get_current_user_openai, auth_service
# Configure structured logging # Configure structured logging
structlog.configure( structlog.configure(
@ -330,6 +336,142 @@ async def chat_with_memory(
) )
async def stream_openai_response(
completion_id: str, model: str, content: str, created: int
):
"""Generate SSE stream for OpenAI-compatible streaming by chunking the response."""
import uuid
# First chunk with role
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
# Stream content in chunks (3 words at a time for smooth effect)
words = content.split()
chunk_size = 3
for i in range(0, len(words), chunk_size):
word_chunk = " ".join(words[i : i + chunk_size])
if i + chunk_size < len(words):
word_chunk += " "
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{"index": 0, "delta": {"content": word_chunk}, "finish_reason": None}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.05)
# Final chunk
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
@app.post("/chat/completions")
@limiter.limit("30/minute")
async def openai_chat_completions(
request: Request,
completion_request: OpenAIChatCompletionRequest,
authenticated_user: str = Depends(get_current_user_openai),
):
"""OpenAI-compatible chat completions endpoint with mem0 memory integration."""
try:
import uuid
user_id = authenticated_user
logger.info(
f"OpenAI chat completion for user: {user_id} (streaming={completion_request.stream})"
)
# Extract last user message
user_messages = [
m for m in completion_request.messages if m.get("role") == "user"
]
if not user_messages:
raise HTTPException(
status_code=400,
detail="No user messages provided. Include at least one message with role='user'.",
)
last_message = user_messages[-1].get("content", "")
context = (
completion_request.messages[:-1]
if len(completion_request.messages) > 1
else None
)
# Call chat_with_memory
result = await mem0_manager.chat_with_memory(
message=last_message,
user_id=user_id,
context=context,
)
completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created_time = int(time.time())
assistant_content = result.get("response", "")
if completion_request.stream:
return StreamingResponse(
stream_openai_response(
completion_id=completion_id,
model=settings.default_model,
content=assistant_content,
created=created_time,
),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
else:
return OpenAIChatCompletionResponse(
id=completion_id,
object="chat.completion",
created=created_time,
model=settings.default_model,
choices=[
OpenAIChoice(
index=0,
message=OpenAIChoiceMessage(
role="assistant", content=assistant_content
),
finish_reason="stop",
)
],
usage=OpenAIUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in OpenAI chat completions: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Memory management endpoints - pure Mem0 passthroughs # Memory management endpoints - pure Mem0 passthroughs
@app.post("/memories") @app.post("/memories")
@limiter.limit("60/minute") # Memory operations - 60/min @limiter.limit("60/minute") # Memory operations - 60/min

View file

@ -227,3 +227,85 @@ class UserStatsResponse(BaseModel):
avg_response_time_ms: float = Field( avg_response_time_ms: float = Field(
..., description="Average response time for this user's requests" ..., description="Average response time for this user's requests"
) )
# OpenAI-Compatible API Models
class OpenAIMessage(BaseModel):
"""OpenAI message format."""
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
class OpenAIChatCompletionRequest(BaseModel):
"""OpenAI chat completion request format."""
model: str = Field(..., description="Model to use (will use configured default)")
messages: List[Dict[str, str]] = Field(..., description="List of messages")
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
stream: Optional[bool] = Field(False, description="Whether to stream responses")
top_p: Optional[float] = Field(1.0, description="Nucleus sampling parameter")
n: Optional[int] = Field(1, description="Number of completions to generate")
stop: Optional[List[str]] = Field(None, description="Stop sequences")
presence_penalty: Optional[float] = Field(0, description="Presence penalty")
frequency_penalty: Optional[float] = Field(0, description="Frequency penalty")
user: Optional[str] = Field(
None, description="User identifier (ignored, uses API key)"
)
class OpenAIUsage(BaseModel):
"""Token usage information."""
prompt_tokens: int = Field(..., description="Tokens in the prompt")
completion_tokens: int = Field(..., description="Tokens in the completion")
total_tokens: int = Field(..., description="Total tokens used")
class OpenAIChoiceMessage(BaseModel):
"""Message in a choice."""
role: str = Field(..., description="Role of the message")
content: str = Field(..., description="Content of the message")
class OpenAIChoice(BaseModel):
"""Individual completion choice."""
index: int = Field(..., description="Choice index")
message: OpenAIChoiceMessage = Field(..., description="Message content")
finish_reason: str = Field(..., description="Reason for completion finish")
class OpenAIChatCompletionResponse(BaseModel):
"""OpenAI chat completion response format."""
id: str = Field(..., description="Unique completion ID")
object: str = Field(default="chat.completion", description="Object type")
created: int = Field(..., description="Unix timestamp of creation")
model: str = Field(..., description="Model used for completion")
choices: List[OpenAIChoice] = Field(..., description="List of completion choices")
usage: Optional[OpenAIUsage] = Field(None, description="Token usage information")
# Streaming-specific models
class OpenAIStreamDelta(BaseModel):
"""Delta content in a streaming chunk."""
role: Optional[str] = Field(None, description="Role (only in first chunk)")
content: Optional[str] = Field(None, description="Incremental content")
class OpenAIStreamChoice(BaseModel):
"""Individual streaming choice."""
index: int = Field(..., description="Choice index")
delta: OpenAIStreamDelta = Field(..., description="Delta content")
finish_reason: Optional[str] = Field(
None, description="Reason for completion finish"
)

View file

@ -18,12 +18,106 @@
display: flex; display: flex;
} }
/* Login Screen */
.login-screen {
display: flex;
align-items: center;
justify-content: center;
width: 100%;
height: 100vh;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
.login-screen.hidden {
display: none;
}
.login-box {
background: white;
padding: 40px;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.2);
width: 100%;
max-width: 400px;
}
.login-box h1 {
margin-bottom: 10px;
color: #333;
font-size: 28px;
text-align: center;
}
.login-box p {
color: #666;
font-size: 14px;
text-align: center;
margin-bottom: 30px;
}
.login-box input {
width: 100%;
padding: 14px;
border: 2px solid #e0e0e0;
border-radius: 8px;
font-size: 14px;
margin-bottom: 20px;
outline: none;
transition: border-color 0.3s;
}
.login-box input:focus {
border-color: #667eea;
}
.login-box button {
width: 100%;
padding: 14px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s, opacity 0.3s;
}
.login-box button:hover {
transform: translateY(-2px);
}
.login-box button:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none;
}
.login-error {
background: #ffe6e6;
border: 1px solid #ffcccc;
color: #cc0000;
padding: 12px;
border-radius: 8px;
margin-bottom: 20px;
font-size: 14px;
display: none;
}
.login-error.show {
display: block;
}
.container { .container {
display: flex; display: flex;
width: 100%; width: 100%;
height: 100vh; height: 100vh;
} }
.container.hidden {
display: none;
}
/* Chat Section */ /* Chat Section */
.chat-section { .chat-section {
flex: 1; flex: 1;
@ -58,7 +152,12 @@
font-size: 14px; font-size: 14px;
} }
.clear-chat-btn { .header-buttons {
display: flex;
gap: 10px;
}
.clear-chat-btn, .logout-btn {
background: #f8f9fa; background: #f8f9fa;
color: #666; color: #666;
border: 1px solid #e0e0e0; border: 1px solid #e0e0e0;
@ -73,16 +172,27 @@
transition: all 0.2s ease; transition: all 0.2s ease;
} }
.clear-chat-btn:hover { .clear-chat-btn:hover, .logout-btn:hover {
background: #e9ecef; background: #e9ecef;
border-color: #ced4da; border-color: #ced4da;
color: #495057; color: #495057;
} }
.clear-chat-btn:active { .clear-chat-btn:active, .logout-btn:active {
background: #dee2e6; background: #dee2e6;
} }
.logout-btn {
background: #fff3cd;
border-color: #ffc107;
color: #856404;
}
.logout-btn:hover {
background: #ffe69c;
border-color: #ffb300;
}
.chat-messages { .chat-messages {
flex: 1; flex: 1;
overflow-y: auto; overflow-y: auto;
@ -281,17 +391,42 @@
</style> </style>
</head> </head>
<body> <body>
<div class="container"> <!-- Login Screen -->
<div class="login-screen" id="loginScreen">
<div class="login-box">
<h1>🧠 Mem0 Chat</h1>
<p>Enter your API key to access your memory-powered assistant</p>
<div class="login-error" id="loginError"></div>
<input
type="password"
id="apiKeyInput"
placeholder="Enter your API key (e.g., sk-xxxxx)"
autocomplete="off"
/>
<button id="loginButton">Connect</button>
</div>
</div>
<!-- Main Chat Interface (hidden initially) -->
<div class="container hidden" id="mainContainer">
<!-- Chat Section --> <!-- Chat Section -->
<div class="chat-section"> <div class="chat-section">
<div class="chat-header"> <div class="chat-header">
<div class="chat-header-content"> <div class="chat-header-content">
<h1>What can I help you with?</h1> <h1>What can I help you with?</h1>
<p>Chat with your memories - User: pratik</p> <p>Chat with your memories - User: <span id="currentUser">...</span></p>
</div>
<div class="header-buttons">
<button class="clear-chat-btn" id="clearChatBtn" title="Clear chat history">
🗑️ Clear Chat
</button>
<button class="logout-btn" id="logoutBtn" title="Logout">
🚪 Logout
</button>
</div> </div>
<button class="clear-chat-btn" id="clearChatBtn" title="Clear chat history">
🗑️ Clear Chat
</button>
</div> </div>
<div class="chat-messages" id="chatMessages"> <div class="chat-messages" id="chatMessages">
@ -319,10 +454,20 @@
<script> <script>
// Configuration // Configuration
const API_BASE = 'http://localhost:8000'; const API_BASE = window.location.origin;
const USER_ID = 'pratik';
// State
let API_KEY = null;
let USER_ID = null;
// DOM Elements // DOM Elements
const loginScreen = document.getElementById('loginScreen');
const mainContainer = document.getElementById('mainContainer');
const apiKeyInput = document.getElementById('apiKeyInput');
const loginButton = document.getElementById('loginButton');
const loginError = document.getElementById('loginError');
const logoutBtn = document.getElementById('logoutBtn');
const currentUser = document.getElementById('currentUser');
const chatMessages = document.getElementById('chatMessages'); const chatMessages = document.getElementById('chatMessages');
const messageInput = document.getElementById('messageInput'); const messageInput = document.getElementById('messageInput');
const sendButton = document.getElementById('sendButton'); const sendButton = document.getElementById('sendButton');
@ -336,19 +481,143 @@
// Initialize // Initialize
document.addEventListener('DOMContentLoaded', function() { document.addEventListener('DOMContentLoaded', function() {
loadChatHistory(); // Check if already logged in
loadMemories(); const savedApiKey = localStorage.getItem('apiKey');
const savedUserId = localStorage.getItem('userId');
if (savedApiKey && savedUserId) {
// Auto-login with saved credentials
API_KEY = savedApiKey;
USER_ID = savedUserId;
showMainInterface();
}
// Event listeners // Event listeners
loginButton.addEventListener('click', handleLogin);
apiKeyInput.addEventListener('keydown', (e) => {
if (e.key === 'Enter') handleLogin();
});
logoutBtn.addEventListener('click', handleLogout);
sendButton.addEventListener('click', sendMessage); sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keydown', handleKeyDown); messageInput.addEventListener('keydown', handleKeyDown);
messageInput.addEventListener('input', autoResizeTextarea); messageInput.addEventListener('input', autoResizeTextarea);
refreshButton.addEventListener('click', loadMemories); refreshButton.addEventListener('click', loadMemories);
clearChatBtn.addEventListener('click', clearChatWithConfirmation); clearChatBtn.addEventListener('click', clearChatWithConfirmation);
});
// Handle login
async function handleLogin() {
const apiKey = apiKeyInput.value.trim();
if (!apiKey) {
showLoginError('Please enter an API key');
return;
}
loginButton.disabled = true;
loginButton.textContent = 'Verifying...';
hideLoginError();
try {
// Verify API key by calling /health with auth
const response = await fetch(`${API_BASE}/health`, {
headers: {
'X-API-Key': apiKey
}
});
if (!response.ok) {
throw new Error('Invalid API key');
}
// Get user_id by trying to call a test endpoint
// We'll use /models since it doesn't require auth parameters
const userResponse = await fetch(`${API_BASE}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-API-Key': apiKey
},
body: JSON.stringify({
model: 'gpt-4',
messages: [{ role: 'user', content: 'test' }],
stream: false
})
});
if (!userResponse.ok) {
throw new Error('Failed to verify user');
}
// API key is valid, extract user_id from auth_service mapping
// We'll store both and use a simple username extraction
API_KEY = apiKey;
// Try to extract username from API key (e.g., sk-alice -> alice)
if (apiKey.startsWith('sk-')) {
const parts = apiKey.substring(3).split('-');
USER_ID = parts[0]; // Get first part after sk-
} else {
USER_ID = 'user';
}
// Save to localStorage
localStorage.setItem('apiKey', API_KEY);
localStorage.setItem('userId', USER_ID);
// Show main interface
showMainInterface();
} catch (error) {
console.error('Login error:', error);
showLoginError('Invalid API key. Please check and try again.');
loginButton.disabled = false;
loginButton.textContent = 'Connect';
}
}
// Show main interface
function showMainInterface() {
loginScreen.classList.add('hidden');
mainContainer.classList.remove('hidden');
currentUser.textContent = USER_ID;
// Load data
loadChatHistory();
loadMemories();
// Initialize textarea height // Initialize textarea height
autoResizeTextarea(); autoResizeTextarea();
}); messageInput.focus();
}
// Handle logout
function handleLogout() {
if (confirm('Are you sure you want to logout?')) {
// Clear credentials
localStorage.removeItem('apiKey');
localStorage.removeItem('userId');
API_KEY = null;
USER_ID = null;
// Show login screen
mainContainer.classList.add('hidden');
loginScreen.classList.remove('hidden');
apiKeyInput.value = '';
hideLoginError();
}
}
// Show login error
function showLoginError(message) {
loginError.textContent = message;
loginError.classList.add('show');
}
// Hide login error
function hideLoginError() {
loginError.classList.remove('show');
}
// Load chat history from localStorage // Load chat history from localStorage
function loadChatHistory() { function loadChatHistory() {
@ -419,6 +688,7 @@
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'X-API-Key': API_KEY
}, },
body: JSON.stringify({ body: JSON.stringify({
message: message, message: message,
@ -459,8 +729,12 @@
// Load memories from backend // Load memories from backend
async function loadMemories() { async function loadMemories() {
try { try{
const response = await fetch(`${API_BASE}/memories/${USER_ID}?limit=50`); const response = await fetch(`${API_BASE}/memories/${USER_ID}?limit=50`, {
headers: {
'X-API-Key': API_KEY
}
});
if (!response.ok) { if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`); throw new Error(`HTTP error! status: ${response.status}`);
@ -512,8 +786,11 @@
} }
try { try {
const response = await fetch(`${API_BASE}/memories/${memoryId}`, { const response = await fetch(`${API_BASE}/memories/${memoryId}?user_id=${USER_ID}`, {
method: 'DELETE' method: 'DELETE',
headers: {
'X-API-Key': API_KEY
}
}); });
if (!response.ok) { if (!response.ok) {