template-with-ai/router/base.py
2025-07-01 16:22:20 +05:30

178 lines
No EOL
5.5 KiB
Python

"""Base abstract class for AI router implementations."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union
import time
@dataclass
class RouterResponse:
"""Standardized response object for all AI providers."""
content: str
model: str
provider: str
latency: float = 0.0
input_tokens: Optional[int] = None
output_tokens: Optional[int] = None
total_tokens: Optional[int] = None
cost: Optional[float] = None
finish_reason: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
raw_response: Optional[Any] = None
def __str__(self) -> str:
return self.content
class AIRouter(ABC):
"""Abstract base class for AI provider routers."""
def __init__(
self,
model: str,
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize the router with model and configuration.
Args:
model: Model identifier for the provider
api_key: API key for authentication (optional if set in environment)
**kwargs: Additional provider-specific configuration
"""
self.model = model
self.api_key = api_key
self.config = kwargs
self.provider = self.__class__.__name__.lower()
@abstractmethod
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
"""Prepare provider-specific request parameters.
Args:
prompt: User prompt
**kwargs: Additional parameters
Returns:
Dictionary of request parameters
"""
pass
@abstractmethod
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
"""Parse provider response into RouterResponse.
Args:
raw_response: Raw response from provider
latency: Request latency in seconds
Returns:
Standardized RouterResponse
"""
pass
@abstractmethod
def _make_request(self, request_params: Dict[str, Any]) -> Any:
"""Make synchronous request to provider.
Args:
request_params: Provider-specific request parameters
Returns:
Raw provider response
"""
pass
@abstractmethod
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
"""Make asynchronous request to provider.
Args:
request_params: Provider-specific request parameters
Returns:
Raw provider response
"""
pass
def call(self, prompt: str, **kwargs: Any) -> RouterResponse:
"""Make a synchronous call to the AI provider.
Args:
prompt: User prompt
**kwargs: Additional parameters (model, temperature, max_tokens, etc.)
Returns:
RouterResponse with the result
"""
# Merge default config with call-specific parameters
params = {**self.config, **kwargs}
# Prepare request
request_params = self._prepare_request(prompt, **params)
# Make request and measure latency
start_time = time.time()
raw_response = self._make_request(request_params)
latency = time.time() - start_time
# Parse and return response
return self._parse_response(raw_response, latency)
async def acall(self, prompt: str, **kwargs: Any) -> RouterResponse:
"""Make an asynchronous call to the AI provider.
Args:
prompt: User prompt
**kwargs: Additional parameters (model, temperature, max_tokens, etc.)
Returns:
RouterResponse with the result
"""
# Merge default config with call-specific parameters
params = {**self.config, **kwargs}
# Prepare request
request_params = self._prepare_request(prompt, **params)
# Make request and measure latency
start_time = time.time()
raw_response = await self._make_async_request(request_params)
latency = time.time() - start_time
# Parse and return response
return self._parse_response(raw_response, latency)
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
"""Stream responses from the AI provider.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
Raises:
NotImplementedError: If streaming is not supported
"""
raise NotImplementedError(f"{self.provider} does not support streaming yet")
async def astream(
self, prompt: str, **kwargs: Any
) -> AsyncGenerator[RouterResponse, None]:
"""Asynchronously stream responses from the AI provider.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
Raises:
NotImplementedError: If async streaming is not supported
"""
raise NotImplementedError(f"{self.provider} does not support async streaming yet")
yield # Required for async generator type hint