template-with-ai/router/cohere.py
2025-07-01 16:22:20 +05:30

259 lines
No EOL
9.4 KiB
Python

"""Cohere provider implementation."""
from typing import Any, Dict, Optional, AsyncGenerator, Generator
import time
from .base import AIRouter, RouterResponse
from .config import config
from .exceptions import (
ConfigurationError,
map_provider_error
)
class Cohere(AIRouter):
"""Router for Cohere models."""
def __init__(
self,
model: str = "command-r-plus",
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Cohere router.
Args:
model: Cohere model to use (command-r-plus, command-r, etc.)
api_key: Cohere API key (optional if set in environment)
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("cohere")
if not api_key:
raise ConfigurationError(
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
provider="cohere"
)
super().__init__(model=model, api_key=api_key, **kwargs)
# Initialize Cohere client
try:
import cohere
# For v5.15.0, use the standard Client
self.client = cohere.Client(api_key)
self.async_client = cohere.AsyncClient(api_key)
except ImportError:
raise ConfigurationError(
"cohere package not installed. Install with: pip install cohere",
provider="cohere"
)
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
"""Prepare Cohere request parameters.
Args:
prompt: User prompt
**kwargs: Additional parameters
Returns:
Request parameters
"""
# Build request parameters for v5.15.0 API
params = {
"model": kwargs.get("model", self.model),
"message": prompt, # v5.15.0 uses 'message' not 'messages'
}
# Add optional parameters
if "temperature" in kwargs:
params["temperature"] = kwargs["temperature"]
elif hasattr(config, "default_temperature"):
params["temperature"] = config.default_temperature
if "max_tokens" in kwargs:
params["max_tokens"] = kwargs["max_tokens"]
elif hasattr(config, "default_max_tokens") and config.default_max_tokens:
params["max_tokens"] = config.default_max_tokens
# Cohere uses 'p' instead of 'top_p'
if "p" in kwargs:
params["p"] = kwargs["p"]
elif "top_p" in kwargs:
params["p"] = kwargs["top_p"]
elif hasattr(config, "default_top_p"):
params["p"] = config.default_top_p
# Cohere uses 'k' instead of 'top_k'
if "k" in kwargs:
params["k"] = kwargs["k"]
elif "top_k" in kwargs:
params["k"] = kwargs["top_k"]
# Other Cohere-specific parameters for v5.15.0
for key in ["chat_history", "preamble", "conversation_id", "prompt_truncation", "connectors", "search_queries_only", "documents", "tools", "tool_results", "stop_sequences", "seed"]:
if key in kwargs:
params[key] = kwargs[key]
return params
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
"""Parse Cohere response into RouterResponse.
Args:
raw_response: Raw Cohere response
latency: Request latency
Returns:
RouterResponse
"""
# Extract text content - v5.15.0 has 'text' attribute directly
content = getattr(raw_response, "text", "")
# Get token counts from meta/usage
input_tokens = None
output_tokens = None
# Try different token count locations based on v5.15.0 structure
if hasattr(raw_response, "meta") and hasattr(raw_response.meta, "billed_units"):
billed = raw_response.meta.billed_units
input_tokens = getattr(billed, "input_tokens", None)
output_tokens = getattr(billed, "output_tokens", None)
elif hasattr(raw_response, "usage") and hasattr(raw_response.usage, "billed_units"):
billed = raw_response.usage.billed_units
input_tokens = getattr(billed, "input_tokens", None)
output_tokens = getattr(billed, "output_tokens", None)
# Calculate cost if available
cost = None
if input_tokens and output_tokens and config.track_costs:
cost = config.calculate_cost(self.model, input_tokens, output_tokens)
# Extract finish reason
finish_reason = "stop"
if hasattr(raw_response, "finish_reason"):
finish_reason = raw_response.finish_reason
return RouterResponse(
content=content,
model=self.model,
provider="cohere",
latency=latency,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=(input_tokens + output_tokens) if input_tokens and output_tokens else None,
cost=cost,
finish_reason=finish_reason,
metadata={
"id": getattr(raw_response, "id", None),
"generation_id": getattr(raw_response, "generation_id", None),
"citations": getattr(raw_response, "citations", None),
"documents": getattr(raw_response, "documents", None),
"search_results": getattr(raw_response, "search_results", None),
"search_queries": getattr(raw_response, "search_queries", None),
"tool_calls": getattr(raw_response, "tool_calls", None),
},
raw_response=raw_response
)
def _make_request(self, request_params: Dict[str, Any]) -> Any:
"""Make synchronous request to Cohere.
Args:
request_params: Request parameters
Returns:
Raw Cohere response
"""
try:
response = self.client.chat(**request_params)
return response
except Exception as e:
raise map_provider_error("cohere", e)
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
"""Make asynchronous request to Cohere.
Args:
request_params: Request parameters
Returns:
Raw Cohere response
"""
try:
response = await self.async_client.chat(**request_params)
return response
except Exception as e:
raise map_provider_error("cohere", e)
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
"""Stream responses from Cohere.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
try:
start_time = time.time()
stream = self.client.chat_stream(**request_params)
for event in stream:
# v5.15.0 uses event_type and text directly
if hasattr(event, "event_type") and event.event_type == "text-generation":
content = getattr(event, "text", "")
if content:
yield RouterResponse(
content=content,
model=self.model,
provider="cohere",
latency=time.time() - start_time,
finish_reason=None,
metadata={"event_type": event.event_type},
raw_response=event
)
except Exception as e:
raise map_provider_error("cohere", e)
async def astream(
self, prompt: str, **kwargs: Any
) -> AsyncGenerator[RouterResponse, None]:
"""Asynchronously stream responses from Cohere.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
try:
start_time = time.time()
stream = self.async_client.chat_stream(**request_params)
async for event in stream:
# v5.15.0 uses event_type and text directly
if hasattr(event, "event_type") and event.event_type == "text-generation":
content = getattr(event, "text", "")
if content:
yield RouterResponse(
content=content,
model=self.model,
provider="cohere",
latency=time.time() - start_time,
finish_reason=None,
metadata={"event_type": event.event_type},
raw_response=event
)
except Exception as e:
raise map_provider_error("cohere", e)