178 lines
No EOL
5.5 KiB
Python
178 lines
No EOL
5.5 KiB
Python
"""Base abstract class for AI router implementations."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union
|
|
import time
|
|
|
|
|
|
@dataclass
|
|
class RouterResponse:
|
|
"""Standardized response object for all AI providers."""
|
|
content: str
|
|
model: str
|
|
provider: str
|
|
latency: float = 0.0
|
|
input_tokens: Optional[int] = None
|
|
output_tokens: Optional[int] = None
|
|
total_tokens: Optional[int] = None
|
|
cost: Optional[float] = None
|
|
finish_reason: Optional[str] = None
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
raw_response: Optional[Any] = None
|
|
|
|
def __str__(self) -> str:
|
|
return self.content
|
|
|
|
|
|
class AIRouter(ABC):
|
|
"""Abstract base class for AI provider routers."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
api_key: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Initialize the router with model and configuration.
|
|
|
|
Args:
|
|
model: Model identifier for the provider
|
|
api_key: API key for authentication (optional if set in environment)
|
|
**kwargs: Additional provider-specific configuration
|
|
"""
|
|
self.model = model
|
|
self.api_key = api_key
|
|
self.config = kwargs
|
|
self.provider = self.__class__.__name__.lower()
|
|
|
|
@abstractmethod
|
|
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
|
|
"""Prepare provider-specific request parameters.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
Dictionary of request parameters
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
|
|
"""Parse provider response into RouterResponse.
|
|
|
|
Args:
|
|
raw_response: Raw response from provider
|
|
latency: Request latency in seconds
|
|
|
|
Returns:
|
|
Standardized RouterResponse
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _make_request(self, request_params: Dict[str, Any]) -> Any:
|
|
"""Make synchronous request to provider.
|
|
|
|
Args:
|
|
request_params: Provider-specific request parameters
|
|
|
|
Returns:
|
|
Raw provider response
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
|
|
"""Make asynchronous request to provider.
|
|
|
|
Args:
|
|
request_params: Provider-specific request parameters
|
|
|
|
Returns:
|
|
Raw provider response
|
|
"""
|
|
pass
|
|
|
|
def call(self, prompt: str, **kwargs: Any) -> RouterResponse:
|
|
"""Make a synchronous call to the AI provider.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters (model, temperature, max_tokens, etc.)
|
|
|
|
Returns:
|
|
RouterResponse with the result
|
|
"""
|
|
# Merge default config with call-specific parameters
|
|
params = {**self.config, **kwargs}
|
|
|
|
# Prepare request
|
|
request_params = self._prepare_request(prompt, **params)
|
|
|
|
# Make request and measure latency
|
|
start_time = time.time()
|
|
raw_response = self._make_request(request_params)
|
|
latency = time.time() - start_time
|
|
|
|
# Parse and return response
|
|
return self._parse_response(raw_response, latency)
|
|
|
|
async def acall(self, prompt: str, **kwargs: Any) -> RouterResponse:
|
|
"""Make an asynchronous call to the AI provider.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters (model, temperature, max_tokens, etc.)
|
|
|
|
Returns:
|
|
RouterResponse with the result
|
|
"""
|
|
# Merge default config with call-specific parameters
|
|
params = {**self.config, **kwargs}
|
|
|
|
# Prepare request
|
|
request_params = self._prepare_request(prompt, **params)
|
|
|
|
# Make request and measure latency
|
|
start_time = time.time()
|
|
raw_response = await self._make_async_request(request_params)
|
|
latency = time.time() - start_time
|
|
|
|
# Parse and return response
|
|
return self._parse_response(raw_response, latency)
|
|
|
|
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
|
|
"""Stream responses from the AI provider.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Yields:
|
|
RouterResponse chunks
|
|
|
|
Raises:
|
|
NotImplementedError: If streaming is not supported
|
|
"""
|
|
raise NotImplementedError(f"{self.provider} does not support streaming yet")
|
|
|
|
async def astream(
|
|
self, prompt: str, **kwargs: Any
|
|
) -> AsyncGenerator[RouterResponse, None]:
|
|
"""Asynchronously stream responses from the AI provider.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Yields:
|
|
RouterResponse chunks
|
|
|
|
Raises:
|
|
NotImplementedError: If async streaming is not supported
|
|
"""
|
|
raise NotImplementedError(f"{self.provider} does not support async streaming yet")
|
|
yield # Required for async generator type hint |