"""Reranking router implementation for Cohere Rerank models.""" from typing import List, Dict, Any, Optional, Union from dataclasses import dataclass, field import time from .config import config from .exceptions import ( ConfigurationError, map_provider_error ) @dataclass class RerankDocument: """Document to be reranked.""" text: str id: Optional[str] = None metadata: Dict[str, Any] = field(default_factory=dict) @dataclass class RerankResult: """Single reranked result.""" index: int # Original index in the documents list relevance_score: float document: RerankDocument @dataclass class RerankResponse: """Response from reranking operation.""" results: List[RerankResult] model: str provider: str latency: float num_documents: int cost: Optional[float] = None metadata: Dict[str, Any] = field(default_factory=dict) raw_response: Optional[Any] = None class CohereRerank: """Router for Cohere Rerank models.""" def __init__( self, model: str = "rerank-english-v3.0", api_key: Optional[str] = None, **kwargs: Any ) -> None: """Initialize Cohere Rerank router. Args: model: Rerank model to use (rerank-3.5, rerank-english-v3.0, rerank-multilingual-v3.0) api_key: Cohere API key (optional if set in environment) **kwargs: Additional configuration """ # Get API key from config if not provided if not api_key: api_key = config.get_api_key("cohere") if not api_key: raise ConfigurationError( "Cohere API key not found. Set COHERE_API_KEY environment variable.", provider="cohere" ) self.model = model self.api_key = api_key self.config = kwargs # Initialize Cohere client try: import cohere # For v5.15.0, use the standard Client self.client = cohere.Client(api_key) self.async_client = cohere.AsyncClient(api_key) except ImportError: raise ConfigurationError( "cohere package not installed. Install with: pip install cohere", provider="cohere" ) def rerank( self, query: str, documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]], top_n: Optional[int] = None, **kwargs: Any ) -> RerankResponse: """Rerank documents based on relevance to query. Args: query: The search query documents: List of documents to rerank (strings, dicts, or RerankDocument objects) top_n: Number of top results to return (None returns all) **kwargs: Additional parameters (max_chunks_per_doc, return_documents, etc.) Returns: RerankResponse with reranked results """ params = self._prepare_request(query, documents, top_n, **kwargs) try: start_time = time.time() response = self.client.rerank(**params) latency = time.time() - start_time return self._parse_response(response, latency, len(params["documents"]), documents) except Exception as e: raise map_provider_error("cohere", e) async def arerank( self, query: str, documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]], top_n: Optional[int] = None, **kwargs: Any ) -> RerankResponse: """Asynchronously rerank documents based on relevance to query. Args: query: The search query documents: List of documents to rerank top_n: Number of top results to return (None returns all) **kwargs: Additional parameters Returns: RerankResponse with reranked results """ params = self._prepare_request(query, documents, top_n, **kwargs) try: start_time = time.time() response = await self.async_client.rerank(**params) latency = time.time() - start_time return self._parse_response(response, latency, len(params["documents"]), documents) except Exception as e: raise map_provider_error("cohere", e) def _prepare_request( self, query: str, documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]], top_n: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: """Prepare rerank request parameters. Args: query: The search query documents: Documents to rerank top_n: Number of results to return **kwargs: Additional parameters Returns: Request parameters """ # Validate query is not empty if not query or not query.strip(): raise ValueError("Query cannot be empty") # Convert documents to the format expected by Cohere formatted_docs = [] for i, doc in enumerate(documents): if isinstance(doc, str): formatted_docs.append({"text": doc}) elif isinstance(doc, RerankDocument): formatted_docs.append({"text": doc.text}) elif isinstance(doc, dict): # Assume dict has at least 'text' field formatted_docs.append(doc) else: raise ValueError(f"Invalid document type at index {i}: {type(doc)}") # Build request parameters params = { "model": kwargs.get("model", self.model), "query": query, "documents": formatted_docs, "return_documents": True, # Always return documents for parsing } if top_n is not None: params["top_n"] = top_n # Add optional parameters for v5.15.0 for key in ["max_chunks_per_doc", "rank_fields"]: if key in kwargs: params[key] = kwargs[key] # Allow override of return_documents if explicitly set to False if "return_documents" in kwargs and kwargs["return_documents"] is False: params["return_documents"] = False return params def _parse_response( self, raw_response: Any, latency: float, num_documents: int, original_documents: Optional[List[Any]] = None ) -> RerankResponse: """Parse Cohere rerank response. Args: raw_response: Raw response from Cohere latency: Request latency num_documents: Total number of documents submitted original_documents: Original documents for fallback Returns: RerankResponse """ results = [] # Parse results from v5.15.0 response format if hasattr(raw_response, "results"): for result in raw_response.results: # Extract document info doc_text = "" # Try to get text from the returned document if hasattr(result, "document"): if hasattr(result.document, "text"): doc_text = result.document.text elif isinstance(result.document, dict) and "text" in result.document: doc_text = result.document["text"] # If no document text found and we have the original documents, # use the index to get the original text if not doc_text and original_documents and 0 <= result.index < len(original_documents): orig_doc = original_documents[result.index] if isinstance(orig_doc, str): doc_text = orig_doc elif isinstance(orig_doc, dict) and "text" in orig_doc: doc_text = orig_doc["text"] elif hasattr(orig_doc, "text"): doc_text = getattr(orig_doc, "text") results.append(RerankResult( index=result.index, relevance_score=result.relevance_score, document=RerankDocument(text=doc_text) )) # Calculate cost cost = None if config.track_costs: # Reranking is charged per search (1 search = 1 query across N documents) cost = config.calculate_rerank_cost(self.model, 1) return RerankResponse( results=results, model=self.model, provider="cohere", latency=latency, num_documents=num_documents, cost=cost, metadata={ "id": getattr(raw_response, "id", None), "meta": getattr(raw_response, "meta", None), }, raw_response=raw_response ) # Convenience alias Rerank = CohereRerank