All test cases working. stable code

This commit is contained in:
Pratik Narola 2025-07-01 17:07:02 +05:30
parent 9961cb55a6
commit c0fa06973e
15 changed files with 1121 additions and 9 deletions

90
conftest.py Normal file
View file

@ -0,0 +1,90 @@
"""Pytest configuration and fixtures."""
import os
import sys
import warnings
from pathlib import Path
import pytest
from dotenv import load_dotenv
# Suppress warnings
warnings.filterwarnings("ignore", category=DeprecationWarning, module="cohere")
warnings.filterwarnings("ignore", message=".*__fields__.*", category=DeprecationWarning)
# Add router to path
sys.path.insert(0, str(Path(__file__).parent))
# Load environment variables once
load_dotenv()
def pytest_addoption(parser):
"""Add custom command line options."""
parser.addoption(
"--ollama",
action="store_true",
default=False,
help="Run Ollama tests"
)
@pytest.fixture(scope="session")
def api_keys():
"""Fixture to provide API keys."""
return {
"cohere": os.getenv("COHERE_API_KEY"),
"gemini": os.getenv("GEMINI_API_KEY"),
"openai": os.getenv("OPENAI_API_KEY"),
"openai_base_url": os.getenv("OPENAI_BASE_URL", "https://veronica.pratikn.com"),
"ollama_base_url": os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
}
@pytest.fixture
def cohere_embedder():
"""Create a Cohere embedding instance."""
from router.embed import CohereEmbedding
return CohereEmbedding()
@pytest.fixture
def gemini_embedder():
"""Create a Gemini embedding instance."""
from router.embed import GeminiEmbedding
return GeminiEmbedding()
@pytest.fixture
def ollama_embedder():
"""Create an Ollama embedding instance."""
from router.embed import OllamaEmbedding
return OllamaEmbedding()
@pytest.fixture
def cohere_reranker():
"""Create a Cohere rerank instance."""
from router.rerank import CohereRerank
return CohereRerank()
@pytest.fixture
def openai_router(api_keys):
"""Create an OpenAI router instance."""
from router.openai_compatible import OpenAI
return OpenAI(base_url=api_keys["openai_base_url"])
@pytest.fixture
def gemini_router():
"""Create a Gemini router instance."""
from router.gemini import Gemini
return Gemini()
@pytest.fixture
def cohere_router():
"""Create a Cohere router instance."""
from router.cohere import Cohere
return Cohere()

16
pytest.ini Normal file
View file

@ -0,0 +1,16 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short --strict-markers
asyncio_mode = auto
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
integration: marks tests as integration tests
unit: marks tests as unit tests
ollama: marks tests that require Ollama to be running
filterwarnings =
ignore::DeprecationWarning:cohere.*
ignore:.*__fields__.*:DeprecationWarning
ignore:.*model_fields.*:DeprecationWarning

11
requirements-test.txt Normal file
View file

@ -0,0 +1,11 @@
# Test dependencies
pytest>=7.0.0
pytest-asyncio>=0.21.0
python-dotenv>=1.0.0
# Required packages for the router
cohere>=5.11.0
google-genai>=0.1.0
openai>=1.0.0
httpx>=0.24.0 # For async HTTP in Ollama
requests>=2.31.0 # For sync HTTP in Ollama

View file

@ -1,6 +1,7 @@
"""AI Router - Clean abstraction for AI providers.""" """AI Router - Clean abstraction for AI providers."""
from .base import AIRouter, RouterResponse from .base import AIRouter as BaseAIRouter, RouterResponse
from .factory import AIRouter, create_router
from .config import RouterConfig, config from .config import RouterConfig, config
from .exceptions import ( from .exceptions import (
RouterError, RouterError,
@ -35,9 +36,11 @@ __version__ = "0.1.0"
__all__ = [ __all__ = [
# Base classes # Base classes
"AIRouter", "AIRouter",
"BaseAIRouter",
"RouterResponse", "RouterResponse",
"RouterConfig", "RouterConfig",
"config", "config",
"create_router",
# Generation Providers # Generation Providers
"Gemini", "Gemini",
"OpenAI", "OpenAI",

View file

@ -196,8 +196,11 @@ class CohereEmbedding(BaseEmbedding):
"texts": texts, "texts": texts,
} }
# For v3 models, input_type is required - default to search_document
if input_type: if input_type:
params["input_type"] = input_type params["input_type"] = input_type
elif "v3" in self.model or "3.0" in self.model or "3.5" in self.model:
params["input_type"] = "search_document"
if truncate: if truncate:
params["truncate"] = truncate params["truncate"] = truncate
@ -674,7 +677,7 @@ class OllamaEmbedding(BaseEmbedding):
dimension = len(embeddings[0]) if embeddings else 0 dimension = len(embeddings[0]) if embeddings else 0
# No cost for local models # No cost for local models
cost = 0.0 if config.track_costs else None cost = 0.0 # Always 0.0 for local models
return EmbeddingResponse( return EmbeddingResponse(
embeddings=embeddings, embeddings=embeddings,

57
router/factory.py Normal file
View file

@ -0,0 +1,57 @@
"""Factory function for creating AI router instances."""
from typing import Any, Optional
from .base import AIRouter
from .gemini import Gemini
from .openai_compatible import OpenAI
from .cohere import Cohere
def create_router(
provider: str,
model: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs: Any
) -> AIRouter:
"""Create an AI router instance based on provider.
Args:
provider: Provider name (openai, gemini, cohere)
model: Model to use (provider-specific default if not provided)
api_key: API key (optional if set in environment)
**kwargs: Additional provider-specific configuration
Returns:
AIRouter instance
Raises:
ValueError: If provider is unknown
"""
provider = provider.lower()
if provider == "openai":
return OpenAI(
model=model or "gpt-4o",
api_key=api_key,
**kwargs
)
elif provider in ["gemini", "google"]:
return Gemini(
model=model or "gemini-2.0-flash-001",
api_key=api_key,
**kwargs
)
elif provider == "cohere":
return Cohere(
model=model or "command-r-plus",
api_key=api_key,
**kwargs
)
else:
raise ValueError(f"Unknown provider: {provider}")
# Convenience alias that matches the test expectations
def AIRouter(provider: str, **kwargs: Any) -> AIRouter:
"""Factory function alias for backward compatibility."""
return create_router(provider, **kwargs)

View file

@ -16,7 +16,7 @@ class OpenAICompatible(AIRouter):
def __init__( def __init__(
self, self,
model: str = "gpt-3.5-turbo", model: str = "gpt-4o",
api_key: Optional[str] = None, api_key: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,

View file

@ -107,7 +107,7 @@ class CohereRerank:
response = self.client.rerank(**params) response = self.client.rerank(**params)
latency = time.time() - start_time latency = time.time() - start_time
return self._parse_response(response, latency, len(params["documents"])) return self._parse_response(response, latency, len(params["documents"]), documents)
except Exception as e: except Exception as e:
raise map_provider_error("cohere", e) raise map_provider_error("cohere", e)
@ -136,7 +136,7 @@ class CohereRerank:
response = await self.async_client.rerank(**params) response = await self.async_client.rerank(**params)
latency = time.time() - start_time latency = time.time() - start_time
return self._parse_response(response, latency, len(params["documents"])) return self._parse_response(response, latency, len(params["documents"]), documents)
except Exception as e: except Exception as e:
raise map_provider_error("cohere", e) raise map_provider_error("cohere", e)
@ -158,6 +158,10 @@ class CohereRerank:
Returns: Returns:
Request parameters Request parameters
""" """
# Validate query is not empty
if not query or not query.strip():
raise ValueError("Query cannot be empty")
# Convert documents to the format expected by Cohere # Convert documents to the format expected by Cohere
formatted_docs = [] formatted_docs = []
for i, doc in enumerate(documents): for i, doc in enumerate(documents):
@ -176,23 +180,29 @@ class CohereRerank:
"model": kwargs.get("model", self.model), "model": kwargs.get("model", self.model),
"query": query, "query": query,
"documents": formatted_docs, "documents": formatted_docs,
"return_documents": True, # Always return documents for parsing
} }
if top_n is not None: if top_n is not None:
params["top_n"] = top_n params["top_n"] = top_n
# Add optional parameters for v5.15.0 # Add optional parameters for v5.15.0
for key in ["max_chunks_per_doc", "return_documents", "rank_fields"]: for key in ["max_chunks_per_doc", "rank_fields"]:
if key in kwargs: if key in kwargs:
params[key] = kwargs[key] params[key] = kwargs[key]
# Allow override of return_documents if explicitly set to False
if "return_documents" in kwargs and kwargs["return_documents"] is False:
params["return_documents"] = False
return params return params
def _parse_response( def _parse_response(
self, self,
raw_response: Any, raw_response: Any,
latency: float, latency: float,
num_documents: int num_documents: int,
original_documents: Optional[List[Any]] = None
) -> RerankResponse: ) -> RerankResponse:
"""Parse Cohere rerank response. """Parse Cohere rerank response.
@ -200,6 +210,7 @@ class CohereRerank:
raw_response: Raw response from Cohere raw_response: Raw response from Cohere
latency: Request latency latency: Request latency
num_documents: Total number of documents submitted num_documents: Total number of documents submitted
original_documents: Original documents for fallback
Returns: Returns:
RerankResponse RerankResponse
@ -211,8 +222,24 @@ class CohereRerank:
for result in raw_response.results: for result in raw_response.results:
# Extract document info # Extract document info
doc_text = "" doc_text = ""
if hasattr(result, "document") and hasattr(result.document, "text"):
doc_text = result.document.text # Try to get text from the returned document
if hasattr(result, "document"):
if hasattr(result.document, "text"):
doc_text = result.document.text
elif isinstance(result.document, dict) and "text" in result.document:
doc_text = result.document["text"]
# If no document text found and we have the original documents,
# use the index to get the original text
if not doc_text and original_documents and 0 <= result.index < len(original_documents):
orig_doc = original_documents[result.index]
if isinstance(orig_doc, str):
doc_text = orig_doc
elif isinstance(orig_doc, dict) and "text" in orig_doc:
doc_text = orig_doc["text"]
elif hasattr(orig_doc, "text"):
doc_text = getattr(orig_doc, "text")
results.append(RerankResult( results.append(RerankResult(
index=result.index, index=result.index,

96
tests/README.md Normal file
View file

@ -0,0 +1,96 @@
# AI Router Test Suite
This test suite contains real API tests for the AI Router library using pytest.
## Setup
1. Install test dependencies:
```bash
pip install -r requirements-test.txt
```
2. Create a `.env` file in the project root with your API keys:
```bash
COHERE_API_KEY=your_cohere_key
GEMINI_API_KEY=your_gemini_key
OPENAI_API_KEY=your_openai_key
OPENAI_BASE_URL=your_openai_base_url # Optional custom endpoint
OLLAMA_BASE_URL=http://localhost:11434 # For local Ollama
```
3. For Ollama tests, ensure Ollama is running:
```bash
ollama serve
ollama pull nomic-embed-text:latest
ollama pull mxbai-embed-large:latest
```
## Running Tests with Pytest
### Run all tests:
```bash
pytest
```
### Run with verbose output:
```bash
pytest -v
```
### Run specific test file:
```bash
pytest tests/test_embeddings.py
pytest tests/test_generation.py -v
```
### Run specific test class or method:
```bash
pytest tests/test_embeddings.py::TestCohereEmbeddings
pytest tests/test_embeddings.py::TestCohereEmbeddings::test_single_text_embedding
```
### Run tests by marker:
```bash
# Run only integration tests
pytest -m integration
# Run tests excluding Ollama
pytest -m "not ollama"
# Run only Ollama tests
pytest -m ollama
```
### Run with coverage:
```bash
pytest --cov=router tests/
```
### Run tests in parallel:
```bash
# Install pytest-xdist first
pip install pytest-xdist
pytest -n auto
```
## Test Coverage
- **test_config.py**: Tests configuration loading and API key management
- **test_embeddings.py**: Tests embedding APIs (Cohere, Gemini, Ollama)
- **test_rerank.py**: Tests Cohere reranking functionality
- **test_generation.py**: Tests text generation (OpenAI-compatible, Gemini, Cohere)
## Test Markers
- `@pytest.mark.integration`: Marks integration tests that call real APIs
- `@pytest.mark.asyncio`: Marks async tests
- `@pytest.mark.ollama`: Marks tests that require Ollama to be running
- `@pytest.mark.parametrize`: Used for running tests with multiple input values
## Notes
- These tests use real APIs and will consume API credits
- Ollama tests are marked with `@pytest.mark.ollama` and can be skipped
- All tests include both sync and async variants
- Tests verify functionality without checking exact output values
- The test suite is configured for `asyncio_mode = auto` in pytest.ini

1
tests/__init__.py Normal file
View file

@ -0,0 +1 @@
"""Test suite for AI Router."""

95
tests/run_all_tests.py Normal file
View file

@ -0,0 +1,95 @@
"""Run all tests in the test suite."""
import os
import sys
import subprocess
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
def run_test(test_file: str) -> bool:
"""Run a single test file and return success status."""
print(f"\n{'='*60}")
print(f"Running {test_file}...")
print('='*60)
try:
result = subprocess.run(
[sys.executable, test_file],
capture_output=True,
text=True,
check=True
)
print(result.stdout)
if result.stderr:
print("Warnings:", result.stderr)
return True
except subprocess.CalledProcessError as e:
print(f"❌ Test failed: {test_file}")
print("STDOUT:", e.stdout)
print("STDERR:", e.stderr)
return False
def main():
"""Run all tests."""
print("🚀 AI Router Test Suite")
print("=" * 60)
# Load environment variables
from dotenv import load_dotenv
env_path = Path(__file__).parent.parent / ".env"
if env_path.exists():
load_dotenv(env_path)
print(f"✓ Loaded environment from {env_path}")
else:
print("⚠ No .env file found, using system environment")
# Test files in order
test_files = [
"test_config.py",
"test_embeddings.py",
"test_rerank.py",
"test_generation.py",
]
# Track results
results = {}
tests_dir = Path(__file__).parent
for test_file in test_files:
test_path = tests_dir / test_file
if test_path.exists():
results[test_file] = run_test(str(test_path))
else:
print(f"⚠ Test file not found: {test_file}")
results[test_file] = False
# Summary
print("\n" + "="*60)
print("📊 TEST SUMMARY")
print("="*60)
passed = sum(1 for success in results.values() if success)
total = len(results)
for test_file, success in results.items():
status = "✅ PASSED" if success else "❌ FAILED"
print(f"{test_file:<30} {status}")
print("-"*60)
print(f"Total: {passed}/{total} tests passed")
if passed == total:
print("\n🎉 All tests passed!")
return 0
else:
print(f"\n{total - passed} test(s) failed")
return 1
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)

71
tests/test_config.py Normal file
View file

@ -0,0 +1,71 @@
"""Test configuration loading from environment."""
import pytest
from router.config import RouterConfig
class TestRouterConfig:
"""Test RouterConfig functionality."""
def test_config_from_env(self, api_keys):
"""Test that config loads API keys from environment."""
config = RouterConfig.from_env()
# Check API keys are loaded
assert config.cohere_api_key is not None, "Cohere API key not loaded"
assert config.gemini_api_key is not None, "Gemini API key not loaded"
assert config.openai_api_key is not None, "OpenAI API key not loaded"
# Check Ollama configuration
assert config.ollama_base_url == api_keys["ollama_base_url"]
def test_get_api_key(self):
"""Test getting API keys by provider."""
config = RouterConfig.from_env()
assert config.get_api_key("cohere") is not None
assert config.get_api_key("gemini") is not None
assert config.get_api_key("openai") is not None
assert config.get_api_key("unknown") is None
def test_cost_calculation(self):
"""Test cost calculation methods."""
config = RouterConfig.from_env()
config.track_costs = True
# Test token cost calculation
cost = config.calculate_cost("gpt-4o", input_tokens=1000, output_tokens=500)
assert cost is not None
assert cost > 0
# Test embedding cost
embed_cost = config.calculate_embed_cost("text-embedding-004", num_tokens=1000)
assert embed_cost is not None
assert embed_cost >= 0
# Test rerank cost
rerank_cost = config.calculate_rerank_cost("rerank-english-v3.0", num_searches=10)
assert rerank_cost is not None
assert rerank_cost > 0
def test_ollama_models_config(self):
"""Test Ollama models configuration."""
config = RouterConfig.from_env()
assert "mxbai-embed-large:latest" in config.ollama_embedding_models
assert "nomic-embed-text:latest" in config.ollama_embedding_models
assert len(config.ollama_embedding_models) >= 2
def test_config_to_dict(self):
"""Test config serialization."""
config = RouterConfig.from_env()
config_dict = config.to_dict()
# Check that API keys are masked
assert config_dict["openai_api_key"] == "***"
assert config_dict["cohere_api_key"] == "***"
assert config_dict["gemini_api_key"] == "***"
# Check other settings are present
assert "default_temperature" in config_dict
assert "track_costs" in config_dict

209
tests/test_embeddings.py Normal file
View file

@ -0,0 +1,209 @@
"""Test embeddings with real API calls."""
import pytest
from router.embed import (
create_embedding,
CohereEmbedding,
GeminiEmbedding,
OllamaEmbedding,
EmbeddingResponse
)
# Removed skip decorator - using pytest marks instead
@pytest.mark.integration
class TestCohereEmbeddings:
"""Test Cohere embedding functionality."""
def test_single_text_embedding(self, cohere_embedder):
"""Test embedding a single text."""
response = cohere_embedder.embed("Hello world")
assert isinstance(response, EmbeddingResponse)
assert len(response.embeddings) == 1
assert len(response.embeddings[0]) > 0
assert response.dimension > 0
assert response.provider == "cohere"
assert response.latency > 0
assert response.num_inputs == 1
def test_multiple_texts_embedding(self, cohere_embedder):
"""Test embedding multiple texts."""
texts = ["First text", "Second text", "Third text"]
response = cohere_embedder.embed(texts)
assert len(response.embeddings) == 3
assert response.num_inputs == 3
assert all(len(emb) == response.dimension for emb in response.embeddings)
def test_input_type_parameter(self, cohere_embedder):
"""Test embedding with input_type parameter."""
response = cohere_embedder.embed("Search query", input_type="search_query")
assert response.embeddings[0] is not None
response = cohere_embedder.embed("Document text", input_type="search_document")
assert response.embeddings[0] is not None
@pytest.mark.asyncio
async def test_async_embedding(self, cohere_embedder):
"""Test async embedding."""
response = await cohere_embedder.aembed("Hello async world")
assert isinstance(response, EmbeddingResponse)
assert len(response.embeddings) == 1
assert response.provider == "cohere"
@pytest.mark.asyncio
async def test_async_multiple_embeddings(self, cohere_embedder):
"""Test async embedding of multiple texts."""
texts = ["Async text 1", "Async text 2"]
response = await cohere_embedder.aembed(texts)
assert len(response.embeddings) == 2
assert response.num_inputs == 2
@pytest.mark.integration
class TestGeminiEmbeddings:
"""Test Gemini embedding functionality."""
def test_single_text_embedding(self, gemini_embedder):
"""Test embedding a single text."""
response = gemini_embedder.embed("Hello from Gemini")
assert isinstance(response, EmbeddingResponse)
assert len(response.embeddings) == 1
assert len(response.embeddings[0]) > 0
assert response.provider == "gemini"
assert response.dimension > 0
def test_multiple_texts_embedding(self, gemini_embedder):
"""Test embedding multiple texts."""
texts = ["Gemini text one", "Gemini text two"]
response = gemini_embedder.embed(texts)
assert len(response.embeddings) == 2
assert all(len(emb) == response.dimension for emb in response.embeddings)
def test_task_type_parameter(self, gemini_embedder):
"""Test embedding with task_type parameter."""
response = gemini_embedder.embed(
"Test with task type",
task_type="retrieval_document"
)
assert response.embeddings[0] is not None
@pytest.mark.asyncio
async def test_async_embedding(self, gemini_embedder):
"""Test async embedding."""
response = await gemini_embedder.aembed("Async Gemini test")
assert isinstance(response, EmbeddingResponse)
assert response.provider == "gemini"
assert len(response.embeddings) == 1
@pytest.mark.integration
class TestOllamaEmbeddings:
"""Test Ollama embedding functionality."""
@pytest.mark.ollama
def test_single_text_embedding(self, ollama_embedder):
"""Test embedding a single text with Ollama."""
try:
response = ollama_embedder.embed("Test Ollama")
assert isinstance(response, EmbeddingResponse)
assert response.provider == "ollama"
assert response.cost == 0.0 # Local model should be free
assert len(response.embeddings) == 1
assert response.dimension > 0
except Exception as e:
if "Cannot connect to Ollama" in str(e):
pytest.skip("Ollama not running")
raise
@pytest.mark.ollama
def test_multiple_texts_embedding(self, ollama_embedder):
"""Test embedding multiple texts with Ollama."""
try:
response = ollama_embedder.embed(["Local text 1", "Local text 2"])
assert len(response.embeddings) == 2
assert response.num_inputs == 2
assert response.metadata.get("local") is True
except Exception as e:
if "Cannot connect to Ollama" in str(e):
pytest.skip("Ollama not running")
raise
@pytest.mark.integration
class TestEmbeddingFactory:
"""Test the embedding factory function."""
def test_create_cohere_embedder(self):
"""Test creating Cohere embedder via factory."""
embedder = create_embedding(provider="cohere")
response = embedder.embed("Factory test")
assert response.provider == "cohere"
assert isinstance(embedder, CohereEmbedding)
def test_create_gemini_embedder(self):
"""Test creating Gemini embedder via factory."""
embedder = create_embedding(provider="gemini", model="text-embedding-004")
response = embedder.embed("Factory Gemini")
assert response.provider == "gemini"
assert response.model == "text-embedding-004"
assert isinstance(embedder, GeminiEmbedding)
def test_create_ollama_embedder(self):
"""Test creating Ollama embedder via factory."""
embedder = create_embedding(
provider="ollama",
model="mxbai-embed-large:latest"
)
assert embedder.model == "mxbai-embed-large:latest"
assert isinstance(embedder, OllamaEmbedding)
def test_invalid_provider(self):
"""Test factory with invalid provider."""
with pytest.raises(ValueError, match="Unknown embedding provider"):
create_embedding(provider="invalid")
@pytest.mark.integration
class TestEmbeddingComparison:
"""Test comparing embeddings across providers."""
def test_embedding_dimensions(self, cohere_embedder, gemini_embedder):
"""Test that different models produce different dimensions."""
test_text = "dimension test"
cohere_resp = cohere_embedder.embed(test_text)
gemini_resp = gemini_embedder.embed(test_text)
assert cohere_resp.dimension > 0
assert gemini_resp.dimension > 0
# Both should return valid embeddings
assert len(cohere_resp.embeddings[0]) == cohere_resp.dimension
assert len(gemini_resp.embeddings[0]) == gemini_resp.dimension
@pytest.mark.parametrize("text", [
"Short text",
"A much longer text that contains multiple sentences and should still work fine.",
"Text with special characters: @#$%^&*()",
"文本与中文字符", # Text with Chinese characters
])
def test_various_text_inputs(self, cohere_embedder, text):
"""Test embedding various types of text inputs."""
response = cohere_embedder.embed(text)
assert len(response.embeddings) == 1
assert len(response.embeddings[0]) > 0

245
tests/test_generation.py Normal file
View file

@ -0,0 +1,245 @@
"""Test text generation with real APIs."""
import pytest
from router import AIRouter, RouterResponse
from router.gemini import Gemini
from router.openai_compatible import OpenAI
from router.cohere import Cohere
@pytest.mark.integration
class TestOpenAICompatible:
"""Test OpenAI-compatible API functionality."""
def test_simple_generation(self, openai_router):
"""Test simple text generation."""
response = openai_router.call("Say hello in 5 words or less")
assert isinstance(response, RouterResponse)
assert response.content is not None
assert response.provider == "openai"
assert response.latency > 0
assert len(response.content.split()) <= 10 # Allow some flexibility
def test_generation_with_parameters(self, openai_router):
"""Test generation with custom parameters."""
response = openai_router.call(
"Write a haiku about coding",
temperature=0.7,
max_tokens=100
)
assert len(response.content) > 0
assert response.total_tokens is not None
assert response.input_tokens is not None
assert response.output_tokens is not None
@pytest.mark.asyncio
async def test_async_generation(self, openai_router):
"""Test async generation."""
response = await openai_router.acall("Say 'async works' if this is working")
assert "async" in response.content.lower()
assert response.provider == "openai"
def test_streaming(self, openai_router):
"""Test streaming responses."""
chunks = []
for chunk in openai_router.stream("Count from 1 to 3"):
chunks.append(chunk.content)
assert isinstance(chunk, RouterResponse)
assert chunk.provider == "openai"
assert len(chunks) > 0
full_response = "".join(chunks)
assert len(full_response) > 0
@pytest.mark.asyncio
async def test_async_streaming(self, openai_router):
"""Test async streaming."""
chunks = []
async for chunk in openai_router.astream("List 3 colors"):
chunks.append(chunk.content)
assert isinstance(chunk, RouterResponse)
assert len(chunks) > 0
full_response = "".join(chunks)
assert len(full_response) > 0
@pytest.mark.integration
class TestGemini:
"""Test Gemini API functionality."""
def test_simple_generation(self, gemini_router):
"""Test simple text generation."""
response = gemini_router.call("What is 2+2? Answer in one word.")
assert isinstance(response, RouterResponse)
assert response.provider == "gemini"
assert response.content is not None
assert len(response.content) > 0
def test_generation_with_temperature(self, gemini_router):
"""Test generation with temperature parameter."""
response = gemini_router.call(
"List 3 benefits of Python programming. Be concise.",
temperature=0.5
)
assert response.content is not None
assert response.total_tokens is not None
# Verify the response contains programming benefits (may not always mention Python by name)
assert any(keyword in response.content.lower() for keyword in ["readability", "versatility", "libraries", "development", "programming"])
@pytest.mark.asyncio
async def test_async_generation(self, gemini_router):
"""Test async generation."""
response = await gemini_router.acall("What is 5x5? Answer only the number.")
assert response.content is not None
assert "25" in response.content
def test_streaming(self, gemini_router):
"""Test streaming responses."""
chunks = []
for chunk in gemini_router.stream("Say 'streaming works'"):
chunks.append(chunk.content)
assert chunk.provider == "gemini"
assert len(chunks) > 0
full_response = "".join(chunks)
assert "streaming" in full_response.lower()
@pytest.mark.integration
class TestCohere:
"""Test Cohere API functionality."""
def test_simple_generation(self, cohere_router):
"""Test simple text generation."""
response = cohere_router.call("Complete this: The capital of France is")
assert isinstance(response, RouterResponse)
assert response.provider == "cohere"
assert "Paris" in response.content
def test_generation_with_preamble(self, cohere_router):
"""Test generation with preamble."""
response = cohere_router.call(
"What is machine learning?",
preamble="You are a helpful AI assistant. Be concise.",
max_tokens=100
)
assert len(response.content) > 0
assert response.finish_reason is not None
@pytest.mark.asyncio
async def test_async_generation(self, cohere_router):
"""Test async generation."""
response = await cohere_router.acall("Say hello asynchronously")
assert response.content is not None
assert response.provider == "cohere"
def test_chat_history(self, cohere_router):
"""Test generation with chat history."""
response = cohere_router.call(
"What did I just ask?",
chat_history=[
{"role": "USER", "message": "What is Python?"},
{"role": "CHATBOT", "message": "Python is a programming language."}
]
)
assert response.content is not None
assert len(response.content) > 0
@pytest.mark.integration
class TestFactoryFunction:
"""Test the AIRouter factory functionality."""
def test_create_openai_router(self, api_keys):
"""Test creating OpenAI router via factory."""
router = AIRouter(
provider="openai",
model="gpt-4o",
base_url=api_keys["openai_base_url"]
)
response = router.call("Say 'factory works'")
assert "factory" in response.content.lower()
assert response.provider == "openai"
def test_create_gemini_router(self):
"""Test creating Gemini router via factory."""
router = AIRouter(provider="gemini")
response = router.call("What is 1+1?")
assert response.provider == "gemini"
assert "2" in response.content
def test_create_cohere_router(self):
"""Test creating Cohere router via factory."""
router = AIRouter(provider="cohere")
response = router.call("Hello")
assert response.provider == "cohere"
assert len(response.content) > 0
def test_invalid_provider(self):
"""Test factory with invalid provider."""
with pytest.raises(ValueError, match="Unknown provider"):
AIRouter(provider="invalid")
@pytest.mark.integration
class TestGenerationFeatures:
"""Test various generation features across providers."""
@pytest.mark.parametrize("provider,router_fixture", [
("openai", "openai_router"),
("gemini", "gemini_router"),
("cohere", "cohere_router"),
])
def test_max_tokens_limit(self, provider, router_fixture, request):
"""Test max_tokens parameter across providers."""
router = request.getfixturevalue(router_fixture)
response = router.call(
"Write a very long story about dragons",
max_tokens=10
)
assert response.content is not None
assert response.provider == provider
# Output should be limited (though exact token count varies by tokenizer)
assert len(response.content.split()) < 50
@pytest.mark.parametrize("temperature", [0.0, 0.5, 1.0])
def test_temperature_variations(self, gemini_router, temperature):
"""Test different temperature settings."""
response = gemini_router.call(
"Generate a random word",
temperature=temperature
)
assert response.content is not None
assert len(response.content) > 0
def test_cost_tracking(self, openai_router):
"""Test cost tracking functionality."""
from router.config import config
original_track_costs = config.track_costs
config.track_costs = True
try:
response = openai_router.call("Hello")
if response.cost is not None:
assert response.cost >= 0
assert isinstance(response.cost, float)
finally:
config.track_costs = original_track_costs

188
tests/test_rerank.py Normal file
View file

@ -0,0 +1,188 @@
"""Test reranking with real Cohere API."""
import pytest
from router.rerank import CohereRerank, RerankDocument, RerankResponse
@pytest.mark.integration
class TestCohereRerank:
"""Test Cohere reranking functionality."""
def test_basic_reranking(self, cohere_reranker):
"""Test basic document reranking."""
documents = [
"Python is a programming language",
"The weather is nice today",
"Machine learning with Python",
"Coffee is a popular beverage",
"Deep learning frameworks in Python"
]
query = "Python programming"
response = cohere_reranker.rerank(query, documents)
assert isinstance(response, RerankResponse)
assert response.provider == "cohere"
assert response.num_documents == 5
assert len(response.results) > 0
# Check results are sorted by relevance
scores = [r.relevance_score for r in response.results]
assert scores == sorted(scores, reverse=True), "Results should be sorted by relevance"
# The most relevant document should be about Python
top_result = response.results[0]
assert "Python" in top_result.document.text
def test_rerank_with_top_n(self, cohere_reranker):
"""Test reranking with top_n parameter."""
documents = [
"Document 1", "Document 2", "Document 3",
"Document 4", "Document 5"
]
response = cohere_reranker.rerank("test", documents, top_n=3)
assert len(response.results) == 3
assert all(0 <= r.index < 5 for r in response.results)
def test_rerank_with_document_objects(self, cohere_reranker):
"""Test reranking with RerankDocument objects."""
documents = [
RerankDocument(text="Python tutorial", id="doc1"),
RerankDocument(text="Java guide", id="doc2"),
RerankDocument(text="Python cookbook", id="doc3")
]
response = cohere_reranker.rerank("Python", documents)
assert response.results[0].document.text in ["Python tutorial", "Python cookbook"]
assert len(response.results) == 3
def test_rerank_mixed_document_types(self, cohere_reranker):
"""Test reranking with mixed document types."""
documents = [
"Plain string document",
{"text": "Dict with text field"},
RerankDocument(text="RerankDocument object")
]
response = cohere_reranker.rerank("test", documents)
assert len(response.results) == 3
assert all(hasattr(r.document, 'text') for r in response.results)
@pytest.mark.asyncio
async def test_async_reranking(self, cohere_reranker):
"""Test async reranking."""
documents = [
"Async programming in Python",
"Synchronous vs asynchronous",
"Python asyncio tutorial",
"Database operations",
]
query = "async Python"
response = await cohere_reranker.arerank(query, documents, top_n=2)
assert isinstance(response, RerankResponse)
assert len(response.results) == 2
assert response.results[0].relevance_score >= response.results[1].relevance_score
def test_single_document_reranking(self, cohere_reranker):
"""Test reranking with a single document."""
response = cohere_reranker.rerank("test", ["single document"])
assert len(response.results) == 1
assert response.results[0].index == 0
assert response.results[0].document.text == "single document"
def test_empty_query(self, cohere_reranker):
"""Test reranking with empty query."""
documents = ["doc1", "doc2"]
# Empty query should raise ValueError
with pytest.raises(ValueError, match="Query cannot be empty"):
cohere_reranker.rerank("", documents)
@pytest.mark.parametrize("model", [
"rerank-english-v3.0",
"rerank-multilingual-v3.0"
])
def test_different_models(self, model):
"""Test different rerank models."""
reranker = CohereRerank(model=model)
documents = ["Hello world", "Bonjour monde", "Hola mundo"]
response = reranker.rerank("greeting", documents)
assert response.model == model
assert len(response.results) == 3
def test_relevance_scores(self, cohere_reranker):
"""Test that relevance scores are properly set."""
documents = [
"Exact match for Python programming",
"Somewhat related to Python",
"Completely unrelated topic"
]
response = cohere_reranker.rerank("Python programming", documents)
# All results should have relevance scores
assert all(hasattr(r, 'relevance_score') for r in response.results)
assert all(isinstance(r.relevance_score, (int, float)) for r in response.results)
# Scores should be between 0 and 1 (typical range)
assert all(0 <= r.relevance_score <= 1 for r in response.results)
def test_cost_tracking(self, cohere_reranker):
"""Test cost tracking for reranking."""
# Enable cost tracking
from router.config import config
original_track_costs = config.track_costs
config.track_costs = True
try:
response = cohere_reranker.rerank("test", ["doc1", "doc2"])
if response.cost is not None:
assert response.cost > 0
assert isinstance(response.cost, float)
finally:
config.track_costs = original_track_costs
@pytest.mark.integration
class TestRerankEdgeCases:
"""Test edge cases and error handling."""
def test_very_long_query(self, cohere_reranker):
"""Test with a very long query."""
long_query = "Python " * 100 # Very long query
documents = ["Python doc", "Java doc"]
response = cohere_reranker.rerank(long_query, documents)
assert isinstance(response, RerankResponse)
def test_special_characters(self, cohere_reranker):
"""Test documents with special characters."""
documents = [
"Document with @#$% special chars",
"Normal document",
"Document with émojis 🐍"
]
response = cohere_reranker.rerank("special", documents)
assert len(response.results) == 3
@pytest.mark.parametrize("num_docs", [1, 10, 50])
def test_various_document_counts(self, cohere_reranker, num_docs):
"""Test with different numbers of documents."""
documents = [f"Document {i}" for i in range(num_docs)]
response = cohere_reranker.rerank("Document", documents)
assert len(response.results) == num_docs
assert response.num_documents == num_docs