"""Custom exceptions for AI router.""" from typing import Optional, Any class RouterError(Exception): """Base exception for all router errors.""" def __init__( self, message: str, provider: Optional[str] = None, original_error: Optional[Exception] = None, **kwargs: Any ) -> None: """Initialize router error. Args: message: Error message provider: Provider that raised the error original_error: Original exception from provider **kwargs: Additional error details """ super().__init__(message) self.provider = provider self.original_error = original_error self.details = kwargs def __str__(self) -> str: """String representation of error.""" base_msg = super().__str__() if self.provider: base_msg = f"[{self.provider}] {base_msg}" if self.original_error: base_msg += f" (Original: {self.original_error})" return base_msg class AuthenticationError(RouterError): """Raised when authentication fails.""" pass class RateLimitError(RouterError): """Raised when rate limit is exceeded.""" def __init__( self, message: str, retry_after: Optional[float] = None, **kwargs: Any ) -> None: """Initialize rate limit error. Args: message: Error message retry_after: Seconds to wait before retry **kwargs: Additional error details """ super().__init__(message, **kwargs) self.retry_after = retry_after class ModelNotFoundError(RouterError): """Raised when requested model is not available.""" def __init__( self, message: str, model: str, available_models: Optional[list] = None, **kwargs: Any ) -> None: """Initialize model not found error. Args: message: Error message model: Requested model available_models: List of available models **kwargs: Additional error details """ super().__init__(message, **kwargs) self.model = model self.available_models = available_models or [] class InvalidRequestError(RouterError): """Raised when request parameters are invalid.""" pass class TimeoutError(RouterError): """Raised when request times out.""" pass class ContentFilterError(RouterError): """Raised when content is blocked by safety filters.""" pass class QuotaExceededError(RouterError): """Raised when API quota is exceeded.""" pass class ProviderError(RouterError): """Raised for provider-specific errors.""" pass class ConfigurationError(RouterError): """Raised for configuration errors.""" pass def map_provider_error(provider: str, error: Exception) -> RouterError: """Map provider-specific errors to router errors. Args: provider: Provider name error: Original exception Returns: Appropriate RouterError subclass """ error_message = str(error) error_type = type(error).__name__ # Common patterns across providers if any(x in error_message.lower() for x in ["unauthorized", "invalid api key", "authentication"]): return AuthenticationError( f"Authentication failed: {error_message}", provider=provider, original_error=error ) if any(x in error_message.lower() for x in ["rate limit", "too many requests", "quota exceeded"]): return RateLimitError( f"Rate limit exceeded: {error_message}", provider=provider, original_error=error ) if any(x in error_message.lower() for x in ["model not found", "invalid model", "unknown model"]): return ModelNotFoundError( f"Model not found: {error_message}", model="", # Extract from error if possible provider=provider, original_error=error ) if any(x in error_message.lower() for x in ["timeout", "timed out"]): return TimeoutError( f"Request timed out: {error_message}", provider=provider, original_error=error ) if any(x in error_message.lower() for x in ["content filter", "safety", "blocked"]): return ContentFilterError( f"Content blocked by safety filters: {error_message}", provider=provider, original_error=error ) # Default to provider error return ProviderError( f"{error_type}: {error_message}", provider=provider, original_error=error )