template-with-ai/tests/test_rerank.py
2025-07-01 17:07:02 +05:30

188 lines
No EOL
6.9 KiB
Python

"""Test reranking with real Cohere API."""
import pytest
from router.rerank import CohereRerank, RerankDocument, RerankResponse
@pytest.mark.integration
class TestCohereRerank:
"""Test Cohere reranking functionality."""
def test_basic_reranking(self, cohere_reranker):
"""Test basic document reranking."""
documents = [
"Python is a programming language",
"The weather is nice today",
"Machine learning with Python",
"Coffee is a popular beverage",
"Deep learning frameworks in Python"
]
query = "Python programming"
response = cohere_reranker.rerank(query, documents)
assert isinstance(response, RerankResponse)
assert response.provider == "cohere"
assert response.num_documents == 5
assert len(response.results) > 0
# Check results are sorted by relevance
scores = [r.relevance_score for r in response.results]
assert scores == sorted(scores, reverse=True), "Results should be sorted by relevance"
# The most relevant document should be about Python
top_result = response.results[0]
assert "Python" in top_result.document.text
def test_rerank_with_top_n(self, cohere_reranker):
"""Test reranking with top_n parameter."""
documents = [
"Document 1", "Document 2", "Document 3",
"Document 4", "Document 5"
]
response = cohere_reranker.rerank("test", documents, top_n=3)
assert len(response.results) == 3
assert all(0 <= r.index < 5 for r in response.results)
def test_rerank_with_document_objects(self, cohere_reranker):
"""Test reranking with RerankDocument objects."""
documents = [
RerankDocument(text="Python tutorial", id="doc1"),
RerankDocument(text="Java guide", id="doc2"),
RerankDocument(text="Python cookbook", id="doc3")
]
response = cohere_reranker.rerank("Python", documents)
assert response.results[0].document.text in ["Python tutorial", "Python cookbook"]
assert len(response.results) == 3
def test_rerank_mixed_document_types(self, cohere_reranker):
"""Test reranking with mixed document types."""
documents = [
"Plain string document",
{"text": "Dict with text field"},
RerankDocument(text="RerankDocument object")
]
response = cohere_reranker.rerank("test", documents)
assert len(response.results) == 3
assert all(hasattr(r.document, 'text') for r in response.results)
@pytest.mark.asyncio
async def test_async_reranking(self, cohere_reranker):
"""Test async reranking."""
documents = [
"Async programming in Python",
"Synchronous vs asynchronous",
"Python asyncio tutorial",
"Database operations",
]
query = "async Python"
response = await cohere_reranker.arerank(query, documents, top_n=2)
assert isinstance(response, RerankResponse)
assert len(response.results) == 2
assert response.results[0].relevance_score >= response.results[1].relevance_score
def test_single_document_reranking(self, cohere_reranker):
"""Test reranking with a single document."""
response = cohere_reranker.rerank("test", ["single document"])
assert len(response.results) == 1
assert response.results[0].index == 0
assert response.results[0].document.text == "single document"
def test_empty_query(self, cohere_reranker):
"""Test reranking with empty query."""
documents = ["doc1", "doc2"]
# Empty query should raise ValueError
with pytest.raises(ValueError, match="Query cannot be empty"):
cohere_reranker.rerank("", documents)
@pytest.mark.parametrize("model", [
"rerank-english-v3.0",
"rerank-multilingual-v3.0"
])
def test_different_models(self, model):
"""Test different rerank models."""
reranker = CohereRerank(model=model)
documents = ["Hello world", "Bonjour monde", "Hola mundo"]
response = reranker.rerank("greeting", documents)
assert response.model == model
assert len(response.results) == 3
def test_relevance_scores(self, cohere_reranker):
"""Test that relevance scores are properly set."""
documents = [
"Exact match for Python programming",
"Somewhat related to Python",
"Completely unrelated topic"
]
response = cohere_reranker.rerank("Python programming", documents)
# All results should have relevance scores
assert all(hasattr(r, 'relevance_score') for r in response.results)
assert all(isinstance(r.relevance_score, (int, float)) for r in response.results)
# Scores should be between 0 and 1 (typical range)
assert all(0 <= r.relevance_score <= 1 for r in response.results)
def test_cost_tracking(self, cohere_reranker):
"""Test cost tracking for reranking."""
# Enable cost tracking
from router.config import config
original_track_costs = config.track_costs
config.track_costs = True
try:
response = cohere_reranker.rerank("test", ["doc1", "doc2"])
if response.cost is not None:
assert response.cost > 0
assert isinstance(response.cost, float)
finally:
config.track_costs = original_track_costs
@pytest.mark.integration
class TestRerankEdgeCases:
"""Test edge cases and error handling."""
def test_very_long_query(self, cohere_reranker):
"""Test with a very long query."""
long_query = "Python " * 100 # Very long query
documents = ["Python doc", "Java doc"]
response = cohere_reranker.rerank(long_query, documents)
assert isinstance(response, RerankResponse)
def test_special_characters(self, cohere_reranker):
"""Test documents with special characters."""
documents = [
"Document with @#$% special chars",
"Normal document",
"Document with émojis 🐍"
]
response = cohere_reranker.rerank("special", documents)
assert len(response.results) == 3
@pytest.mark.parametrize("num_docs", [1, 10, 50])
def test_various_document_counts(self, cohere_reranker, num_docs):
"""Test with different numbers of documents."""
documents = [f"Document {i}" for i in range(num_docs)]
response = cohere_reranker.rerank("Document", documents)
assert len(response.results) == num_docs
assert response.num_documents == num_docs