"""Google Gemini provider implementation.""" from typing import Any, Dict, Optional, AsyncGenerator, Generator import time from .base import AIRouter, RouterResponse from .config import config from .exceptions import ( AuthenticationError, ConfigurationError, map_provider_error ) class Gemini(AIRouter): """Router for Google Gemini models.""" def __init__( self, model: str = "gemini-2.0-flash-001", api_key: Optional[str] = None, **kwargs: Any ) -> None: """Initialize Gemini router. Args: model: Gemini model to use (gemini-2.0-flash-001, gemini-1.5-pro, etc.) api_key: Google API key (optional if set in environment) **kwargs: Additional configuration """ # Get API key from config if not provided if not api_key: api_key = config.get_api_key("gemini") if not api_key: raise ConfigurationError( "Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.", provider="gemini" ) super().__init__(model=model, api_key=api_key, **kwargs) # Initialize Gemini client using new google.genai library try: from google import genai from google.genai import types self.genai = genai self.types = types self.client = genai.Client(api_key=api_key) except ImportError: raise ConfigurationError( "google-genai package not installed. Install with: pip install google-genai", provider="gemini" ) def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]: """Prepare Gemini request parameters. Args: prompt: User prompt **kwargs: Additional parameters Returns: Request parameters """ # Build config using new API structure config_params = {} if "temperature" in kwargs: config_params["temperature"] = kwargs["temperature"] elif hasattr(config, "default_temperature"): config_params["temperature"] = config.default_temperature if "max_tokens" in kwargs: config_params["max_output_tokens"] = kwargs["max_tokens"] elif "max_output_tokens" in kwargs: config_params["max_output_tokens"] = kwargs["max_output_tokens"] if "top_p" in kwargs: config_params["top_p"] = kwargs["top_p"] elif hasattr(config, "default_top_p"): config_params["top_p"] = config.default_top_p if "top_k" in kwargs: config_params["top_k"] = kwargs["top_k"] # Add safety settings if provided if "safety_settings" in kwargs: config_params["safety_settings"] = kwargs["safety_settings"] return { "model": kwargs.get("model", self.model), "contents": prompt, "config": self.types.GenerateContentConfig(**config_params) if config_params else None } def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse: """Parse Gemini response into RouterResponse. Args: raw_response: Raw Gemini response latency: Request latency Returns: RouterResponse """ # Extract text content content = raw_response.text if hasattr(raw_response, "text") else "" # Try to get token counts from usage_metadata input_tokens = None output_tokens = None total_tokens = None if hasattr(raw_response, "usage_metadata"): usage = raw_response.usage_metadata # Handle both old and new attribute names input_tokens = getattr(usage, "prompt_token_count", None) or getattr(usage, "cached_content_token_count", None) output_tokens = getattr(usage, "candidates_token_count", None) total_tokens = getattr(usage, "total_token_count", None) # Calculate cost if available cost = None if input_tokens and output_tokens and config.track_costs: cost = config.calculate_cost(self.model, input_tokens, output_tokens) # Extract finish reason finish_reason = "stop" if hasattr(raw_response, "candidates") and raw_response.candidates: candidate = raw_response.candidates[0] finish_reason = getattr(candidate, "finish_reason", "stop") return RouterResponse( content=content, model=raw_response.model_version if hasattr(raw_response, "model_version") else self.model, provider="gemini", latency=latency, input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens or ((input_tokens + output_tokens) if input_tokens and output_tokens else None), cost=cost, finish_reason=finish_reason, metadata={ "prompt_feedback": getattr(raw_response, "prompt_feedback", None), "safety_ratings": getattr(raw_response.candidates[0], "safety_ratings", None) if hasattr(raw_response, "candidates") and raw_response.candidates else None }, raw_response=raw_response ) def _make_request(self, request_params: Dict[str, Any]) -> Any: """Make synchronous request to Gemini. Args: request_params: Request parameters Returns: Raw Gemini response """ try: response = self.client.models.generate_content(**request_params) return response except Exception as e: raise map_provider_error("gemini", e) async def _make_async_request(self, request_params: Dict[str, Any]) -> Any: """Make asynchronous request to Gemini. Args: request_params: Request parameters Returns: Raw Gemini response """ try: response = await self.client.aio.models.generate_content(**request_params) return response except Exception as e: raise map_provider_error("gemini", e) def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]: """Stream responses from Gemini. Args: prompt: User prompt **kwargs: Additional parameters Yields: RouterResponse chunks """ params = {**self.config, **kwargs} request_params = self._prepare_request(prompt, **params) try: start_time = time.time() stream = self.client.models.generate_content_stream(**request_params) for chunk in stream: if chunk.text: yield RouterResponse( content=chunk.text, model=self.model, provider="gemini", latency=time.time() - start_time, finish_reason=None, metadata={}, raw_response=chunk ) except Exception as e: raise map_provider_error("gemini", e) async def astream( self, prompt: str, **kwargs: Any ) -> AsyncGenerator[RouterResponse, None]: """Asynchronously stream responses from Gemini. Args: prompt: User prompt **kwargs: Additional parameters Yields: RouterResponse chunks """ params = {**self.config, **kwargs} request_params = self._prepare_request(prompt, **params) try: start_time = time.time() stream = await self.client.aio.models.generate_content_stream(**request_params) async for chunk in stream: if chunk.text: yield RouterResponse( content=chunk.text, model=self.model, provider="gemini", latency=time.time() - start_time, finish_reason=None, metadata={}, raw_response=chunk ) except Exception as e: raise map_provider_error("gemini", e)