259 lines
No EOL
8.3 KiB
Python
259 lines
No EOL
8.3 KiB
Python
"""Configuration management for AI router."""
|
|
|
|
import os
|
|
from typing import Optional, Dict, Any, List
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
@dataclass
|
|
class RouterConfig:
|
|
"""Configuration for AI router."""
|
|
|
|
# API Keys
|
|
openai_api_key: Optional[str] = field(default_factory=lambda: os.getenv("OPENAI_API_KEY"))
|
|
gemini_api_key: Optional[str] = field(default_factory=lambda: os.getenv("GEMINI_API_KEY"))
|
|
google_api_key: Optional[str] = field(default_factory=lambda: os.getenv("GOOGLE_API_KEY"))
|
|
cohere_api_key: Optional[str] = field(default_factory=lambda: os.getenv("COHERE_API_KEY"))
|
|
|
|
# Ollama configuration
|
|
ollama_base_url: str = field(default_factory=lambda: os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"))
|
|
|
|
# Available Ollama embedding models
|
|
ollama_embedding_models: List[str] = field(default_factory=lambda: [
|
|
"mxbai-embed-large:latest",
|
|
"nomic-embed-text:latest",
|
|
])
|
|
|
|
# Default model settings
|
|
default_temperature: float = 0.7
|
|
default_max_tokens: Optional[int] = None
|
|
default_top_p: float = 1.0
|
|
default_timeout: int = 160 # seconds
|
|
|
|
# Retry configuration
|
|
max_retries: int = 3
|
|
retry_delay: float = 1.0
|
|
retry_backoff: float = 2.0
|
|
|
|
# Cost tracking
|
|
track_costs: bool = False
|
|
|
|
# Cost per million tokens (prices in USD)
|
|
cost_per_1m_input_tokens: Dict[str, float] = field(default_factory=lambda: {
|
|
# OpenAI Compatible Models
|
|
"gpt-4o": 2.5,
|
|
"azure/gpt-4.1-new": 2.0,
|
|
"o3": 2.0,
|
|
"o4-mini": 1.1,
|
|
"claude-opus-4": 15.0,
|
|
"claude-sonnet-4": 3.0,
|
|
|
|
# Gemini Models
|
|
"gemini-2.5-pro": 1.25,
|
|
"gemini-2.5-flash": 0.3,
|
|
"gemini-2.0-flash-001": 0.3, # Alias for current default
|
|
|
|
# Cohere Models
|
|
"command-a": 2.5,
|
|
"command-r-plus": 2.5,
|
|
})
|
|
|
|
cost_per_1m_output_tokens: Dict[str, float] = field(default_factory=lambda: {
|
|
# OpenAI Compatible Models
|
|
"gpt-4o": 10.0,
|
|
"azure/gpt-4.1-new": 8.0,
|
|
"o3": 8.0,
|
|
"o4-mini": 4.4,
|
|
"claude-opus-4": 75.0,
|
|
"claude-sonnet-4": 15.0,
|
|
|
|
# Gemini Models
|
|
"gemini-2.5-pro": 10.0,
|
|
"gemini-2.5-flash": 2.5,
|
|
"gemini-2.0-flash-001": 2.5, # Alias for current default
|
|
|
|
# Cohere Models
|
|
"command-a": 10.0,
|
|
"command-r-plus": 10.0,
|
|
})
|
|
|
|
# Cached input pricing for specific models (per million tokens)
|
|
cost_per_1m_cached_input_tokens: Dict[str, float] = field(default_factory=lambda: {
|
|
"azure/gpt-4.1-new": 0.5,
|
|
"o3": 0.5,
|
|
"o4-mini": 0.28,
|
|
})
|
|
|
|
# Reranking costs (per 1k searches)
|
|
cost_per_1k_rerank_searches: Dict[str, float] = field(default_factory=lambda: {
|
|
"rerank-3.5": 2.0,
|
|
"rerank-english-v3.0": 2.0,
|
|
"rerank-multilingual-v3.0": 2.0,
|
|
})
|
|
|
|
# Embedding costs (per million tokens)
|
|
cost_per_1m_embed_tokens: Dict[str, float] = field(default_factory=lambda: {
|
|
"embed-english-v3.0": 0.12,
|
|
"embed-multilingual-v3.0": 0.12,
|
|
"embed-english-light-v3.0": 0.12,
|
|
"text-embedding-004": 0.12, # Google's embedding model
|
|
# Ollama models are free (local)
|
|
"mxbai-embed-large:latest": 0.0,
|
|
"nomic-embed-text:latest": 0.0,
|
|
"nomic-embed-text:137m-v1.5-fp16": 0.0,
|
|
})
|
|
|
|
# Image embedding costs (per million image tokens)
|
|
cost_per_1m_embed_image_tokens: Dict[str, float] = field(default_factory=lambda: {
|
|
"embed-english-v3.0": 0.47,
|
|
"embed-multilingual-v3.0": 0.47,
|
|
})
|
|
|
|
# Logging
|
|
log_requests: bool = False
|
|
log_responses: bool = False
|
|
log_errors: bool = True
|
|
|
|
def get_api_key(self, provider: str) -> Optional[str]:
|
|
"""Get API key for a specific provider.
|
|
|
|
Args:
|
|
provider: Provider name (openai, gemini, cohere, etc.)
|
|
|
|
Returns:
|
|
API key if available
|
|
"""
|
|
provider = provider.lower()
|
|
if provider == "openai":
|
|
return self.openai_api_key
|
|
elif provider in ["gemini", "google"]:
|
|
return self.gemini_api_key or self.google_api_key
|
|
elif provider == "cohere":
|
|
return self.cohere_api_key
|
|
return None
|
|
|
|
def calculate_cost(
|
|
self,
|
|
model: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
cached_input: bool = False
|
|
) -> Optional[float]:
|
|
"""Calculate cost for a request.
|
|
|
|
Args:
|
|
model: Model identifier
|
|
input_tokens: Number of input tokens
|
|
output_tokens: Number of output tokens
|
|
cached_input: Whether input tokens are cached (for compatible models)
|
|
|
|
Returns:
|
|
Total cost in USD, or None if cost data not available
|
|
"""
|
|
if not self.track_costs:
|
|
return None
|
|
|
|
# Check if using cached pricing
|
|
if cached_input and model in self.cost_per_1m_cached_input_tokens:
|
|
input_cost_per_1m = self.cost_per_1m_cached_input_tokens.get(model)
|
|
else:
|
|
input_cost_per_1m = self.cost_per_1m_input_tokens.get(model)
|
|
|
|
output_cost_per_1m = self.cost_per_1m_output_tokens.get(model)
|
|
|
|
if input_cost_per_1m is None or output_cost_per_1m is None:
|
|
return None
|
|
|
|
input_cost = (input_tokens / 1_000_000) * input_cost_per_1m
|
|
output_cost = (output_tokens / 1_000_000) * output_cost_per_1m
|
|
|
|
return round(input_cost + output_cost, 6)
|
|
|
|
def calculate_rerank_cost(
|
|
self,
|
|
model: str,
|
|
num_searches: int
|
|
) -> Optional[float]:
|
|
"""Calculate cost for reranking.
|
|
|
|
Args:
|
|
model: Rerank model identifier
|
|
num_searches: Number of searches performed
|
|
|
|
Returns:
|
|
Total cost in USD, or None if cost data not available
|
|
"""
|
|
if not self.track_costs:
|
|
return None
|
|
|
|
cost_per_1k = self.cost_per_1k_rerank_searches.get(model)
|
|
if cost_per_1k is None:
|
|
return None
|
|
|
|
return round((num_searches / 1000) * cost_per_1k, 6)
|
|
|
|
def calculate_embed_cost(
|
|
self,
|
|
model: str,
|
|
num_tokens: int,
|
|
is_image: bool = False
|
|
) -> Optional[float]:
|
|
"""Calculate cost for embeddings.
|
|
|
|
Args:
|
|
model: Embedding model identifier
|
|
num_tokens: Number of tokens to embed
|
|
is_image: Whether these are image tokens
|
|
|
|
Returns:
|
|
Total cost in USD, or None if cost data not available
|
|
"""
|
|
if not self.track_costs:
|
|
return None
|
|
|
|
if is_image:
|
|
cost_per_1m = self.cost_per_1m_embed_image_tokens.get(model)
|
|
else:
|
|
cost_per_1m = self.cost_per_1m_embed_tokens.get(model)
|
|
|
|
if cost_per_1m is None:
|
|
return None
|
|
|
|
return round((num_tokens / 1_000_000) * cost_per_1m, 6)
|
|
|
|
@classmethod
|
|
def from_env(cls) -> "RouterConfig":
|
|
"""Create config from environment variables.
|
|
|
|
Returns:
|
|
RouterConfig instance
|
|
"""
|
|
return cls()
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert config to dictionary.
|
|
|
|
Returns:
|
|
Dictionary representation
|
|
"""
|
|
return {
|
|
"openai_api_key": "***" if self.openai_api_key else None,
|
|
"gemini_api_key": "***" if self.gemini_api_key else None,
|
|
"google_api_key": "***" if self.google_api_key else None,
|
|
"cohere_api_key": "***" if self.cohere_api_key else None,
|
|
"default_temperature": self.default_temperature,
|
|
"default_max_tokens": self.default_max_tokens,
|
|
"default_top_p": self.default_top_p,
|
|
"default_timeout": self.default_timeout,
|
|
"max_retries": self.max_retries,
|
|
"retry_delay": self.retry_delay,
|
|
"retry_backoff": self.retry_backoff,
|
|
"track_costs": self.track_costs,
|
|
"log_requests": self.log_requests,
|
|
"log_responses": self.log_responses,
|
|
"log_errors": self.log_errors,
|
|
}
|
|
|
|
|
|
# Global config instance
|
|
config = RouterConfig.from_env() |