template-with-ai/router/openai_compatible.py
2025-07-01 17:07:02 +05:30

280 lines
No EOL
10 KiB
Python

"""OpenAI-compatible provider implementation."""
from typing import Any, Dict, Optional, AsyncGenerator, Generator
import time
from .base import AIRouter, RouterResponse
from .config import config
from .exceptions import (
ConfigurationError,
map_provider_error
)
class OpenAICompatible(AIRouter):
"""Router for OpenAI and compatible APIs (OpenAI, Azure OpenAI, etc.)."""
def __init__(
self,
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
organization: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize OpenAI-compatible router.
Args:
model: Model to use (gpt-4, gpt-3.5-turbo, etc.)
api_key: API key (optional if set in environment)
base_url: Base URL for API (for Azure or custom endpoints)
organization: Organization ID for OpenAI
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("openai")
if not api_key:
raise ConfigurationError(
"OpenAI API key not found. Set OPENAI_API_KEY environment variable.",
provider="openai"
)
super().__init__(model=model, api_key=api_key, **kwargs)
# Store additional config
self.base_url = base_url
self.organization = organization
# Initialize OpenAI client
try:
from openai import OpenAI, AsyncOpenAI
# Build client kwargs with proper types
client_kwargs: Dict[str, Any] = {
"api_key": api_key,
}
if base_url:
client_kwargs["base_url"] = base_url
if organization:
client_kwargs["organization"] = organization
# Add any additional client configuration from kwargs
# Note: Only pass through valid OpenAI client parameters
valid_client_params = ["timeout", "max_retries", "default_headers", "default_query", "http_client"]
for param in valid_client_params:
if param in kwargs:
client_kwargs[param] = kwargs.pop(param)
self.client = OpenAI(**client_kwargs)
self.async_client = AsyncOpenAI(**client_kwargs)
except ImportError:
raise ConfigurationError(
"openai package not installed. Install with: pip install openai",
provider="openai"
)
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
"""Prepare OpenAI request parameters.
Args:
prompt: User prompt
**kwargs: Additional parameters
Returns:
Request parameters
"""
# Build messages
messages = kwargs.get("messages", [
{"role": "user", "content": prompt}
])
# Build request parameters
params = {
"model": kwargs.get("model", self.model),
"messages": messages,
}
# Add optional parameters
if "temperature" in kwargs:
params["temperature"] = kwargs["temperature"]
elif hasattr(config, "default_temperature"):
params["temperature"] = config.default_temperature
if "max_tokens" in kwargs:
params["max_tokens"] = kwargs["max_tokens"]
elif hasattr(config, "default_max_tokens") and config.default_max_tokens:
params["max_tokens"] = config.default_max_tokens
if "top_p" in kwargs:
params["top_p"] = kwargs["top_p"]
elif hasattr(config, "default_top_p"):
params["top_p"] = config.default_top_p
# Other OpenAI-specific parameters
for key in ["n", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "user", "seed", "tools", "tool_choice", "response_format", "logprobs", "top_logprobs", "parallel_tool_calls"]:
if key in kwargs:
params[key] = kwargs[key]
return params
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
"""Parse OpenAI response into RouterResponse.
Args:
raw_response: Raw OpenAI response
latency: Request latency
Returns:
RouterResponse
"""
# Extract first choice
choice = raw_response.choices[0]
content = choice.message.content or ""
# Get token counts
usage = raw_response.usage
input_tokens = usage.prompt_tokens if usage else None
output_tokens = usage.completion_tokens if usage else None
total_tokens = usage.total_tokens if usage else None
# Calculate cost if available
cost = None
if input_tokens and output_tokens and config.track_costs:
cost = config.calculate_cost(raw_response.model, input_tokens, output_tokens)
return RouterResponse(
content=content,
model=raw_response.model,
provider="openai",
latency=latency,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
cost=cost,
finish_reason=choice.finish_reason,
metadata={
"id": raw_response.id,
"created": raw_response.created,
"system_fingerprint": getattr(raw_response, "system_fingerprint", None),
"tool_calls": getattr(choice.message, "tool_calls", None),
"function_call": getattr(choice.message, "function_call", None),
"logprobs": getattr(choice, "logprobs", None),
},
raw_response=raw_response
)
def _make_request(self, request_params: Dict[str, Any]) -> Any:
"""Make synchronous request to OpenAI.
Args:
request_params: Request parameters
Returns:
Raw OpenAI response
"""
try:
response = self.client.chat.completions.create(**request_params)
return response
except Exception as e:
raise map_provider_error("openai", e)
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
"""Make asynchronous request to OpenAI.
Args:
request_params: Request parameters
Returns:
Raw OpenAI response
"""
try:
response = await self.async_client.chat.completions.create(**request_params)
return response
except Exception as e:
raise map_provider_error("openai", e)
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
"""Stream responses from OpenAI.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
request_params["stream"] = True
try:
start_time = time.time()
stream = self.client.chat.completions.create(**request_params)
for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
choice = chunk.choices[0]
content = getattr(choice.delta, "content", None)
if content:
yield RouterResponse(
content=content,
model=chunk.model,
provider="openai",
latency=time.time() - start_time,
finish_reason=getattr(choice, "finish_reason", None),
metadata={
"chunk_id": chunk.id,
"tool_calls": getattr(choice.delta, "tool_calls", None),
"function_call": getattr(choice.delta, "function_call", None),
},
raw_response=chunk
)
except Exception as e:
raise map_provider_error("openai", e)
async def astream(
self, prompt: str, **kwargs: Any
) -> AsyncGenerator[RouterResponse, None]:
"""Asynchronously stream responses from OpenAI.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
request_params["stream"] = True
try:
start_time = time.time()
stream = await self.async_client.chat.completions.create(**request_params)
async for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
choice = chunk.choices[0]
content = getattr(choice.delta, "content", None)
if content:
yield RouterResponse(
content=content,
model=chunk.model,
provider="openai",
latency=time.time() - start_time,
finish_reason=getattr(choice, "finish_reason", None),
metadata={
"chunk_id": chunk.id,
"tool_calls": getattr(choice.delta, "tool_calls", None),
"function_call": getattr(choice.delta, "function_call", None),
},
raw_response=chunk
)
except Exception as e:
raise map_provider_error("openai", e)
# Convenience alias
OpenAI = OpenAICompatible