259 lines
No EOL
9.4 KiB
Python
259 lines
No EOL
9.4 KiB
Python
"""Cohere provider implementation."""
|
|
|
|
from typing import Any, Dict, Optional, AsyncGenerator, Generator
|
|
import time
|
|
|
|
from .base import AIRouter, RouterResponse
|
|
from .config import config
|
|
from .exceptions import (
|
|
ConfigurationError,
|
|
map_provider_error
|
|
)
|
|
|
|
|
|
class Cohere(AIRouter):
|
|
"""Router for Cohere models."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "command-r-plus",
|
|
api_key: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Initialize Cohere router.
|
|
|
|
Args:
|
|
model: Cohere model to use (command-r-plus, command-r, etc.)
|
|
api_key: Cohere API key (optional if set in environment)
|
|
**kwargs: Additional configuration
|
|
"""
|
|
# Get API key from config if not provided
|
|
if not api_key:
|
|
api_key = config.get_api_key("cohere")
|
|
|
|
if not api_key:
|
|
raise ConfigurationError(
|
|
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
|
|
provider="cohere"
|
|
)
|
|
|
|
super().__init__(model=model, api_key=api_key, **kwargs)
|
|
|
|
# Initialize Cohere client
|
|
try:
|
|
import cohere
|
|
# For v5.15.0, use the standard Client
|
|
self.client = cohere.Client(api_key)
|
|
self.async_client = cohere.AsyncClient(api_key)
|
|
except ImportError:
|
|
raise ConfigurationError(
|
|
"cohere package not installed. Install with: pip install cohere",
|
|
provider="cohere"
|
|
)
|
|
|
|
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
|
|
"""Prepare Cohere request parameters.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
Request parameters
|
|
"""
|
|
# Build request parameters for v5.15.0 API
|
|
params = {
|
|
"model": kwargs.get("model", self.model),
|
|
"message": prompt, # v5.15.0 uses 'message' not 'messages'
|
|
}
|
|
|
|
# Add optional parameters
|
|
if "temperature" in kwargs:
|
|
params["temperature"] = kwargs["temperature"]
|
|
elif hasattr(config, "default_temperature"):
|
|
params["temperature"] = config.default_temperature
|
|
|
|
if "max_tokens" in kwargs:
|
|
params["max_tokens"] = kwargs["max_tokens"]
|
|
elif hasattr(config, "default_max_tokens") and config.default_max_tokens:
|
|
params["max_tokens"] = config.default_max_tokens
|
|
|
|
# Cohere uses 'p' instead of 'top_p'
|
|
if "p" in kwargs:
|
|
params["p"] = kwargs["p"]
|
|
elif "top_p" in kwargs:
|
|
params["p"] = kwargs["top_p"]
|
|
elif hasattr(config, "default_top_p"):
|
|
params["p"] = config.default_top_p
|
|
|
|
# Cohere uses 'k' instead of 'top_k'
|
|
if "k" in kwargs:
|
|
params["k"] = kwargs["k"]
|
|
elif "top_k" in kwargs:
|
|
params["k"] = kwargs["top_k"]
|
|
|
|
# Other Cohere-specific parameters for v5.15.0
|
|
for key in ["chat_history", "preamble", "conversation_id", "prompt_truncation", "connectors", "search_queries_only", "documents", "tools", "tool_results", "stop_sequences", "seed"]:
|
|
if key in kwargs:
|
|
params[key] = kwargs[key]
|
|
|
|
return params
|
|
|
|
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
|
|
"""Parse Cohere response into RouterResponse.
|
|
|
|
Args:
|
|
raw_response: Raw Cohere response
|
|
latency: Request latency
|
|
|
|
Returns:
|
|
RouterResponse
|
|
"""
|
|
# Extract text content - v5.15.0 has 'text' attribute directly
|
|
content = getattr(raw_response, "text", "")
|
|
|
|
# Get token counts from meta/usage
|
|
input_tokens = None
|
|
output_tokens = None
|
|
|
|
# Try different token count locations based on v5.15.0 structure
|
|
if hasattr(raw_response, "meta") and hasattr(raw_response.meta, "billed_units"):
|
|
billed = raw_response.meta.billed_units
|
|
input_tokens = getattr(billed, "input_tokens", None)
|
|
output_tokens = getattr(billed, "output_tokens", None)
|
|
elif hasattr(raw_response, "usage") and hasattr(raw_response.usage, "billed_units"):
|
|
billed = raw_response.usage.billed_units
|
|
input_tokens = getattr(billed, "input_tokens", None)
|
|
output_tokens = getattr(billed, "output_tokens", None)
|
|
|
|
# Calculate cost if available
|
|
cost = None
|
|
if input_tokens and output_tokens and config.track_costs:
|
|
cost = config.calculate_cost(self.model, input_tokens, output_tokens)
|
|
|
|
# Extract finish reason
|
|
finish_reason = "stop"
|
|
if hasattr(raw_response, "finish_reason"):
|
|
finish_reason = raw_response.finish_reason
|
|
|
|
return RouterResponse(
|
|
content=content,
|
|
model=self.model,
|
|
provider="cohere",
|
|
latency=latency,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
total_tokens=(input_tokens + output_tokens) if input_tokens and output_tokens else None,
|
|
cost=cost,
|
|
finish_reason=finish_reason,
|
|
metadata={
|
|
"id": getattr(raw_response, "id", None),
|
|
"generation_id": getattr(raw_response, "generation_id", None),
|
|
"citations": getattr(raw_response, "citations", None),
|
|
"documents": getattr(raw_response, "documents", None),
|
|
"search_results": getattr(raw_response, "search_results", None),
|
|
"search_queries": getattr(raw_response, "search_queries", None),
|
|
"tool_calls": getattr(raw_response, "tool_calls", None),
|
|
},
|
|
raw_response=raw_response
|
|
)
|
|
|
|
def _make_request(self, request_params: Dict[str, Any]) -> Any:
|
|
"""Make synchronous request to Cohere.
|
|
|
|
Args:
|
|
request_params: Request parameters
|
|
|
|
Returns:
|
|
Raw Cohere response
|
|
"""
|
|
try:
|
|
response = self.client.chat(**request_params)
|
|
return response
|
|
except Exception as e:
|
|
raise map_provider_error("cohere", e)
|
|
|
|
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
|
|
"""Make asynchronous request to Cohere.
|
|
|
|
Args:
|
|
request_params: Request parameters
|
|
|
|
Returns:
|
|
Raw Cohere response
|
|
"""
|
|
try:
|
|
response = await self.async_client.chat(**request_params)
|
|
return response
|
|
except Exception as e:
|
|
raise map_provider_error("cohere", e)
|
|
|
|
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
|
|
"""Stream responses from Cohere.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Yields:
|
|
RouterResponse chunks
|
|
"""
|
|
params = {**self.config, **kwargs}
|
|
request_params = self._prepare_request(prompt, **params)
|
|
|
|
try:
|
|
start_time = time.time()
|
|
stream = self.client.chat_stream(**request_params)
|
|
|
|
for event in stream:
|
|
# v5.15.0 uses event_type and text directly
|
|
if hasattr(event, "event_type") and event.event_type == "text-generation":
|
|
content = getattr(event, "text", "")
|
|
if content:
|
|
yield RouterResponse(
|
|
content=content,
|
|
model=self.model,
|
|
provider="cohere",
|
|
latency=time.time() - start_time,
|
|
finish_reason=None,
|
|
metadata={"event_type": event.event_type},
|
|
raw_response=event
|
|
)
|
|
except Exception as e:
|
|
raise map_provider_error("cohere", e)
|
|
|
|
async def astream(
|
|
self, prompt: str, **kwargs: Any
|
|
) -> AsyncGenerator[RouterResponse, None]:
|
|
"""Asynchronously stream responses from Cohere.
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
**kwargs: Additional parameters
|
|
|
|
Yields:
|
|
RouterResponse chunks
|
|
"""
|
|
params = {**self.config, **kwargs}
|
|
request_params = self._prepare_request(prompt, **params)
|
|
|
|
try:
|
|
start_time = time.time()
|
|
stream = self.async_client.chat_stream(**request_params)
|
|
|
|
async for event in stream:
|
|
# v5.15.0 uses event_type and text directly
|
|
if hasattr(event, "event_type") and event.event_type == "text-generation":
|
|
content = getattr(event, "text", "")
|
|
if content:
|
|
yield RouterResponse(
|
|
content=content,
|
|
model=self.model,
|
|
provider="cohere",
|
|
latency=time.time() - start_time,
|
|
finish_reason=None,
|
|
metadata={"event_type": event.event_type},
|
|
raw_response=event
|
|
)
|
|
except Exception as e:
|
|
raise map_provider_error("cohere", e) |