template-with-ai/router/gemini.py
2025-07-01 16:22:20 +05:30

242 lines
No EOL
8.5 KiB
Python

"""Google Gemini provider implementation."""
from typing import Any, Dict, Optional, AsyncGenerator, Generator
import time
from .base import AIRouter, RouterResponse
from .config import config
from .exceptions import (
AuthenticationError,
ConfigurationError,
map_provider_error
)
class Gemini(AIRouter):
"""Router for Google Gemini models."""
def __init__(
self,
model: str = "gemini-2.0-flash-001",
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Gemini router.
Args:
model: Gemini model to use (gemini-2.0-flash-001, gemini-1.5-pro, etc.)
api_key: Google API key (optional if set in environment)
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("gemini")
if not api_key:
raise ConfigurationError(
"Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.",
provider="gemini"
)
super().__init__(model=model, api_key=api_key, **kwargs)
# Initialize Gemini client using new google.genai library
try:
from google import genai
from google.genai import types
self.genai = genai
self.types = types
self.client = genai.Client(api_key=api_key)
except ImportError:
raise ConfigurationError(
"google-genai package not installed. Install with: pip install google-genai",
provider="gemini"
)
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
"""Prepare Gemini request parameters.
Args:
prompt: User prompt
**kwargs: Additional parameters
Returns:
Request parameters
"""
# Build config using new API structure
config_params = {}
if "temperature" in kwargs:
config_params["temperature"] = kwargs["temperature"]
elif hasattr(config, "default_temperature"):
config_params["temperature"] = config.default_temperature
if "max_tokens" in kwargs:
config_params["max_output_tokens"] = kwargs["max_tokens"]
elif "max_output_tokens" in kwargs:
config_params["max_output_tokens"] = kwargs["max_output_tokens"]
if "top_p" in kwargs:
config_params["top_p"] = kwargs["top_p"]
elif hasattr(config, "default_top_p"):
config_params["top_p"] = config.default_top_p
if "top_k" in kwargs:
config_params["top_k"] = kwargs["top_k"]
# Add safety settings if provided
if "safety_settings" in kwargs:
config_params["safety_settings"] = kwargs["safety_settings"]
return {
"model": kwargs.get("model", self.model),
"contents": prompt,
"config": self.types.GenerateContentConfig(**config_params) if config_params else None
}
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
"""Parse Gemini response into RouterResponse.
Args:
raw_response: Raw Gemini response
latency: Request latency
Returns:
RouterResponse
"""
# Extract text content
content = raw_response.text if hasattr(raw_response, "text") else ""
# Try to get token counts from usage_metadata
input_tokens = None
output_tokens = None
total_tokens = None
if hasattr(raw_response, "usage_metadata"):
usage = raw_response.usage_metadata
# Handle both old and new attribute names
input_tokens = getattr(usage, "prompt_token_count", None) or getattr(usage, "cached_content_token_count", None)
output_tokens = getattr(usage, "candidates_token_count", None)
total_tokens = getattr(usage, "total_token_count", None)
# Calculate cost if available
cost = None
if input_tokens and output_tokens and config.track_costs:
cost = config.calculate_cost(self.model, input_tokens, output_tokens)
# Extract finish reason
finish_reason = "stop"
if hasattr(raw_response, "candidates") and raw_response.candidates:
candidate = raw_response.candidates[0]
finish_reason = getattr(candidate, "finish_reason", "stop")
return RouterResponse(
content=content,
model=raw_response.model_version if hasattr(raw_response, "model_version") else self.model,
provider="gemini",
latency=latency,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens or ((input_tokens + output_tokens) if input_tokens and output_tokens else None),
cost=cost,
finish_reason=finish_reason,
metadata={
"prompt_feedback": getattr(raw_response, "prompt_feedback", None),
"safety_ratings": getattr(raw_response.candidates[0], "safety_ratings", None) if hasattr(raw_response, "candidates") and raw_response.candidates else None
},
raw_response=raw_response
)
def _make_request(self, request_params: Dict[str, Any]) -> Any:
"""Make synchronous request to Gemini.
Args:
request_params: Request parameters
Returns:
Raw Gemini response
"""
try:
response = self.client.models.generate_content(**request_params)
return response
except Exception as e:
raise map_provider_error("gemini", e)
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
"""Make asynchronous request to Gemini.
Args:
request_params: Request parameters
Returns:
Raw Gemini response
"""
try:
response = await self.client.aio.models.generate_content(**request_params)
return response
except Exception as e:
raise map_provider_error("gemini", e)
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
"""Stream responses from Gemini.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
try:
start_time = time.time()
stream = self.client.models.generate_content_stream(**request_params)
for chunk in stream:
if chunk.text:
yield RouterResponse(
content=chunk.text,
model=self.model,
provider="gemini",
latency=time.time() - start_time,
finish_reason=None,
metadata={},
raw_response=chunk
)
except Exception as e:
raise map_provider_error("gemini", e)
async def astream(
self, prompt: str, **kwargs: Any
) -> AsyncGenerator[RouterResponse, None]:
"""Asynchronously stream responses from Gemini.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
try:
start_time = time.time()
stream = await self.client.aio.models.generate_content_stream(**request_params)
async for chunk in stream:
if chunk.text:
yield RouterResponse(
content=chunk.text,
model=self.model,
provider="gemini",
latency=time.time() - start_time,
finish_reason=None,
metadata={},
raw_response=chunk
)
except Exception as e:
raise map_provider_error("gemini", e)