"""Base abstract class for AI router implementations.""" from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union import time @dataclass class RouterResponse: """Standardized response object for all AI providers.""" content: str model: str provider: str latency: float = 0.0 input_tokens: Optional[int] = None output_tokens: Optional[int] = None total_tokens: Optional[int] = None cost: Optional[float] = None finish_reason: Optional[str] = None metadata: Dict[str, Any] = field(default_factory=dict) raw_response: Optional[Any] = None def __str__(self) -> str: return self.content class AIRouter(ABC): """Abstract base class for AI provider routers.""" def __init__( self, model: str, api_key: Optional[str] = None, **kwargs: Any ) -> None: """Initialize the router with model and configuration. Args: model: Model identifier for the provider api_key: API key for authentication (optional if set in environment) **kwargs: Additional provider-specific configuration """ self.model = model self.api_key = api_key self.config = kwargs self.provider = self.__class__.__name__.lower() @abstractmethod def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]: """Prepare provider-specific request parameters. Args: prompt: User prompt **kwargs: Additional parameters Returns: Dictionary of request parameters """ pass @abstractmethod def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse: """Parse provider response into RouterResponse. Args: raw_response: Raw response from provider latency: Request latency in seconds Returns: Standardized RouterResponse """ pass @abstractmethod def _make_request(self, request_params: Dict[str, Any]) -> Any: """Make synchronous request to provider. Args: request_params: Provider-specific request parameters Returns: Raw provider response """ pass @abstractmethod async def _make_async_request(self, request_params: Dict[str, Any]) -> Any: """Make asynchronous request to provider. Args: request_params: Provider-specific request parameters Returns: Raw provider response """ pass def call(self, prompt: str, **kwargs: Any) -> RouterResponse: """Make a synchronous call to the AI provider. Args: prompt: User prompt **kwargs: Additional parameters (model, temperature, max_tokens, etc.) Returns: RouterResponse with the result """ # Merge default config with call-specific parameters params = {**self.config, **kwargs} # Prepare request request_params = self._prepare_request(prompt, **params) # Make request and measure latency start_time = time.time() raw_response = self._make_request(request_params) latency = time.time() - start_time # Parse and return response return self._parse_response(raw_response, latency) async def acall(self, prompt: str, **kwargs: Any) -> RouterResponse: """Make an asynchronous call to the AI provider. Args: prompt: User prompt **kwargs: Additional parameters (model, temperature, max_tokens, etc.) Returns: RouterResponse with the result """ # Merge default config with call-specific parameters params = {**self.config, **kwargs} # Prepare request request_params = self._prepare_request(prompt, **params) # Make request and measure latency start_time = time.time() raw_response = await self._make_async_request(request_params) latency = time.time() - start_time # Parse and return response return self._parse_response(raw_response, latency) def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]: """Stream responses from the AI provider. Args: prompt: User prompt **kwargs: Additional parameters Yields: RouterResponse chunks Raises: NotImplementedError: If streaming is not supported """ raise NotImplementedError(f"{self.provider} does not support streaming yet") async def astream( self, prompt: str, **kwargs: Any ) -> AsyncGenerator[RouterResponse, None]: """Asynchronously stream responses from the AI provider. Args: prompt: User prompt **kwargs: Additional parameters Yields: RouterResponse chunks Raises: NotImplementedError: If async streaming is not supported """ raise NotImplementedError(f"{self.provider} does not support async streaming yet") yield # Required for async generator type hint