272 lines
No EOL
9.1 KiB
Python
272 lines
No EOL
9.1 KiB
Python
"""Reranking router implementation for Cohere Rerank models."""
|
|
|
|
from typing import List, Dict, Any, Optional, Union
|
|
from dataclasses import dataclass, field
|
|
import time
|
|
|
|
from .config import config
|
|
from .exceptions import (
|
|
ConfigurationError,
|
|
map_provider_error
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class RerankDocument:
|
|
"""Document to be reranked."""
|
|
text: str
|
|
id: Optional[str] = None
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class RerankResult:
|
|
"""Single reranked result."""
|
|
index: int # Original index in the documents list
|
|
relevance_score: float
|
|
document: RerankDocument
|
|
|
|
|
|
@dataclass
|
|
class RerankResponse:
|
|
"""Response from reranking operation."""
|
|
results: List[RerankResult]
|
|
model: str
|
|
provider: str
|
|
latency: float
|
|
num_documents: int
|
|
cost: Optional[float] = None
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
raw_response: Optional[Any] = None
|
|
|
|
|
|
class CohereRerank:
|
|
"""Router for Cohere Rerank models."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "rerank-english-v3.0",
|
|
api_key: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Initialize Cohere Rerank router.
|
|
|
|
Args:
|
|
model: Rerank model to use (rerank-3.5, rerank-english-v3.0, rerank-multilingual-v3.0)
|
|
api_key: Cohere API key (optional if set in environment)
|
|
**kwargs: Additional configuration
|
|
"""
|
|
# Get API key from config if not provided
|
|
if not api_key:
|
|
api_key = config.get_api_key("cohere")
|
|
|
|
if not api_key:
|
|
raise ConfigurationError(
|
|
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
|
|
provider="cohere"
|
|
)
|
|
|
|
self.model = model
|
|
self.api_key = api_key
|
|
self.config = kwargs
|
|
|
|
# Initialize Cohere client
|
|
try:
|
|
import cohere
|
|
# For v5.15.0, use the standard Client
|
|
self.client = cohere.Client(api_key)
|
|
self.async_client = cohere.AsyncClient(api_key)
|
|
except ImportError:
|
|
raise ConfigurationError(
|
|
"cohere package not installed. Install with: pip install cohere",
|
|
provider="cohere"
|
|
)
|
|
|
|
def rerank(
|
|
self,
|
|
query: str,
|
|
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
|
|
top_n: Optional[int] = None,
|
|
**kwargs: Any
|
|
) -> RerankResponse:
|
|
"""Rerank documents based on relevance to query.
|
|
|
|
Args:
|
|
query: The search query
|
|
documents: List of documents to rerank (strings, dicts, or RerankDocument objects)
|
|
top_n: Number of top results to return (None returns all)
|
|
**kwargs: Additional parameters (max_chunks_per_doc, return_documents, etc.)
|
|
|
|
Returns:
|
|
RerankResponse with reranked results
|
|
"""
|
|
params = self._prepare_request(query, documents, top_n, **kwargs)
|
|
|
|
try:
|
|
start_time = time.time()
|
|
response = self.client.rerank(**params)
|
|
latency = time.time() - start_time
|
|
|
|
return self._parse_response(response, latency, len(params["documents"]), documents)
|
|
except Exception as e:
|
|
raise map_provider_error("cohere", e)
|
|
|
|
async def arerank(
|
|
self,
|
|
query: str,
|
|
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
|
|
top_n: Optional[int] = None,
|
|
**kwargs: Any
|
|
) -> RerankResponse:
|
|
"""Asynchronously rerank documents based on relevance to query.
|
|
|
|
Args:
|
|
query: The search query
|
|
documents: List of documents to rerank
|
|
top_n: Number of top results to return (None returns all)
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
RerankResponse with reranked results
|
|
"""
|
|
params = self._prepare_request(query, documents, top_n, **kwargs)
|
|
|
|
try:
|
|
start_time = time.time()
|
|
response = await self.async_client.rerank(**params)
|
|
latency = time.time() - start_time
|
|
|
|
return self._parse_response(response, latency, len(params["documents"]), documents)
|
|
except Exception as e:
|
|
raise map_provider_error("cohere", e)
|
|
|
|
def _prepare_request(
|
|
self,
|
|
query: str,
|
|
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
|
|
top_n: Optional[int] = None,
|
|
**kwargs: Any
|
|
) -> Dict[str, Any]:
|
|
"""Prepare rerank request parameters.
|
|
|
|
Args:
|
|
query: The search query
|
|
documents: Documents to rerank
|
|
top_n: Number of results to return
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
Request parameters
|
|
"""
|
|
# Validate query is not empty
|
|
if not query or not query.strip():
|
|
raise ValueError("Query cannot be empty")
|
|
|
|
# Convert documents to the format expected by Cohere
|
|
formatted_docs = []
|
|
for i, doc in enumerate(documents):
|
|
if isinstance(doc, str):
|
|
formatted_docs.append({"text": doc})
|
|
elif isinstance(doc, RerankDocument):
|
|
formatted_docs.append({"text": doc.text})
|
|
elif isinstance(doc, dict):
|
|
# Assume dict has at least 'text' field
|
|
formatted_docs.append(doc)
|
|
else:
|
|
raise ValueError(f"Invalid document type at index {i}: {type(doc)}")
|
|
|
|
# Build request parameters
|
|
params = {
|
|
"model": kwargs.get("model", self.model),
|
|
"query": query,
|
|
"documents": formatted_docs,
|
|
"return_documents": True, # Always return documents for parsing
|
|
}
|
|
|
|
if top_n is not None:
|
|
params["top_n"] = top_n
|
|
|
|
# Add optional parameters for v5.15.0
|
|
for key in ["max_chunks_per_doc", "rank_fields"]:
|
|
if key in kwargs:
|
|
params[key] = kwargs[key]
|
|
|
|
# Allow override of return_documents if explicitly set to False
|
|
if "return_documents" in kwargs and kwargs["return_documents"] is False:
|
|
params["return_documents"] = False
|
|
|
|
return params
|
|
|
|
def _parse_response(
|
|
self,
|
|
raw_response: Any,
|
|
latency: float,
|
|
num_documents: int,
|
|
original_documents: Optional[List[Any]] = None
|
|
) -> RerankResponse:
|
|
"""Parse Cohere rerank response.
|
|
|
|
Args:
|
|
raw_response: Raw response from Cohere
|
|
latency: Request latency
|
|
num_documents: Total number of documents submitted
|
|
original_documents: Original documents for fallback
|
|
|
|
Returns:
|
|
RerankResponse
|
|
"""
|
|
results = []
|
|
|
|
# Parse results from v5.15.0 response format
|
|
if hasattr(raw_response, "results"):
|
|
for result in raw_response.results:
|
|
# Extract document info
|
|
doc_text = ""
|
|
|
|
# Try to get text from the returned document
|
|
if hasattr(result, "document"):
|
|
if hasattr(result.document, "text"):
|
|
doc_text = result.document.text
|
|
elif isinstance(result.document, dict) and "text" in result.document:
|
|
doc_text = result.document["text"]
|
|
|
|
# If no document text found and we have the original documents,
|
|
# use the index to get the original text
|
|
if not doc_text and original_documents and 0 <= result.index < len(original_documents):
|
|
orig_doc = original_documents[result.index]
|
|
if isinstance(orig_doc, str):
|
|
doc_text = orig_doc
|
|
elif isinstance(orig_doc, dict) and "text" in orig_doc:
|
|
doc_text = orig_doc["text"]
|
|
elif hasattr(orig_doc, "text"):
|
|
doc_text = getattr(orig_doc, "text")
|
|
|
|
results.append(RerankResult(
|
|
index=result.index,
|
|
relevance_score=result.relevance_score,
|
|
document=RerankDocument(text=doc_text)
|
|
))
|
|
|
|
# Calculate cost
|
|
cost = None
|
|
if config.track_costs:
|
|
# Reranking is charged per search (1 search = 1 query across N documents)
|
|
cost = config.calculate_rerank_cost(self.model, 1)
|
|
|
|
return RerankResponse(
|
|
results=results,
|
|
model=self.model,
|
|
provider="cohere",
|
|
latency=latency,
|
|
num_documents=num_documents,
|
|
cost=cost,
|
|
metadata={
|
|
"id": getattr(raw_response, "id", None),
|
|
"meta": getattr(raw_response, "meta", None),
|
|
},
|
|
raw_response=raw_response
|
|
)
|
|
|
|
|
|
# Convenience alias
|
|
Rerank = CohereRerank |