242 lines
No EOL
8.5 KiB
Python
242 lines
No EOL
8.5 KiB
Python
"""Google Gemini provider implementation."""
|
|
|
|
from typing import Any, Dict, Optional, AsyncGenerator, Generator
|
|
import time
|
|
|
|
from .base import AIRouter, RouterResponse
|
|
from .config import config
|
|
from .exceptions import (
|
|
AuthenticationError,
|
|
ConfigurationError,
|
|
map_provider_error
|
|
)
|
|
|
|
|
|
class Gemini(AIRouter):
|
|
"""Router for Google Gemini models."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "gemini-2.0-flash-001",
|
|
api_key: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Initialize Gemini router.
|
|
|
|
Args:
|
|
model: Gemini model to use (gemini-2.0-flash-001, gemini-1.5-pro, etc.)
|
|
api_key: Google API key (optional if set in environment)
|
|
**kwargs: Additional configuration
|
|
"""
|
|
# Get API key from config if not provided
|
|
if not api_key:
|
|
api_key = config.get_api_key("gemini")
|
|
|
|
if not api_key:
|
|
raise ConfigurationError(
|
|
"Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.",
|
|
provider="gemini"
|
|
)
|
|
|
|
super().__init__(model=model, api_key=api_key, **kwargs)
|
|
|
|
# Initialize Gemini client using new google.genai library
|
|
try:
|
|
from google import genai
|
|
from google.genai import types
|
|
self.genai = genai
|
|
self.types = types
|
|
self.client = genai.Client(api_key=api_key)
|
|
except ImportError:
|
|
raise ConfigurationError(
|
|
"google-genai package not installed. Install with: pip install google-genai",
|
|
provider="gemini"
|
|
)
|
|
|
|
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
|
|
"""Prepare Gemini request parameters.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
Request parameters
|
|
"""
|
|
# Build config using new API structure
|
|
config_params = {}
|
|
|
|
if "temperature" in kwargs:
|
|
config_params["temperature"] = kwargs["temperature"]
|
|
elif hasattr(config, "default_temperature"):
|
|
config_params["temperature"] = config.default_temperature
|
|
|
|
if "max_tokens" in kwargs:
|
|
config_params["max_output_tokens"] = kwargs["max_tokens"]
|
|
elif "max_output_tokens" in kwargs:
|
|
config_params["max_output_tokens"] = kwargs["max_output_tokens"]
|
|
|
|
if "top_p" in kwargs:
|
|
config_params["top_p"] = kwargs["top_p"]
|
|
elif hasattr(config, "default_top_p"):
|
|
config_params["top_p"] = config.default_top_p
|
|
|
|
if "top_k" in kwargs:
|
|
config_params["top_k"] = kwargs["top_k"]
|
|
|
|
# Add safety settings if provided
|
|
if "safety_settings" in kwargs:
|
|
config_params["safety_settings"] = kwargs["safety_settings"]
|
|
|
|
return {
|
|
"model": kwargs.get("model", self.model),
|
|
"contents": prompt,
|
|
"config": self.types.GenerateContentConfig(**config_params) if config_params else None
|
|
}
|
|
|
|
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
|
|
"""Parse Gemini response into RouterResponse.
|
|
|
|
Args:
|
|
raw_response: Raw Gemini response
|
|
latency: Request latency
|
|
|
|
Returns:
|
|
RouterResponse
|
|
"""
|
|
# Extract text content
|
|
content = raw_response.text if hasattr(raw_response, "text") else ""
|
|
|
|
# Try to get token counts from usage_metadata
|
|
input_tokens = None
|
|
output_tokens = None
|
|
total_tokens = None
|
|
|
|
if hasattr(raw_response, "usage_metadata"):
|
|
usage = raw_response.usage_metadata
|
|
# Handle both old and new attribute names
|
|
input_tokens = getattr(usage, "prompt_token_count", None) or getattr(usage, "cached_content_token_count", None)
|
|
output_tokens = getattr(usage, "candidates_token_count", None)
|
|
total_tokens = getattr(usage, "total_token_count", None)
|
|
|
|
# Calculate cost if available
|
|
cost = None
|
|
if input_tokens and output_tokens and config.track_costs:
|
|
cost = config.calculate_cost(self.model, input_tokens, output_tokens)
|
|
|
|
# Extract finish reason
|
|
finish_reason = "stop"
|
|
if hasattr(raw_response, "candidates") and raw_response.candidates:
|
|
candidate = raw_response.candidates[0]
|
|
finish_reason = getattr(candidate, "finish_reason", "stop")
|
|
|
|
return RouterResponse(
|
|
content=content,
|
|
model=raw_response.model_version if hasattr(raw_response, "model_version") else self.model,
|
|
provider="gemini",
|
|
latency=latency,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
total_tokens=total_tokens or ((input_tokens + output_tokens) if input_tokens and output_tokens else None),
|
|
cost=cost,
|
|
finish_reason=finish_reason,
|
|
metadata={
|
|
"prompt_feedback": getattr(raw_response, "prompt_feedback", None),
|
|
"safety_ratings": getattr(raw_response.candidates[0], "safety_ratings", None) if hasattr(raw_response, "candidates") and raw_response.candidates else None
|
|
},
|
|
raw_response=raw_response
|
|
)
|
|
|
|
def _make_request(self, request_params: Dict[str, Any]) -> Any:
|
|
"""Make synchronous request to Gemini.
|
|
|
|
Args:
|
|
request_params: Request parameters
|
|
|
|
Returns:
|
|
Raw Gemini response
|
|
"""
|
|
try:
|
|
response = self.client.models.generate_content(**request_params)
|
|
return response
|
|
except Exception as e:
|
|
raise map_provider_error("gemini", e)
|
|
|
|
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
|
|
"""Make asynchronous request to Gemini.
|
|
|
|
Args:
|
|
request_params: Request parameters
|
|
|
|
Returns:
|
|
Raw Gemini response
|
|
"""
|
|
try:
|
|
response = await self.client.aio.models.generate_content(**request_params)
|
|
return response
|
|
except Exception as e:
|
|
raise map_provider_error("gemini", e)
|
|
|
|
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
|
|
"""Stream responses from Gemini.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Yields:
|
|
RouterResponse chunks
|
|
"""
|
|
params = {**self.config, **kwargs}
|
|
request_params = self._prepare_request(prompt, **params)
|
|
|
|
try:
|
|
start_time = time.time()
|
|
stream = self.client.models.generate_content_stream(**request_params)
|
|
|
|
for chunk in stream:
|
|
if chunk.text:
|
|
yield RouterResponse(
|
|
content=chunk.text,
|
|
model=self.model,
|
|
provider="gemini",
|
|
latency=time.time() - start_time,
|
|
finish_reason=None,
|
|
metadata={},
|
|
raw_response=chunk
|
|
)
|
|
except Exception as e:
|
|
raise map_provider_error("gemini", e)
|
|
|
|
async def astream(
|
|
self, prompt: str, **kwargs: Any
|
|
) -> AsyncGenerator[RouterResponse, None]:
|
|
"""Asynchronously stream responses from Gemini.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Yields:
|
|
RouterResponse chunks
|
|
"""
|
|
params = {**self.config, **kwargs}
|
|
request_params = self._prepare_request(prompt, **params)
|
|
|
|
try:
|
|
start_time = time.time()
|
|
stream = await self.client.aio.models.generate_content_stream(**request_params)
|
|
|
|
async for chunk in stream:
|
|
if chunk.text:
|
|
yield RouterResponse(
|
|
content=chunk.text,
|
|
model=self.model,
|
|
provider="gemini",
|
|
latency=time.time() - start_time,
|
|
finish_reason=None,
|
|
metadata={},
|
|
raw_response=chunk
|
|
)
|
|
except Exception as e:
|
|
raise map_provider_error("gemini", e) |