188 lines
No EOL
6.9 KiB
Python
188 lines
No EOL
6.9 KiB
Python
"""Test reranking with real Cohere API."""
|
|
|
|
import pytest
|
|
from router.rerank import CohereRerank, RerankDocument, RerankResponse
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestCohereRerank:
|
|
"""Test Cohere reranking functionality."""
|
|
|
|
def test_basic_reranking(self, cohere_reranker):
|
|
"""Test basic document reranking."""
|
|
documents = [
|
|
"Python is a programming language",
|
|
"The weather is nice today",
|
|
"Machine learning with Python",
|
|
"Coffee is a popular beverage",
|
|
"Deep learning frameworks in Python"
|
|
]
|
|
|
|
query = "Python programming"
|
|
|
|
response = cohere_reranker.rerank(query, documents)
|
|
|
|
assert isinstance(response, RerankResponse)
|
|
assert response.provider == "cohere"
|
|
assert response.num_documents == 5
|
|
assert len(response.results) > 0
|
|
|
|
# Check results are sorted by relevance
|
|
scores = [r.relevance_score for r in response.results]
|
|
assert scores == sorted(scores, reverse=True), "Results should be sorted by relevance"
|
|
|
|
# The most relevant document should be about Python
|
|
top_result = response.results[0]
|
|
assert "Python" in top_result.document.text
|
|
|
|
def test_rerank_with_top_n(self, cohere_reranker):
|
|
"""Test reranking with top_n parameter."""
|
|
documents = [
|
|
"Document 1", "Document 2", "Document 3",
|
|
"Document 4", "Document 5"
|
|
]
|
|
|
|
response = cohere_reranker.rerank("test", documents, top_n=3)
|
|
|
|
assert len(response.results) == 3
|
|
assert all(0 <= r.index < 5 for r in response.results)
|
|
|
|
def test_rerank_with_document_objects(self, cohere_reranker):
|
|
"""Test reranking with RerankDocument objects."""
|
|
documents = [
|
|
RerankDocument(text="Python tutorial", id="doc1"),
|
|
RerankDocument(text="Java guide", id="doc2"),
|
|
RerankDocument(text="Python cookbook", id="doc3")
|
|
]
|
|
|
|
response = cohere_reranker.rerank("Python", documents)
|
|
|
|
assert response.results[0].document.text in ["Python tutorial", "Python cookbook"]
|
|
assert len(response.results) == 3
|
|
|
|
def test_rerank_mixed_document_types(self, cohere_reranker):
|
|
"""Test reranking with mixed document types."""
|
|
documents = [
|
|
"Plain string document",
|
|
{"text": "Dict with text field"},
|
|
RerankDocument(text="RerankDocument object")
|
|
]
|
|
|
|
response = cohere_reranker.rerank("test", documents)
|
|
|
|
assert len(response.results) == 3
|
|
assert all(hasattr(r.document, 'text') for r in response.results)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_reranking(self, cohere_reranker):
|
|
"""Test async reranking."""
|
|
documents = [
|
|
"Async programming in Python",
|
|
"Synchronous vs asynchronous",
|
|
"Python asyncio tutorial",
|
|
"Database operations",
|
|
]
|
|
|
|
query = "async Python"
|
|
|
|
response = await cohere_reranker.arerank(query, documents, top_n=2)
|
|
|
|
assert isinstance(response, RerankResponse)
|
|
assert len(response.results) == 2
|
|
assert response.results[0].relevance_score >= response.results[1].relevance_score
|
|
|
|
def test_single_document_reranking(self, cohere_reranker):
|
|
"""Test reranking with a single document."""
|
|
response = cohere_reranker.rerank("test", ["single document"])
|
|
|
|
assert len(response.results) == 1
|
|
assert response.results[0].index == 0
|
|
assert response.results[0].document.text == "single document"
|
|
|
|
def test_empty_query(self, cohere_reranker):
|
|
"""Test reranking with empty query."""
|
|
documents = ["doc1", "doc2"]
|
|
|
|
# Empty query should raise ValueError
|
|
with pytest.raises(ValueError, match="Query cannot be empty"):
|
|
cohere_reranker.rerank("", documents)
|
|
|
|
@pytest.mark.parametrize("model", [
|
|
"rerank-english-v3.0",
|
|
"rerank-multilingual-v3.0"
|
|
])
|
|
def test_different_models(self, model):
|
|
"""Test different rerank models."""
|
|
reranker = CohereRerank(model=model)
|
|
|
|
documents = ["Hello world", "Bonjour monde", "Hola mundo"]
|
|
response = reranker.rerank("greeting", documents)
|
|
|
|
assert response.model == model
|
|
assert len(response.results) == 3
|
|
|
|
def test_relevance_scores(self, cohere_reranker):
|
|
"""Test that relevance scores are properly set."""
|
|
documents = [
|
|
"Exact match for Python programming",
|
|
"Somewhat related to Python",
|
|
"Completely unrelated topic"
|
|
]
|
|
|
|
response = cohere_reranker.rerank("Python programming", documents)
|
|
|
|
# All results should have relevance scores
|
|
assert all(hasattr(r, 'relevance_score') for r in response.results)
|
|
assert all(isinstance(r.relevance_score, (int, float)) for r in response.results)
|
|
|
|
# Scores should be between 0 and 1 (typical range)
|
|
assert all(0 <= r.relevance_score <= 1 for r in response.results)
|
|
|
|
def test_cost_tracking(self, cohere_reranker):
|
|
"""Test cost tracking for reranking."""
|
|
# Enable cost tracking
|
|
from router.config import config
|
|
original_track_costs = config.track_costs
|
|
config.track_costs = True
|
|
|
|
try:
|
|
response = cohere_reranker.rerank("test", ["doc1", "doc2"])
|
|
|
|
if response.cost is not None:
|
|
assert response.cost > 0
|
|
assert isinstance(response.cost, float)
|
|
finally:
|
|
config.track_costs = original_track_costs
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestRerankEdgeCases:
|
|
"""Test edge cases and error handling."""
|
|
|
|
def test_very_long_query(self, cohere_reranker):
|
|
"""Test with a very long query."""
|
|
long_query = "Python " * 100 # Very long query
|
|
documents = ["Python doc", "Java doc"]
|
|
|
|
response = cohere_reranker.rerank(long_query, documents)
|
|
assert isinstance(response, RerankResponse)
|
|
|
|
def test_special_characters(self, cohere_reranker):
|
|
"""Test documents with special characters."""
|
|
documents = [
|
|
"Document with @#$% special chars",
|
|
"Normal document",
|
|
"Document with émojis 🐍"
|
|
]
|
|
|
|
response = cohere_reranker.rerank("special", documents)
|
|
assert len(response.results) == 3
|
|
|
|
@pytest.mark.parametrize("num_docs", [1, 10, 50])
|
|
def test_various_document_counts(self, cohere_reranker, num_docs):
|
|
"""Test with different numbers of documents."""
|
|
documents = [f"Document {i}" for i in range(num_docs)]
|
|
|
|
response = cohere_reranker.rerank("Document", documents)
|
|
assert len(response.results) == num_docs
|
|
assert response.num_documents == num_docs |