"""Configuration management for AI router.""" import os from typing import Optional, Dict, Any, List from dataclasses import dataclass, field @dataclass class RouterConfig: """Configuration for AI router.""" # API Keys openai_api_key: Optional[str] = field(default_factory=lambda: os.getenv("OPENAI_API_KEY")) gemini_api_key: Optional[str] = field(default_factory=lambda: os.getenv("GEMINI_API_KEY")) google_api_key: Optional[str] = field(default_factory=lambda: os.getenv("GOOGLE_API_KEY")) cohere_api_key: Optional[str] = field(default_factory=lambda: os.getenv("COHERE_API_KEY")) # Ollama configuration ollama_base_url: str = field(default_factory=lambda: os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")) # Available Ollama embedding models ollama_embedding_models: List[str] = field(default_factory=lambda: [ "mxbai-embed-large:latest", "nomic-embed-text:latest", ]) # Default model settings default_temperature: float = 0.7 default_max_tokens: Optional[int] = None default_top_p: float = 1.0 default_timeout: int = 160 # seconds # Retry configuration max_retries: int = 3 retry_delay: float = 1.0 retry_backoff: float = 2.0 # Cost tracking track_costs: bool = False # Cost per million tokens (prices in USD) cost_per_1m_input_tokens: Dict[str, float] = field(default_factory=lambda: { # OpenAI Compatible Models "gpt-4o": 2.5, "azure/gpt-4.1-new": 2.0, "o3": 2.0, "o4-mini": 1.1, "claude-opus-4": 15.0, "claude-sonnet-4": 3.0, # Gemini Models "gemini-2.5-pro": 1.25, "gemini-2.5-flash": 0.3, "gemini-2.0-flash-001": 0.3, # Alias for current default # Cohere Models "command-a": 2.5, "command-r-plus": 2.5, }) cost_per_1m_output_tokens: Dict[str, float] = field(default_factory=lambda: { # OpenAI Compatible Models "gpt-4o": 10.0, "azure/gpt-4.1-new": 8.0, "o3": 8.0, "o4-mini": 4.4, "claude-opus-4": 75.0, "claude-sonnet-4": 15.0, # Gemini Models "gemini-2.5-pro": 10.0, "gemini-2.5-flash": 2.5, "gemini-2.0-flash-001": 2.5, # Alias for current default # Cohere Models "command-a": 10.0, "command-r-plus": 10.0, }) # Cached input pricing for specific models (per million tokens) cost_per_1m_cached_input_tokens: Dict[str, float] = field(default_factory=lambda: { "azure/gpt-4.1-new": 0.5, "o3": 0.5, "o4-mini": 0.28, }) # Reranking costs (per 1k searches) cost_per_1k_rerank_searches: Dict[str, float] = field(default_factory=lambda: { "rerank-3.5": 2.0, "rerank-english-v3.0": 2.0, "rerank-multilingual-v3.0": 2.0, }) # Embedding costs (per million tokens) cost_per_1m_embed_tokens: Dict[str, float] = field(default_factory=lambda: { "embed-english-v3.0": 0.12, "embed-multilingual-v3.0": 0.12, "embed-english-light-v3.0": 0.12, "text-embedding-004": 0.12, # Google's embedding model # Ollama models are free (local) "mxbai-embed-large:latest": 0.0, "nomic-embed-text:latest": 0.0, "nomic-embed-text:137m-v1.5-fp16": 0.0, }) # Image embedding costs (per million image tokens) cost_per_1m_embed_image_tokens: Dict[str, float] = field(default_factory=lambda: { "embed-english-v3.0": 0.47, "embed-multilingual-v3.0": 0.47, }) # Logging log_requests: bool = False log_responses: bool = False log_errors: bool = True def get_api_key(self, provider: str) -> Optional[str]: """Get API key for a specific provider. Args: provider: Provider name (openai, gemini, cohere, etc.) Returns: API key if available """ provider = provider.lower() if provider == "openai": return self.openai_api_key elif provider in ["gemini", "google"]: return self.gemini_api_key or self.google_api_key elif provider == "cohere": return self.cohere_api_key return None def calculate_cost( self, model: str, input_tokens: int, output_tokens: int, cached_input: bool = False ) -> Optional[float]: """Calculate cost for a request. Args: model: Model identifier input_tokens: Number of input tokens output_tokens: Number of output tokens cached_input: Whether input tokens are cached (for compatible models) Returns: Total cost in USD, or None if cost data not available """ if not self.track_costs: return None # Check if using cached pricing if cached_input and model in self.cost_per_1m_cached_input_tokens: input_cost_per_1m = self.cost_per_1m_cached_input_tokens.get(model) else: input_cost_per_1m = self.cost_per_1m_input_tokens.get(model) output_cost_per_1m = self.cost_per_1m_output_tokens.get(model) if input_cost_per_1m is None or output_cost_per_1m is None: return None input_cost = (input_tokens / 1_000_000) * input_cost_per_1m output_cost = (output_tokens / 1_000_000) * output_cost_per_1m return round(input_cost + output_cost, 6) def calculate_rerank_cost( self, model: str, num_searches: int ) -> Optional[float]: """Calculate cost for reranking. Args: model: Rerank model identifier num_searches: Number of searches performed Returns: Total cost in USD, or None if cost data not available """ if not self.track_costs: return None cost_per_1k = self.cost_per_1k_rerank_searches.get(model) if cost_per_1k is None: return None return round((num_searches / 1000) * cost_per_1k, 6) def calculate_embed_cost( self, model: str, num_tokens: int, is_image: bool = False ) -> Optional[float]: """Calculate cost for embeddings. Args: model: Embedding model identifier num_tokens: Number of tokens to embed is_image: Whether these are image tokens Returns: Total cost in USD, or None if cost data not available """ if not self.track_costs: return None if is_image: cost_per_1m = self.cost_per_1m_embed_image_tokens.get(model) else: cost_per_1m = self.cost_per_1m_embed_tokens.get(model) if cost_per_1m is None: return None return round((num_tokens / 1_000_000) * cost_per_1m, 6) @classmethod def from_env(cls) -> "RouterConfig": """Create config from environment variables. Returns: RouterConfig instance """ return cls() def to_dict(self) -> Dict[str, Any]: """Convert config to dictionary. Returns: Dictionary representation """ return { "openai_api_key": "***" if self.openai_api_key else None, "gemini_api_key": "***" if self.gemini_api_key else None, "google_api_key": "***" if self.google_api_key else None, "cohere_api_key": "***" if self.cohere_api_key else None, "default_temperature": self.default_temperature, "default_max_tokens": self.default_max_tokens, "default_top_p": self.default_top_p, "default_timeout": self.default_timeout, "max_retries": self.max_retries, "retry_delay": self.retry_delay, "retry_backoff": self.retry_backoff, "track_costs": self.track_costs, "log_requests": self.log_requests, "log_responses": self.log_responses, "log_errors": self.log_errors, } # Global config instance config = RouterConfig.from_env()