template-with-ai/router/config.py
2025-07-01 16:22:20 +05:30

259 lines
No EOL
8.3 KiB
Python

"""Configuration management for AI router."""
import os
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
@dataclass
class RouterConfig:
"""Configuration for AI router."""
# API Keys
openai_api_key: Optional[str] = field(default_factory=lambda: os.getenv("OPENAI_API_KEY"))
gemini_api_key: Optional[str] = field(default_factory=lambda: os.getenv("GEMINI_API_KEY"))
google_api_key: Optional[str] = field(default_factory=lambda: os.getenv("GOOGLE_API_KEY"))
cohere_api_key: Optional[str] = field(default_factory=lambda: os.getenv("COHERE_API_KEY"))
# Ollama configuration
ollama_base_url: str = field(default_factory=lambda: os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"))
# Available Ollama embedding models
ollama_embedding_models: List[str] = field(default_factory=lambda: [
"mxbai-embed-large:latest",
"nomic-embed-text:latest",
])
# Default model settings
default_temperature: float = 0.7
default_max_tokens: Optional[int] = None
default_top_p: float = 1.0
default_timeout: int = 160 # seconds
# Retry configuration
max_retries: int = 3
retry_delay: float = 1.0
retry_backoff: float = 2.0
# Cost tracking
track_costs: bool = False
# Cost per million tokens (prices in USD)
cost_per_1m_input_tokens: Dict[str, float] = field(default_factory=lambda: {
# OpenAI Compatible Models
"gpt-4o": 2.5,
"azure/gpt-4.1-new": 2.0,
"o3": 2.0,
"o4-mini": 1.1,
"claude-opus-4": 15.0,
"claude-sonnet-4": 3.0,
# Gemini Models
"gemini-2.5-pro": 1.25,
"gemini-2.5-flash": 0.3,
"gemini-2.0-flash-001": 0.3, # Alias for current default
# Cohere Models
"command-a": 2.5,
"command-r-plus": 2.5,
})
cost_per_1m_output_tokens: Dict[str, float] = field(default_factory=lambda: {
# OpenAI Compatible Models
"gpt-4o": 10.0,
"azure/gpt-4.1-new": 8.0,
"o3": 8.0,
"o4-mini": 4.4,
"claude-opus-4": 75.0,
"claude-sonnet-4": 15.0,
# Gemini Models
"gemini-2.5-pro": 10.0,
"gemini-2.5-flash": 2.5,
"gemini-2.0-flash-001": 2.5, # Alias for current default
# Cohere Models
"command-a": 10.0,
"command-r-plus": 10.0,
})
# Cached input pricing for specific models (per million tokens)
cost_per_1m_cached_input_tokens: Dict[str, float] = field(default_factory=lambda: {
"azure/gpt-4.1-new": 0.5,
"o3": 0.5,
"o4-mini": 0.28,
})
# Reranking costs (per 1k searches)
cost_per_1k_rerank_searches: Dict[str, float] = field(default_factory=lambda: {
"rerank-3.5": 2.0,
"rerank-english-v3.0": 2.0,
"rerank-multilingual-v3.0": 2.0,
})
# Embedding costs (per million tokens)
cost_per_1m_embed_tokens: Dict[str, float] = field(default_factory=lambda: {
"embed-english-v3.0": 0.12,
"embed-multilingual-v3.0": 0.12,
"embed-english-light-v3.0": 0.12,
"text-embedding-004": 0.12, # Google's embedding model
# Ollama models are free (local)
"mxbai-embed-large:latest": 0.0,
"nomic-embed-text:latest": 0.0,
"nomic-embed-text:137m-v1.5-fp16": 0.0,
})
# Image embedding costs (per million image tokens)
cost_per_1m_embed_image_tokens: Dict[str, float] = field(default_factory=lambda: {
"embed-english-v3.0": 0.47,
"embed-multilingual-v3.0": 0.47,
})
# Logging
log_requests: bool = False
log_responses: bool = False
log_errors: bool = True
def get_api_key(self, provider: str) -> Optional[str]:
"""Get API key for a specific provider.
Args:
provider: Provider name (openai, gemini, cohere, etc.)
Returns:
API key if available
"""
provider = provider.lower()
if provider == "openai":
return self.openai_api_key
elif provider in ["gemini", "google"]:
return self.gemini_api_key or self.google_api_key
elif provider == "cohere":
return self.cohere_api_key
return None
def calculate_cost(
self,
model: str,
input_tokens: int,
output_tokens: int,
cached_input: bool = False
) -> Optional[float]:
"""Calculate cost for a request.
Args:
model: Model identifier
input_tokens: Number of input tokens
output_tokens: Number of output tokens
cached_input: Whether input tokens are cached (for compatible models)
Returns:
Total cost in USD, or None if cost data not available
"""
if not self.track_costs:
return None
# Check if using cached pricing
if cached_input and model in self.cost_per_1m_cached_input_tokens:
input_cost_per_1m = self.cost_per_1m_cached_input_tokens.get(model)
else:
input_cost_per_1m = self.cost_per_1m_input_tokens.get(model)
output_cost_per_1m = self.cost_per_1m_output_tokens.get(model)
if input_cost_per_1m is None or output_cost_per_1m is None:
return None
input_cost = (input_tokens / 1_000_000) * input_cost_per_1m
output_cost = (output_tokens / 1_000_000) * output_cost_per_1m
return round(input_cost + output_cost, 6)
def calculate_rerank_cost(
self,
model: str,
num_searches: int
) -> Optional[float]:
"""Calculate cost for reranking.
Args:
model: Rerank model identifier
num_searches: Number of searches performed
Returns:
Total cost in USD, or None if cost data not available
"""
if not self.track_costs:
return None
cost_per_1k = self.cost_per_1k_rerank_searches.get(model)
if cost_per_1k is None:
return None
return round((num_searches / 1000) * cost_per_1k, 6)
def calculate_embed_cost(
self,
model: str,
num_tokens: int,
is_image: bool = False
) -> Optional[float]:
"""Calculate cost for embeddings.
Args:
model: Embedding model identifier
num_tokens: Number of tokens to embed
is_image: Whether these are image tokens
Returns:
Total cost in USD, or None if cost data not available
"""
if not self.track_costs:
return None
if is_image:
cost_per_1m = self.cost_per_1m_embed_image_tokens.get(model)
else:
cost_per_1m = self.cost_per_1m_embed_tokens.get(model)
if cost_per_1m is None:
return None
return round((num_tokens / 1_000_000) * cost_per_1m, 6)
@classmethod
def from_env(cls) -> "RouterConfig":
"""Create config from environment variables.
Returns:
RouterConfig instance
"""
return cls()
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary.
Returns:
Dictionary representation
"""
return {
"openai_api_key": "***" if self.openai_api_key else None,
"gemini_api_key": "***" if self.gemini_api_key else None,
"google_api_key": "***" if self.google_api_key else None,
"cohere_api_key": "***" if self.cohere_api_key else None,
"default_temperature": self.default_temperature,
"default_max_tokens": self.default_max_tokens,
"default_top_p": self.default_top_p,
"default_timeout": self.default_timeout,
"max_retries": self.max_retries,
"retry_delay": self.retry_delay,
"retry_backoff": self.retry_backoff,
"track_costs": self.track_costs,
"log_requests": self.log_requests,
"log_responses": self.log_responses,
"log_errors": self.log_errors,
}
# Global config instance
config = RouterConfig.from_env()