template-with-ai/router/rerank.py
2025-07-01 16:22:20 +05:30

245 lines
No EOL
7.6 KiB
Python

"""Reranking router implementation for Cohere Rerank models."""
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, field
import time
from .config import config
from .exceptions import (
ConfigurationError,
map_provider_error
)
@dataclass
class RerankDocument:
"""Document to be reranked."""
text: str
id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class RerankResult:
"""Single reranked result."""
index: int # Original index in the documents list
relevance_score: float
document: RerankDocument
@dataclass
class RerankResponse:
"""Response from reranking operation."""
results: List[RerankResult]
model: str
provider: str
latency: float
num_documents: int
cost: Optional[float] = None
metadata: Dict[str, Any] = field(default_factory=dict)
raw_response: Optional[Any] = None
class CohereRerank:
"""Router for Cohere Rerank models."""
def __init__(
self,
model: str = "rerank-english-v3.0",
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Cohere Rerank router.
Args:
model: Rerank model to use (rerank-3.5, rerank-english-v3.0, rerank-multilingual-v3.0)
api_key: Cohere API key (optional if set in environment)
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("cohere")
if not api_key:
raise ConfigurationError(
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
provider="cohere"
)
self.model = model
self.api_key = api_key
self.config = kwargs
# Initialize Cohere client
try:
import cohere
# For v5.15.0, use the standard Client
self.client = cohere.Client(api_key)
self.async_client = cohere.AsyncClient(api_key)
except ImportError:
raise ConfigurationError(
"cohere package not installed. Install with: pip install cohere",
provider="cohere"
)
def rerank(
self,
query: str,
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
top_n: Optional[int] = None,
**kwargs: Any
) -> RerankResponse:
"""Rerank documents based on relevance to query.
Args:
query: The search query
documents: List of documents to rerank (strings, dicts, or RerankDocument objects)
top_n: Number of top results to return (None returns all)
**kwargs: Additional parameters (max_chunks_per_doc, return_documents, etc.)
Returns:
RerankResponse with reranked results
"""
params = self._prepare_request(query, documents, top_n, **kwargs)
try:
start_time = time.time()
response = self.client.rerank(**params)
latency = time.time() - start_time
return self._parse_response(response, latency, len(params["documents"]))
except Exception as e:
raise map_provider_error("cohere", e)
async def arerank(
self,
query: str,
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
top_n: Optional[int] = None,
**kwargs: Any
) -> RerankResponse:
"""Asynchronously rerank documents based on relevance to query.
Args:
query: The search query
documents: List of documents to rerank
top_n: Number of top results to return (None returns all)
**kwargs: Additional parameters
Returns:
RerankResponse with reranked results
"""
params = self._prepare_request(query, documents, top_n, **kwargs)
try:
start_time = time.time()
response = await self.async_client.rerank(**params)
latency = time.time() - start_time
return self._parse_response(response, latency, len(params["documents"]))
except Exception as e:
raise map_provider_error("cohere", e)
def _prepare_request(
self,
query: str,
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
top_n: Optional[int] = None,
**kwargs: Any
) -> Dict[str, Any]:
"""Prepare rerank request parameters.
Args:
query: The search query
documents: Documents to rerank
top_n: Number of results to return
**kwargs: Additional parameters
Returns:
Request parameters
"""
# Convert documents to the format expected by Cohere
formatted_docs = []
for i, doc in enumerate(documents):
if isinstance(doc, str):
formatted_docs.append({"text": doc})
elif isinstance(doc, RerankDocument):
formatted_docs.append({"text": doc.text})
elif isinstance(doc, dict):
# Assume dict has at least 'text' field
formatted_docs.append(doc)
else:
raise ValueError(f"Invalid document type at index {i}: {type(doc)}")
# Build request parameters
params = {
"model": kwargs.get("model", self.model),
"query": query,
"documents": formatted_docs,
}
if top_n is not None:
params["top_n"] = top_n
# Add optional parameters for v5.15.0
for key in ["max_chunks_per_doc", "return_documents", "rank_fields"]:
if key in kwargs:
params[key] = kwargs[key]
return params
def _parse_response(
self,
raw_response: Any,
latency: float,
num_documents: int
) -> RerankResponse:
"""Parse Cohere rerank response.
Args:
raw_response: Raw response from Cohere
latency: Request latency
num_documents: Total number of documents submitted
Returns:
RerankResponse
"""
results = []
# Parse results from v5.15.0 response format
if hasattr(raw_response, "results"):
for result in raw_response.results:
# Extract document info
doc_text = ""
if hasattr(result, "document") and hasattr(result.document, "text"):
doc_text = result.document.text
results.append(RerankResult(
index=result.index,
relevance_score=result.relevance_score,
document=RerankDocument(text=doc_text)
))
# Calculate cost
cost = None
if config.track_costs:
# Reranking is charged per search (1 search = 1 query across N documents)
cost = config.calculate_rerank_cost(self.model, 1)
return RerankResponse(
results=results,
model=self.model,
provider="cohere",
latency=latency,
num_documents=num_documents,
cost=cost,
metadata={
"id": getattr(raw_response, "id", None),
"meta": getattr(raw_response, "meta", None),
},
raw_response=raw_response
)
# Convenience alias
Rerank = CohereRerank