280 lines
No EOL
10 KiB
Python
280 lines
No EOL
10 KiB
Python
"""OpenAI-compatible provider implementation."""
|
|
|
|
from typing import Any, Dict, Optional, AsyncGenerator, Generator
|
|
import time
|
|
|
|
from .base import AIRouter, RouterResponse
|
|
from .config import config
|
|
from .exceptions import (
|
|
ConfigurationError,
|
|
map_provider_error
|
|
)
|
|
|
|
|
|
class OpenAICompatible(AIRouter):
|
|
"""Router for OpenAI and compatible APIs (OpenAI, Azure OpenAI, etc.)."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "gpt-3.5-turbo",
|
|
api_key: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
organization: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Initialize OpenAI-compatible router.
|
|
|
|
Args:
|
|
model: Model to use (gpt-4, gpt-3.5-turbo, etc.)
|
|
api_key: API key (optional if set in environment)
|
|
base_url: Base URL for API (for Azure or custom endpoints)
|
|
organization: Organization ID for OpenAI
|
|
**kwargs: Additional configuration
|
|
"""
|
|
# Get API key from config if not provided
|
|
if not api_key:
|
|
api_key = config.get_api_key("openai")
|
|
|
|
if not api_key:
|
|
raise ConfigurationError(
|
|
"OpenAI API key not found. Set OPENAI_API_KEY environment variable.",
|
|
provider="openai"
|
|
)
|
|
|
|
super().__init__(model=model, api_key=api_key, **kwargs)
|
|
|
|
# Store additional config
|
|
self.base_url = base_url
|
|
self.organization = organization
|
|
|
|
# Initialize OpenAI client
|
|
try:
|
|
from openai import OpenAI, AsyncOpenAI
|
|
|
|
# Build client kwargs with proper types
|
|
client_kwargs: Dict[str, Any] = {
|
|
"api_key": api_key,
|
|
}
|
|
if base_url:
|
|
client_kwargs["base_url"] = base_url
|
|
if organization:
|
|
client_kwargs["organization"] = organization
|
|
|
|
# Add any additional client configuration from kwargs
|
|
# Note: Only pass through valid OpenAI client parameters
|
|
valid_client_params = ["timeout", "max_retries", "default_headers", "default_query", "http_client"]
|
|
for param in valid_client_params:
|
|
if param in kwargs:
|
|
client_kwargs[param] = kwargs.pop(param)
|
|
|
|
self.client = OpenAI(**client_kwargs)
|
|
self.async_client = AsyncOpenAI(**client_kwargs)
|
|
except ImportError:
|
|
raise ConfigurationError(
|
|
"openai package not installed. Install with: pip install openai",
|
|
provider="openai"
|
|
)
|
|
|
|
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
|
|
"""Prepare OpenAI request parameters.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
Request parameters
|
|
"""
|
|
# Build messages
|
|
messages = kwargs.get("messages", [
|
|
{"role": "user", "content": prompt}
|
|
])
|
|
|
|
# Build request parameters
|
|
params = {
|
|
"model": kwargs.get("model", self.model),
|
|
"messages": messages,
|
|
}
|
|
|
|
# Add optional parameters
|
|
if "temperature" in kwargs:
|
|
params["temperature"] = kwargs["temperature"]
|
|
elif hasattr(config, "default_temperature"):
|
|
params["temperature"] = config.default_temperature
|
|
|
|
if "max_tokens" in kwargs:
|
|
params["max_tokens"] = kwargs["max_tokens"]
|
|
elif hasattr(config, "default_max_tokens") and config.default_max_tokens:
|
|
params["max_tokens"] = config.default_max_tokens
|
|
|
|
if "top_p" in kwargs:
|
|
params["top_p"] = kwargs["top_p"]
|
|
elif hasattr(config, "default_top_p"):
|
|
params["top_p"] = config.default_top_p
|
|
|
|
# Other OpenAI-specific parameters
|
|
for key in ["n", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "user", "seed", "tools", "tool_choice", "response_format", "logprobs", "top_logprobs", "parallel_tool_calls"]:
|
|
if key in kwargs:
|
|
params[key] = kwargs[key]
|
|
|
|
return params
|
|
|
|
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
|
|
"""Parse OpenAI response into RouterResponse.
|
|
|
|
Args:
|
|
raw_response: Raw OpenAI response
|
|
latency: Request latency
|
|
|
|
Returns:
|
|
RouterResponse
|
|
"""
|
|
# Extract first choice
|
|
choice = raw_response.choices[0]
|
|
content = choice.message.content or ""
|
|
|
|
# Get token counts
|
|
usage = raw_response.usage
|
|
input_tokens = usage.prompt_tokens if usage else None
|
|
output_tokens = usage.completion_tokens if usage else None
|
|
total_tokens = usage.total_tokens if usage else None
|
|
|
|
# Calculate cost if available
|
|
cost = None
|
|
if input_tokens and output_tokens and config.track_costs:
|
|
cost = config.calculate_cost(raw_response.model, input_tokens, output_tokens)
|
|
|
|
return RouterResponse(
|
|
content=content,
|
|
model=raw_response.model,
|
|
provider="openai",
|
|
latency=latency,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
total_tokens=total_tokens,
|
|
cost=cost,
|
|
finish_reason=choice.finish_reason,
|
|
metadata={
|
|
"id": raw_response.id,
|
|
"created": raw_response.created,
|
|
"system_fingerprint": getattr(raw_response, "system_fingerprint", None),
|
|
"tool_calls": getattr(choice.message, "tool_calls", None),
|
|
"function_call": getattr(choice.message, "function_call", None),
|
|
"logprobs": getattr(choice, "logprobs", None),
|
|
},
|
|
raw_response=raw_response
|
|
)
|
|
|
|
def _make_request(self, request_params: Dict[str, Any]) -> Any:
|
|
"""Make synchronous request to OpenAI.
|
|
|
|
Args:
|
|
request_params: Request parameters
|
|
|
|
Returns:
|
|
Raw OpenAI response
|
|
"""
|
|
try:
|
|
response = self.client.chat.completions.create(**request_params)
|
|
return response
|
|
except Exception as e:
|
|
raise map_provider_error("openai", e)
|
|
|
|
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
|
|
"""Make asynchronous request to OpenAI.
|
|
|
|
Args:
|
|
request_params: Request parameters
|
|
|
|
Returns:
|
|
Raw OpenAI response
|
|
"""
|
|
try:
|
|
response = await self.async_client.chat.completions.create(**request_params)
|
|
return response
|
|
except Exception as e:
|
|
raise map_provider_error("openai", e)
|
|
|
|
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
|
|
"""Stream responses from OpenAI.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Yields:
|
|
RouterResponse chunks
|
|
"""
|
|
params = {**self.config, **kwargs}
|
|
request_params = self._prepare_request(prompt, **params)
|
|
request_params["stream"] = True
|
|
|
|
try:
|
|
start_time = time.time()
|
|
stream = self.client.chat.completions.create(**request_params)
|
|
|
|
for chunk in stream:
|
|
if chunk.choices and len(chunk.choices) > 0:
|
|
choice = chunk.choices[0]
|
|
content = getattr(choice.delta, "content", None)
|
|
if content:
|
|
yield RouterResponse(
|
|
content=content,
|
|
model=chunk.model,
|
|
provider="openai",
|
|
latency=time.time() - start_time,
|
|
finish_reason=getattr(choice, "finish_reason", None),
|
|
metadata={
|
|
"chunk_id": chunk.id,
|
|
"tool_calls": getattr(choice.delta, "tool_calls", None),
|
|
"function_call": getattr(choice.delta, "function_call", None),
|
|
},
|
|
raw_response=chunk
|
|
)
|
|
except Exception as e:
|
|
raise map_provider_error("openai", e)
|
|
|
|
async def astream(
|
|
self, prompt: str, **kwargs: Any
|
|
) -> AsyncGenerator[RouterResponse, None]:
|
|
"""Asynchronously stream responses from OpenAI.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Yields:
|
|
RouterResponse chunks
|
|
"""
|
|
params = {**self.config, **kwargs}
|
|
request_params = self._prepare_request(prompt, **params)
|
|
request_params["stream"] = True
|
|
|
|
try:
|
|
start_time = time.time()
|
|
stream = await self.async_client.chat.completions.create(**request_params)
|
|
|
|
async for chunk in stream:
|
|
if chunk.choices and len(chunk.choices) > 0:
|
|
choice = chunk.choices[0]
|
|
content = getattr(choice.delta, "content", None)
|
|
if content:
|
|
yield RouterResponse(
|
|
content=content,
|
|
model=chunk.model,
|
|
provider="openai",
|
|
latency=time.time() - start_time,
|
|
finish_reason=getattr(choice, "finish_reason", None),
|
|
metadata={
|
|
"chunk_id": chunk.id,
|
|
"tool_calls": getattr(choice.delta, "tool_calls", None),
|
|
"function_call": getattr(choice.delta, "function_call", None),
|
|
},
|
|
raw_response=chunk
|
|
)
|
|
except Exception as e:
|
|
raise map_provider_error("openai", e)
|
|
|
|
|
|
# Convenience alias
|
|
OpenAI = OpenAICompatible |