"""OpenAI-compatible provider implementation.""" from typing import Any, Dict, Optional, AsyncGenerator, Generator import time from .base import AIRouter, RouterResponse from .config import config from .exceptions import ( ConfigurationError, map_provider_error ) class OpenAICompatible(AIRouter): """Router for OpenAI and compatible APIs (OpenAI, Azure OpenAI, etc.).""" def __init__( self, model: str = "gpt-3.5-turbo", api_key: Optional[str] = None, base_url: Optional[str] = None, organization: Optional[str] = None, **kwargs: Any ) -> None: """Initialize OpenAI-compatible router. Args: model: Model to use (gpt-4, gpt-3.5-turbo, etc.) api_key: API key (optional if set in environment) base_url: Base URL for API (for Azure or custom endpoints) organization: Organization ID for OpenAI **kwargs: Additional configuration """ # Get API key from config if not provided if not api_key: api_key = config.get_api_key("openai") if not api_key: raise ConfigurationError( "OpenAI API key not found. Set OPENAI_API_KEY environment variable.", provider="openai" ) super().__init__(model=model, api_key=api_key, **kwargs) # Store additional config self.base_url = base_url self.organization = organization # Initialize OpenAI client try: from openai import OpenAI, AsyncOpenAI # Build client kwargs with proper types client_kwargs: Dict[str, Any] = { "api_key": api_key, } if base_url: client_kwargs["base_url"] = base_url if organization: client_kwargs["organization"] = organization # Add any additional client configuration from kwargs # Note: Only pass through valid OpenAI client parameters valid_client_params = ["timeout", "max_retries", "default_headers", "default_query", "http_client"] for param in valid_client_params: if param in kwargs: client_kwargs[param] = kwargs.pop(param) self.client = OpenAI(**client_kwargs) self.async_client = AsyncOpenAI(**client_kwargs) except ImportError: raise ConfigurationError( "openai package not installed. Install with: pip install openai", provider="openai" ) def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]: """Prepare OpenAI request parameters. Args: prompt: User prompt **kwargs: Additional parameters Returns: Request parameters """ # Build messages messages = kwargs.get("messages", [ {"role": "user", "content": prompt} ]) # Build request parameters params = { "model": kwargs.get("model", self.model), "messages": messages, } # Add optional parameters if "temperature" in kwargs: params["temperature"] = kwargs["temperature"] elif hasattr(config, "default_temperature"): params["temperature"] = config.default_temperature if "max_tokens" in kwargs: params["max_tokens"] = kwargs["max_tokens"] elif hasattr(config, "default_max_tokens") and config.default_max_tokens: params["max_tokens"] = config.default_max_tokens if "top_p" in kwargs: params["top_p"] = kwargs["top_p"] elif hasattr(config, "default_top_p"): params["top_p"] = config.default_top_p # Other OpenAI-specific parameters for key in ["n", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "user", "seed", "tools", "tool_choice", "response_format", "logprobs", "top_logprobs", "parallel_tool_calls"]: if key in kwargs: params[key] = kwargs[key] return params def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse: """Parse OpenAI response into RouterResponse. Args: raw_response: Raw OpenAI response latency: Request latency Returns: RouterResponse """ # Extract first choice choice = raw_response.choices[0] content = choice.message.content or "" # Get token counts usage = raw_response.usage input_tokens = usage.prompt_tokens if usage else None output_tokens = usage.completion_tokens if usage else None total_tokens = usage.total_tokens if usage else None # Calculate cost if available cost = None if input_tokens and output_tokens and config.track_costs: cost = config.calculate_cost(raw_response.model, input_tokens, output_tokens) return RouterResponse( content=content, model=raw_response.model, provider="openai", latency=latency, input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens, cost=cost, finish_reason=choice.finish_reason, metadata={ "id": raw_response.id, "created": raw_response.created, "system_fingerprint": getattr(raw_response, "system_fingerprint", None), "tool_calls": getattr(choice.message, "tool_calls", None), "function_call": getattr(choice.message, "function_call", None), "logprobs": getattr(choice, "logprobs", None), }, raw_response=raw_response ) def _make_request(self, request_params: Dict[str, Any]) -> Any: """Make synchronous request to OpenAI. Args: request_params: Request parameters Returns: Raw OpenAI response """ try: response = self.client.chat.completions.create(**request_params) return response except Exception as e: raise map_provider_error("openai", e) async def _make_async_request(self, request_params: Dict[str, Any]) -> Any: """Make asynchronous request to OpenAI. Args: request_params: Request parameters Returns: Raw OpenAI response """ try: response = await self.async_client.chat.completions.create(**request_params) return response except Exception as e: raise map_provider_error("openai", e) def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]: """Stream responses from OpenAI. Args: prompt: User prompt **kwargs: Additional parameters Yields: RouterResponse chunks """ params = {**self.config, **kwargs} request_params = self._prepare_request(prompt, **params) request_params["stream"] = True try: start_time = time.time() stream = self.client.chat.completions.create(**request_params) for chunk in stream: if chunk.choices and len(chunk.choices) > 0: choice = chunk.choices[0] content = getattr(choice.delta, "content", None) if content: yield RouterResponse( content=content, model=chunk.model, provider="openai", latency=time.time() - start_time, finish_reason=getattr(choice, "finish_reason", None), metadata={ "chunk_id": chunk.id, "tool_calls": getattr(choice.delta, "tool_calls", None), "function_call": getattr(choice.delta, "function_call", None), }, raw_response=chunk ) except Exception as e: raise map_provider_error("openai", e) async def astream( self, prompt: str, **kwargs: Any ) -> AsyncGenerator[RouterResponse, None]: """Asynchronously stream responses from OpenAI. Args: prompt: User prompt **kwargs: Additional parameters Yields: RouterResponse chunks """ params = {**self.config, **kwargs} request_params = self._prepare_request(prompt, **params) request_params["stream"] = True try: start_time = time.time() stream = await self.async_client.chat.completions.create(**request_params) async for chunk in stream: if chunk.choices and len(chunk.choices) > 0: choice = chunk.choices[0] content = getattr(choice.delta, "content", None) if content: yield RouterResponse( content=content, model=chunk.model, provider="openai", latency=time.time() - start_time, finish_reason=getattr(choice, "finish_reason", None), metadata={ "chunk_id": chunk.id, "tool_calls": getattr(choice.delta, "tool_calls", None), "function_call": getattr(choice.delta, "function_call", None), }, raw_response=chunk ) except Exception as e: raise map_provider_error("openai", e) # Convenience alias OpenAI = OpenAICompatible