"""Cohere provider implementation.""" from typing import Any, Dict, Optional, AsyncGenerator, Generator import time from .base import AIRouter, RouterResponse from .config import config from .exceptions import ( ConfigurationError, map_provider_error ) class Cohere(AIRouter): """Router for Cohere models.""" def __init__( self, model: str = "command-r-plus", api_key: Optional[str] = None, **kwargs: Any ) -> None: """Initialize Cohere router. Args: model: Cohere model to use (command-r-plus, command-r, etc.) api_key: Cohere API key (optional if set in environment) **kwargs: Additional configuration """ # Get API key from config if not provided if not api_key: api_key = config.get_api_key("cohere") if not api_key: raise ConfigurationError( "Cohere API key not found. Set COHERE_API_KEY environment variable.", provider="cohere" ) super().__init__(model=model, api_key=api_key, **kwargs) # Initialize Cohere client try: import cohere # For v5.15.0, use the standard Client self.client = cohere.Client(api_key) self.async_client = cohere.AsyncClient(api_key) except ImportError: raise ConfigurationError( "cohere package not installed. Install with: pip install cohere", provider="cohere" ) def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]: """Prepare Cohere request parameters. Args: prompt: User prompt **kwargs: Additional parameters Returns: Request parameters """ # Build request parameters for v5.15.0 API params = { "model": kwargs.get("model", self.model), "message": prompt, # v5.15.0 uses 'message' not 'messages' } # Add optional parameters if "temperature" in kwargs: params["temperature"] = kwargs["temperature"] elif hasattr(config, "default_temperature"): params["temperature"] = config.default_temperature if "max_tokens" in kwargs: params["max_tokens"] = kwargs["max_tokens"] elif hasattr(config, "default_max_tokens") and config.default_max_tokens: params["max_tokens"] = config.default_max_tokens # Cohere uses 'p' instead of 'top_p' if "p" in kwargs: params["p"] = kwargs["p"] elif "top_p" in kwargs: params["p"] = kwargs["top_p"] elif hasattr(config, "default_top_p"): params["p"] = config.default_top_p # Cohere uses 'k' instead of 'top_k' if "k" in kwargs: params["k"] = kwargs["k"] elif "top_k" in kwargs: params["k"] = kwargs["top_k"] # Other Cohere-specific parameters for v5.15.0 for key in ["chat_history", "preamble", "conversation_id", "prompt_truncation", "connectors", "search_queries_only", "documents", "tools", "tool_results", "stop_sequences", "seed"]: if key in kwargs: params[key] = kwargs[key] return params def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse: """Parse Cohere response into RouterResponse. Args: raw_response: Raw Cohere response latency: Request latency Returns: RouterResponse """ # Extract text content - v5.15.0 has 'text' attribute directly content = getattr(raw_response, "text", "") # Get token counts from meta/usage input_tokens = None output_tokens = None # Try different token count locations based on v5.15.0 structure if hasattr(raw_response, "meta") and hasattr(raw_response.meta, "billed_units"): billed = raw_response.meta.billed_units input_tokens = getattr(billed, "input_tokens", None) output_tokens = getattr(billed, "output_tokens", None) elif hasattr(raw_response, "usage") and hasattr(raw_response.usage, "billed_units"): billed = raw_response.usage.billed_units input_tokens = getattr(billed, "input_tokens", None) output_tokens = getattr(billed, "output_tokens", None) # Calculate cost if available cost = None if input_tokens and output_tokens and config.track_costs: cost = config.calculate_cost(self.model, input_tokens, output_tokens) # Extract finish reason finish_reason = "stop" if hasattr(raw_response, "finish_reason"): finish_reason = raw_response.finish_reason return RouterResponse( content=content, model=self.model, provider="cohere", latency=latency, input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=(input_tokens + output_tokens) if input_tokens and output_tokens else None, cost=cost, finish_reason=finish_reason, metadata={ "id": getattr(raw_response, "id", None), "generation_id": getattr(raw_response, "generation_id", None), "citations": getattr(raw_response, "citations", None), "documents": getattr(raw_response, "documents", None), "search_results": getattr(raw_response, "search_results", None), "search_queries": getattr(raw_response, "search_queries", None), "tool_calls": getattr(raw_response, "tool_calls", None), }, raw_response=raw_response ) def _make_request(self, request_params: Dict[str, Any]) -> Any: """Make synchronous request to Cohere. Args: request_params: Request parameters Returns: Raw Cohere response """ try: response = self.client.chat(**request_params) return response except Exception as e: raise map_provider_error("cohere", e) async def _make_async_request(self, request_params: Dict[str, Any]) -> Any: """Make asynchronous request to Cohere. Args: request_params: Request parameters Returns: Raw Cohere response """ try: response = await self.async_client.chat(**request_params) return response except Exception as e: raise map_provider_error("cohere", e) def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]: """Stream responses from Cohere. Args: prompt: User prompt **kwargs: Additional parameters Yields: RouterResponse chunks """ params = {**self.config, **kwargs} request_params = self._prepare_request(prompt, **params) try: start_time = time.time() stream = self.client.chat_stream(**request_params) for event in stream: # v5.15.0 uses event_type and text directly if hasattr(event, "event_type") and event.event_type == "text-generation": content = getattr(event, "text", "") if content: yield RouterResponse( content=content, model=self.model, provider="cohere", latency=time.time() - start_time, finish_reason=None, metadata={"event_type": event.event_type}, raw_response=event ) except Exception as e: raise map_provider_error("cohere", e) async def astream( self, prompt: str, **kwargs: Any ) -> AsyncGenerator[RouterResponse, None]: """Asynchronously stream responses from Cohere. Args: prompt: User prompt **kwargs: Additional parameters Yields: RouterResponse chunks """ params = {**self.config, **kwargs} request_params = self._prepare_request(prompt, **params) try: start_time = time.time() stream = self.async_client.chat_stream(**request_params) async for event in stream: # v5.15.0 uses event_type and text directly if hasattr(event, "event_type") and event.event_type == "text-generation": content = getattr(event, "text", "") if content: yield RouterResponse( content=content, model=self.model, provider="cohere", latency=time.time() - start_time, finish_reason=None, metadata={"event_type": event.event_type}, raw_response=event ) except Exception as e: raise map_provider_error("cohere", e)