All test cases working. stable code
This commit is contained in:
parent
9961cb55a6
commit
c0fa06973e
15 changed files with 1121 additions and 9 deletions
90
conftest.py
Normal file
90
conftest.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
"""Pytest configuration and fixtures."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Suppress warnings
|
||||||
|
warnings.filterwarnings("ignore", category=DeprecationWarning, module="cohere")
|
||||||
|
warnings.filterwarnings("ignore", message=".*__fields__.*", category=DeprecationWarning)
|
||||||
|
|
||||||
|
# Add router to path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
# Load environment variables once
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
"""Add custom command line options."""
|
||||||
|
parser.addoption(
|
||||||
|
"--ollama",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Run Ollama tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def api_keys():
|
||||||
|
"""Fixture to provide API keys."""
|
||||||
|
return {
|
||||||
|
"cohere": os.getenv("COHERE_API_KEY"),
|
||||||
|
"gemini": os.getenv("GEMINI_API_KEY"),
|
||||||
|
"openai": os.getenv("OPENAI_API_KEY"),
|
||||||
|
"openai_base_url": os.getenv("OPENAI_BASE_URL", "https://veronica.pratikn.com"),
|
||||||
|
"ollama_base_url": os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cohere_embedder():
|
||||||
|
"""Create a Cohere embedding instance."""
|
||||||
|
from router.embed import CohereEmbedding
|
||||||
|
return CohereEmbedding()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gemini_embedder():
|
||||||
|
"""Create a Gemini embedding instance."""
|
||||||
|
from router.embed import GeminiEmbedding
|
||||||
|
return GeminiEmbedding()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ollama_embedder():
|
||||||
|
"""Create an Ollama embedding instance."""
|
||||||
|
from router.embed import OllamaEmbedding
|
||||||
|
return OllamaEmbedding()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cohere_reranker():
|
||||||
|
"""Create a Cohere rerank instance."""
|
||||||
|
from router.rerank import CohereRerank
|
||||||
|
return CohereRerank()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def openai_router(api_keys):
|
||||||
|
"""Create an OpenAI router instance."""
|
||||||
|
from router.openai_compatible import OpenAI
|
||||||
|
return OpenAI(base_url=api_keys["openai_base_url"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gemini_router():
|
||||||
|
"""Create a Gemini router instance."""
|
||||||
|
from router.gemini import Gemini
|
||||||
|
return Gemini()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cohere_router():
|
||||||
|
"""Create a Cohere router instance."""
|
||||||
|
from router.cohere import Cohere
|
||||||
|
return Cohere()
|
16
pytest.ini
Normal file
16
pytest.ini
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
[pytest]
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
addopts = -v --tb=short --strict-markers
|
||||||
|
asyncio_mode = auto
|
||||||
|
markers =
|
||||||
|
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||||
|
integration: marks tests as integration tests
|
||||||
|
unit: marks tests as unit tests
|
||||||
|
ollama: marks tests that require Ollama to be running
|
||||||
|
filterwarnings =
|
||||||
|
ignore::DeprecationWarning:cohere.*
|
||||||
|
ignore:.*__fields__.*:DeprecationWarning
|
||||||
|
ignore:.*model_fields.*:DeprecationWarning
|
11
requirements-test.txt
Normal file
11
requirements-test.txt
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
# Test dependencies
|
||||||
|
pytest>=7.0.0
|
||||||
|
pytest-asyncio>=0.21.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
|
||||||
|
# Required packages for the router
|
||||||
|
cohere>=5.11.0
|
||||||
|
google-genai>=0.1.0
|
||||||
|
openai>=1.0.0
|
||||||
|
httpx>=0.24.0 # For async HTTP in Ollama
|
||||||
|
requests>=2.31.0 # For sync HTTP in Ollama
|
|
@ -1,6 +1,7 @@
|
||||||
"""AI Router - Clean abstraction for AI providers."""
|
"""AI Router - Clean abstraction for AI providers."""
|
||||||
|
|
||||||
from .base import AIRouter, RouterResponse
|
from .base import AIRouter as BaseAIRouter, RouterResponse
|
||||||
|
from .factory import AIRouter, create_router
|
||||||
from .config import RouterConfig, config
|
from .config import RouterConfig, config
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
RouterError,
|
RouterError,
|
||||||
|
@ -35,9 +36,11 @@ __version__ = "0.1.0"
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Base classes
|
# Base classes
|
||||||
"AIRouter",
|
"AIRouter",
|
||||||
|
"BaseAIRouter",
|
||||||
"RouterResponse",
|
"RouterResponse",
|
||||||
"RouterConfig",
|
"RouterConfig",
|
||||||
"config",
|
"config",
|
||||||
|
"create_router",
|
||||||
# Generation Providers
|
# Generation Providers
|
||||||
"Gemini",
|
"Gemini",
|
||||||
"OpenAI",
|
"OpenAI",
|
||||||
|
|
|
@ -196,8 +196,11 @@ class CohereEmbedding(BaseEmbedding):
|
||||||
"texts": texts,
|
"texts": texts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# For v3 models, input_type is required - default to search_document
|
||||||
if input_type:
|
if input_type:
|
||||||
params["input_type"] = input_type
|
params["input_type"] = input_type
|
||||||
|
elif "v3" in self.model or "3.0" in self.model or "3.5" in self.model:
|
||||||
|
params["input_type"] = "search_document"
|
||||||
|
|
||||||
if truncate:
|
if truncate:
|
||||||
params["truncate"] = truncate
|
params["truncate"] = truncate
|
||||||
|
@ -674,7 +677,7 @@ class OllamaEmbedding(BaseEmbedding):
|
||||||
dimension = len(embeddings[0]) if embeddings else 0
|
dimension = len(embeddings[0]) if embeddings else 0
|
||||||
|
|
||||||
# No cost for local models
|
# No cost for local models
|
||||||
cost = 0.0 if config.track_costs else None
|
cost = 0.0 # Always 0.0 for local models
|
||||||
|
|
||||||
return EmbeddingResponse(
|
return EmbeddingResponse(
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
|
|
57
router/factory.py
Normal file
57
router/factory.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
"""Factory function for creating AI router instances."""
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
from .base import AIRouter
|
||||||
|
from .gemini import Gemini
|
||||||
|
from .openai_compatible import OpenAI
|
||||||
|
from .cohere import Cohere
|
||||||
|
|
||||||
|
|
||||||
|
def create_router(
|
||||||
|
provider: str,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> AIRouter:
|
||||||
|
"""Create an AI router instance based on provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name (openai, gemini, cohere)
|
||||||
|
model: Model to use (provider-specific default if not provided)
|
||||||
|
api_key: API key (optional if set in environment)
|
||||||
|
**kwargs: Additional provider-specific configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AIRouter instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider is unknown
|
||||||
|
"""
|
||||||
|
provider = provider.lower()
|
||||||
|
|
||||||
|
if provider == "openai":
|
||||||
|
return OpenAI(
|
||||||
|
model=model or "gpt-4o",
|
||||||
|
api_key=api_key,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
elif provider in ["gemini", "google"]:
|
||||||
|
return Gemini(
|
||||||
|
model=model or "gemini-2.0-flash-001",
|
||||||
|
api_key=api_key,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
elif provider == "cohere":
|
||||||
|
return Cohere(
|
||||||
|
model=model or "command-r-plus",
|
||||||
|
api_key=api_key,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience alias that matches the test expectations
|
||||||
|
def AIRouter(provider: str, **kwargs: Any) -> AIRouter:
|
||||||
|
"""Factory function alias for backward compatibility."""
|
||||||
|
return create_router(provider, **kwargs)
|
|
@ -16,7 +16,7 @@ class OpenAICompatible(AIRouter):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str = "gpt-3.5-turbo",
|
model: str = "gpt-4o",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
|
|
|
@ -107,7 +107,7 @@ class CohereRerank:
|
||||||
response = self.client.rerank(**params)
|
response = self.client.rerank(**params)
|
||||||
latency = time.time() - start_time
|
latency = time.time() - start_time
|
||||||
|
|
||||||
return self._parse_response(response, latency, len(params["documents"]))
|
return self._parse_response(response, latency, len(params["documents"]), documents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise map_provider_error("cohere", e)
|
raise map_provider_error("cohere", e)
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ class CohereRerank:
|
||||||
response = await self.async_client.rerank(**params)
|
response = await self.async_client.rerank(**params)
|
||||||
latency = time.time() - start_time
|
latency = time.time() - start_time
|
||||||
|
|
||||||
return self._parse_response(response, latency, len(params["documents"]))
|
return self._parse_response(response, latency, len(params["documents"]), documents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise map_provider_error("cohere", e)
|
raise map_provider_error("cohere", e)
|
||||||
|
|
||||||
|
@ -158,6 +158,10 @@ class CohereRerank:
|
||||||
Returns:
|
Returns:
|
||||||
Request parameters
|
Request parameters
|
||||||
"""
|
"""
|
||||||
|
# Validate query is not empty
|
||||||
|
if not query or not query.strip():
|
||||||
|
raise ValueError("Query cannot be empty")
|
||||||
|
|
||||||
# Convert documents to the format expected by Cohere
|
# Convert documents to the format expected by Cohere
|
||||||
formatted_docs = []
|
formatted_docs = []
|
||||||
for i, doc in enumerate(documents):
|
for i, doc in enumerate(documents):
|
||||||
|
@ -176,23 +180,29 @@ class CohereRerank:
|
||||||
"model": kwargs.get("model", self.model),
|
"model": kwargs.get("model", self.model),
|
||||||
"query": query,
|
"query": query,
|
||||||
"documents": formatted_docs,
|
"documents": formatted_docs,
|
||||||
|
"return_documents": True, # Always return documents for parsing
|
||||||
}
|
}
|
||||||
|
|
||||||
if top_n is not None:
|
if top_n is not None:
|
||||||
params["top_n"] = top_n
|
params["top_n"] = top_n
|
||||||
|
|
||||||
# Add optional parameters for v5.15.0
|
# Add optional parameters for v5.15.0
|
||||||
for key in ["max_chunks_per_doc", "return_documents", "rank_fields"]:
|
for key in ["max_chunks_per_doc", "rank_fields"]:
|
||||||
if key in kwargs:
|
if key in kwargs:
|
||||||
params[key] = kwargs[key]
|
params[key] = kwargs[key]
|
||||||
|
|
||||||
|
# Allow override of return_documents if explicitly set to False
|
||||||
|
if "return_documents" in kwargs and kwargs["return_documents"] is False:
|
||||||
|
params["return_documents"] = False
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def _parse_response(
|
def _parse_response(
|
||||||
self,
|
self,
|
||||||
raw_response: Any,
|
raw_response: Any,
|
||||||
latency: float,
|
latency: float,
|
||||||
num_documents: int
|
num_documents: int,
|
||||||
|
original_documents: Optional[List[Any]] = None
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
"""Parse Cohere rerank response.
|
"""Parse Cohere rerank response.
|
||||||
|
|
||||||
|
@ -200,6 +210,7 @@ class CohereRerank:
|
||||||
raw_response: Raw response from Cohere
|
raw_response: Raw response from Cohere
|
||||||
latency: Request latency
|
latency: Request latency
|
||||||
num_documents: Total number of documents submitted
|
num_documents: Total number of documents submitted
|
||||||
|
original_documents: Original documents for fallback
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
RerankResponse
|
RerankResponse
|
||||||
|
@ -211,8 +222,24 @@ class CohereRerank:
|
||||||
for result in raw_response.results:
|
for result in raw_response.results:
|
||||||
# Extract document info
|
# Extract document info
|
||||||
doc_text = ""
|
doc_text = ""
|
||||||
if hasattr(result, "document") and hasattr(result.document, "text"):
|
|
||||||
doc_text = result.document.text
|
# Try to get text from the returned document
|
||||||
|
if hasattr(result, "document"):
|
||||||
|
if hasattr(result.document, "text"):
|
||||||
|
doc_text = result.document.text
|
||||||
|
elif isinstance(result.document, dict) and "text" in result.document:
|
||||||
|
doc_text = result.document["text"]
|
||||||
|
|
||||||
|
# If no document text found and we have the original documents,
|
||||||
|
# use the index to get the original text
|
||||||
|
if not doc_text and original_documents and 0 <= result.index < len(original_documents):
|
||||||
|
orig_doc = original_documents[result.index]
|
||||||
|
if isinstance(orig_doc, str):
|
||||||
|
doc_text = orig_doc
|
||||||
|
elif isinstance(orig_doc, dict) and "text" in orig_doc:
|
||||||
|
doc_text = orig_doc["text"]
|
||||||
|
elif hasattr(orig_doc, "text"):
|
||||||
|
doc_text = getattr(orig_doc, "text")
|
||||||
|
|
||||||
results.append(RerankResult(
|
results.append(RerankResult(
|
||||||
index=result.index,
|
index=result.index,
|
||||||
|
|
96
tests/README.md
Normal file
96
tests/README.md
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
# AI Router Test Suite
|
||||||
|
|
||||||
|
This test suite contains real API tests for the AI Router library using pytest.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
1. Install test dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements-test.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create a `.env` file in the project root with your API keys:
|
||||||
|
```bash
|
||||||
|
COHERE_API_KEY=your_cohere_key
|
||||||
|
GEMINI_API_KEY=your_gemini_key
|
||||||
|
OPENAI_API_KEY=your_openai_key
|
||||||
|
OPENAI_BASE_URL=your_openai_base_url # Optional custom endpoint
|
||||||
|
OLLAMA_BASE_URL=http://localhost:11434 # For local Ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
3. For Ollama tests, ensure Ollama is running:
|
||||||
|
```bash
|
||||||
|
ollama serve
|
||||||
|
ollama pull nomic-embed-text:latest
|
||||||
|
ollama pull mxbai-embed-large:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running Tests with Pytest
|
||||||
|
|
||||||
|
### Run all tests:
|
||||||
|
```bash
|
||||||
|
pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run with verbose output:
|
||||||
|
```bash
|
||||||
|
pytest -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run specific test file:
|
||||||
|
```bash
|
||||||
|
pytest tests/test_embeddings.py
|
||||||
|
pytest tests/test_generation.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run specific test class or method:
|
||||||
|
```bash
|
||||||
|
pytest tests/test_embeddings.py::TestCohereEmbeddings
|
||||||
|
pytest tests/test_embeddings.py::TestCohereEmbeddings::test_single_text_embedding
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run tests by marker:
|
||||||
|
```bash
|
||||||
|
# Run only integration tests
|
||||||
|
pytest -m integration
|
||||||
|
|
||||||
|
# Run tests excluding Ollama
|
||||||
|
pytest -m "not ollama"
|
||||||
|
|
||||||
|
# Run only Ollama tests
|
||||||
|
pytest -m ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run with coverage:
|
||||||
|
```bash
|
||||||
|
pytest --cov=router tests/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run tests in parallel:
|
||||||
|
```bash
|
||||||
|
# Install pytest-xdist first
|
||||||
|
pip install pytest-xdist
|
||||||
|
pytest -n auto
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Coverage
|
||||||
|
|
||||||
|
- **test_config.py**: Tests configuration loading and API key management
|
||||||
|
- **test_embeddings.py**: Tests embedding APIs (Cohere, Gemini, Ollama)
|
||||||
|
- **test_rerank.py**: Tests Cohere reranking functionality
|
||||||
|
- **test_generation.py**: Tests text generation (OpenAI-compatible, Gemini, Cohere)
|
||||||
|
|
||||||
|
## Test Markers
|
||||||
|
|
||||||
|
- `@pytest.mark.integration`: Marks integration tests that call real APIs
|
||||||
|
- `@pytest.mark.asyncio`: Marks async tests
|
||||||
|
- `@pytest.mark.ollama`: Marks tests that require Ollama to be running
|
||||||
|
- `@pytest.mark.parametrize`: Used for running tests with multiple input values
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- These tests use real APIs and will consume API credits
|
||||||
|
- Ollama tests are marked with `@pytest.mark.ollama` and can be skipped
|
||||||
|
- All tests include both sync and async variants
|
||||||
|
- Tests verify functionality without checking exact output values
|
||||||
|
- The test suite is configured for `asyncio_mode = auto` in pytest.ini
|
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
"""Test suite for AI Router."""
|
95
tests/run_all_tests.py
Normal file
95
tests/run_all_tests.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
"""Run all tests in the test suite."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(test_file: str) -> bool:
|
||||||
|
"""Run a single test file and return success status."""
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Running {test_file}...")
|
||||||
|
print('='*60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[sys.executable, test_file],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True
|
||||||
|
)
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("Warnings:", result.stderr)
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"❌ Test failed: {test_file}")
|
||||||
|
print("STDOUT:", e.stdout)
|
||||||
|
print("STDERR:", e.stderr)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all tests."""
|
||||||
|
print("🚀 AI Router Test Suite")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
env_path = Path(__file__).parent.parent / ".env"
|
||||||
|
if env_path.exists():
|
||||||
|
load_dotenv(env_path)
|
||||||
|
print(f"✓ Loaded environment from {env_path}")
|
||||||
|
else:
|
||||||
|
print("⚠ No .env file found, using system environment")
|
||||||
|
|
||||||
|
# Test files in order
|
||||||
|
test_files = [
|
||||||
|
"test_config.py",
|
||||||
|
"test_embeddings.py",
|
||||||
|
"test_rerank.py",
|
||||||
|
"test_generation.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Track results
|
||||||
|
results = {}
|
||||||
|
tests_dir = Path(__file__).parent
|
||||||
|
|
||||||
|
for test_file in test_files:
|
||||||
|
test_path = tests_dir / test_file
|
||||||
|
if test_path.exists():
|
||||||
|
results[test_file] = run_test(str(test_path))
|
||||||
|
else:
|
||||||
|
print(f"⚠ Test file not found: {test_file}")
|
||||||
|
results[test_file] = False
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("📊 TEST SUMMARY")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
passed = sum(1 for success in results.values() if success)
|
||||||
|
total = len(results)
|
||||||
|
|
||||||
|
for test_file, success in results.items():
|
||||||
|
status = "✅ PASSED" if success else "❌ FAILED"
|
||||||
|
print(f"{test_file:<30} {status}")
|
||||||
|
|
||||||
|
print("-"*60)
|
||||||
|
print(f"Total: {passed}/{total} tests passed")
|
||||||
|
|
||||||
|
if passed == total:
|
||||||
|
print("\n🎉 All tests passed!")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
print(f"\n❌ {total - passed} test(s) failed")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = main()
|
||||||
|
sys.exit(exit_code)
|
71
tests/test_config.py
Normal file
71
tests/test_config.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
"""Test configuration loading from environment."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from router.config import RouterConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestRouterConfig:
|
||||||
|
"""Test RouterConfig functionality."""
|
||||||
|
|
||||||
|
def test_config_from_env(self, api_keys):
|
||||||
|
"""Test that config loads API keys from environment."""
|
||||||
|
config = RouterConfig.from_env()
|
||||||
|
|
||||||
|
# Check API keys are loaded
|
||||||
|
assert config.cohere_api_key is not None, "Cohere API key not loaded"
|
||||||
|
assert config.gemini_api_key is not None, "Gemini API key not loaded"
|
||||||
|
assert config.openai_api_key is not None, "OpenAI API key not loaded"
|
||||||
|
|
||||||
|
# Check Ollama configuration
|
||||||
|
assert config.ollama_base_url == api_keys["ollama_base_url"]
|
||||||
|
|
||||||
|
def test_get_api_key(self):
|
||||||
|
"""Test getting API keys by provider."""
|
||||||
|
config = RouterConfig.from_env()
|
||||||
|
|
||||||
|
assert config.get_api_key("cohere") is not None
|
||||||
|
assert config.get_api_key("gemini") is not None
|
||||||
|
assert config.get_api_key("openai") is not None
|
||||||
|
assert config.get_api_key("unknown") is None
|
||||||
|
|
||||||
|
def test_cost_calculation(self):
|
||||||
|
"""Test cost calculation methods."""
|
||||||
|
config = RouterConfig.from_env()
|
||||||
|
config.track_costs = True
|
||||||
|
|
||||||
|
# Test token cost calculation
|
||||||
|
cost = config.calculate_cost("gpt-4o", input_tokens=1000, output_tokens=500)
|
||||||
|
assert cost is not None
|
||||||
|
assert cost > 0
|
||||||
|
|
||||||
|
# Test embedding cost
|
||||||
|
embed_cost = config.calculate_embed_cost("text-embedding-004", num_tokens=1000)
|
||||||
|
assert embed_cost is not None
|
||||||
|
assert embed_cost >= 0
|
||||||
|
|
||||||
|
# Test rerank cost
|
||||||
|
rerank_cost = config.calculate_rerank_cost("rerank-english-v3.0", num_searches=10)
|
||||||
|
assert rerank_cost is not None
|
||||||
|
assert rerank_cost > 0
|
||||||
|
|
||||||
|
def test_ollama_models_config(self):
|
||||||
|
"""Test Ollama models configuration."""
|
||||||
|
config = RouterConfig.from_env()
|
||||||
|
|
||||||
|
assert "mxbai-embed-large:latest" in config.ollama_embedding_models
|
||||||
|
assert "nomic-embed-text:latest" in config.ollama_embedding_models
|
||||||
|
assert len(config.ollama_embedding_models) >= 2
|
||||||
|
|
||||||
|
def test_config_to_dict(self):
|
||||||
|
"""Test config serialization."""
|
||||||
|
config = RouterConfig.from_env()
|
||||||
|
config_dict = config.to_dict()
|
||||||
|
|
||||||
|
# Check that API keys are masked
|
||||||
|
assert config_dict["openai_api_key"] == "***"
|
||||||
|
assert config_dict["cohere_api_key"] == "***"
|
||||||
|
assert config_dict["gemini_api_key"] == "***"
|
||||||
|
|
||||||
|
# Check other settings are present
|
||||||
|
assert "default_temperature" in config_dict
|
||||||
|
assert "track_costs" in config_dict
|
209
tests/test_embeddings.py
Normal file
209
tests/test_embeddings.py
Normal file
|
@ -0,0 +1,209 @@
|
||||||
|
"""Test embeddings with real API calls."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from router.embed import (
|
||||||
|
create_embedding,
|
||||||
|
CohereEmbedding,
|
||||||
|
GeminiEmbedding,
|
||||||
|
OllamaEmbedding,
|
||||||
|
EmbeddingResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Removed skip decorator - using pytest marks instead
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestCohereEmbeddings:
|
||||||
|
"""Test Cohere embedding functionality."""
|
||||||
|
|
||||||
|
def test_single_text_embedding(self, cohere_embedder):
|
||||||
|
"""Test embedding a single text."""
|
||||||
|
response = cohere_embedder.embed("Hello world")
|
||||||
|
|
||||||
|
assert isinstance(response, EmbeddingResponse)
|
||||||
|
assert len(response.embeddings) == 1
|
||||||
|
assert len(response.embeddings[0]) > 0
|
||||||
|
assert response.dimension > 0
|
||||||
|
assert response.provider == "cohere"
|
||||||
|
assert response.latency > 0
|
||||||
|
assert response.num_inputs == 1
|
||||||
|
|
||||||
|
def test_multiple_texts_embedding(self, cohere_embedder):
|
||||||
|
"""Test embedding multiple texts."""
|
||||||
|
texts = ["First text", "Second text", "Third text"]
|
||||||
|
response = cohere_embedder.embed(texts)
|
||||||
|
|
||||||
|
assert len(response.embeddings) == 3
|
||||||
|
assert response.num_inputs == 3
|
||||||
|
assert all(len(emb) == response.dimension for emb in response.embeddings)
|
||||||
|
|
||||||
|
def test_input_type_parameter(self, cohere_embedder):
|
||||||
|
"""Test embedding with input_type parameter."""
|
||||||
|
response = cohere_embedder.embed("Search query", input_type="search_query")
|
||||||
|
assert response.embeddings[0] is not None
|
||||||
|
|
||||||
|
response = cohere_embedder.embed("Document text", input_type="search_document")
|
||||||
|
assert response.embeddings[0] is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_embedding(self, cohere_embedder):
|
||||||
|
"""Test async embedding."""
|
||||||
|
response = await cohere_embedder.aembed("Hello async world")
|
||||||
|
|
||||||
|
assert isinstance(response, EmbeddingResponse)
|
||||||
|
assert len(response.embeddings) == 1
|
||||||
|
assert response.provider == "cohere"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_multiple_embeddings(self, cohere_embedder):
|
||||||
|
"""Test async embedding of multiple texts."""
|
||||||
|
texts = ["Async text 1", "Async text 2"]
|
||||||
|
response = await cohere_embedder.aembed(texts)
|
||||||
|
|
||||||
|
assert len(response.embeddings) == 2
|
||||||
|
assert response.num_inputs == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestGeminiEmbeddings:
|
||||||
|
"""Test Gemini embedding functionality."""
|
||||||
|
|
||||||
|
def test_single_text_embedding(self, gemini_embedder):
|
||||||
|
"""Test embedding a single text."""
|
||||||
|
response = gemini_embedder.embed("Hello from Gemini")
|
||||||
|
|
||||||
|
assert isinstance(response, EmbeddingResponse)
|
||||||
|
assert len(response.embeddings) == 1
|
||||||
|
assert len(response.embeddings[0]) > 0
|
||||||
|
assert response.provider == "gemini"
|
||||||
|
assert response.dimension > 0
|
||||||
|
|
||||||
|
def test_multiple_texts_embedding(self, gemini_embedder):
|
||||||
|
"""Test embedding multiple texts."""
|
||||||
|
texts = ["Gemini text one", "Gemini text two"]
|
||||||
|
response = gemini_embedder.embed(texts)
|
||||||
|
|
||||||
|
assert len(response.embeddings) == 2
|
||||||
|
assert all(len(emb) == response.dimension for emb in response.embeddings)
|
||||||
|
|
||||||
|
def test_task_type_parameter(self, gemini_embedder):
|
||||||
|
"""Test embedding with task_type parameter."""
|
||||||
|
response = gemini_embedder.embed(
|
||||||
|
"Test with task type",
|
||||||
|
task_type="retrieval_document"
|
||||||
|
)
|
||||||
|
assert response.embeddings[0] is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_embedding(self, gemini_embedder):
|
||||||
|
"""Test async embedding."""
|
||||||
|
response = await gemini_embedder.aembed("Async Gemini test")
|
||||||
|
|
||||||
|
assert isinstance(response, EmbeddingResponse)
|
||||||
|
assert response.provider == "gemini"
|
||||||
|
assert len(response.embeddings) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestOllamaEmbeddings:
|
||||||
|
"""Test Ollama embedding functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.ollama
|
||||||
|
def test_single_text_embedding(self, ollama_embedder):
|
||||||
|
"""Test embedding a single text with Ollama."""
|
||||||
|
try:
|
||||||
|
response = ollama_embedder.embed("Test Ollama")
|
||||||
|
|
||||||
|
assert isinstance(response, EmbeddingResponse)
|
||||||
|
assert response.provider == "ollama"
|
||||||
|
assert response.cost == 0.0 # Local model should be free
|
||||||
|
assert len(response.embeddings) == 1
|
||||||
|
assert response.dimension > 0
|
||||||
|
except Exception as e:
|
||||||
|
if "Cannot connect to Ollama" in str(e):
|
||||||
|
pytest.skip("Ollama not running")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@pytest.mark.ollama
|
||||||
|
def test_multiple_texts_embedding(self, ollama_embedder):
|
||||||
|
"""Test embedding multiple texts with Ollama."""
|
||||||
|
try:
|
||||||
|
response = ollama_embedder.embed(["Local text 1", "Local text 2"])
|
||||||
|
|
||||||
|
assert len(response.embeddings) == 2
|
||||||
|
assert response.num_inputs == 2
|
||||||
|
assert response.metadata.get("local") is True
|
||||||
|
except Exception as e:
|
||||||
|
if "Cannot connect to Ollama" in str(e):
|
||||||
|
pytest.skip("Ollama not running")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestEmbeddingFactory:
|
||||||
|
"""Test the embedding factory function."""
|
||||||
|
|
||||||
|
def test_create_cohere_embedder(self):
|
||||||
|
"""Test creating Cohere embedder via factory."""
|
||||||
|
embedder = create_embedding(provider="cohere")
|
||||||
|
response = embedder.embed("Factory test")
|
||||||
|
|
||||||
|
assert response.provider == "cohere"
|
||||||
|
assert isinstance(embedder, CohereEmbedding)
|
||||||
|
|
||||||
|
def test_create_gemini_embedder(self):
|
||||||
|
"""Test creating Gemini embedder via factory."""
|
||||||
|
embedder = create_embedding(provider="gemini", model="text-embedding-004")
|
||||||
|
response = embedder.embed("Factory Gemini")
|
||||||
|
|
||||||
|
assert response.provider == "gemini"
|
||||||
|
assert response.model == "text-embedding-004"
|
||||||
|
assert isinstance(embedder, GeminiEmbedding)
|
||||||
|
|
||||||
|
def test_create_ollama_embedder(self):
|
||||||
|
"""Test creating Ollama embedder via factory."""
|
||||||
|
embedder = create_embedding(
|
||||||
|
provider="ollama",
|
||||||
|
model="mxbai-embed-large:latest"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedder.model == "mxbai-embed-large:latest"
|
||||||
|
assert isinstance(embedder, OllamaEmbedding)
|
||||||
|
|
||||||
|
def test_invalid_provider(self):
|
||||||
|
"""Test factory with invalid provider."""
|
||||||
|
with pytest.raises(ValueError, match="Unknown embedding provider"):
|
||||||
|
create_embedding(provider="invalid")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestEmbeddingComparison:
|
||||||
|
"""Test comparing embeddings across providers."""
|
||||||
|
|
||||||
|
def test_embedding_dimensions(self, cohere_embedder, gemini_embedder):
|
||||||
|
"""Test that different models produce different dimensions."""
|
||||||
|
test_text = "dimension test"
|
||||||
|
|
||||||
|
cohere_resp = cohere_embedder.embed(test_text)
|
||||||
|
gemini_resp = gemini_embedder.embed(test_text)
|
||||||
|
|
||||||
|
assert cohere_resp.dimension > 0
|
||||||
|
assert gemini_resp.dimension > 0
|
||||||
|
|
||||||
|
# Both should return valid embeddings
|
||||||
|
assert len(cohere_resp.embeddings[0]) == cohere_resp.dimension
|
||||||
|
assert len(gemini_resp.embeddings[0]) == gemini_resp.dimension
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("text", [
|
||||||
|
"Short text",
|
||||||
|
"A much longer text that contains multiple sentences and should still work fine.",
|
||||||
|
"Text with special characters: @#$%^&*()",
|
||||||
|
"文本与中文字符", # Text with Chinese characters
|
||||||
|
])
|
||||||
|
def test_various_text_inputs(self, cohere_embedder, text):
|
||||||
|
"""Test embedding various types of text inputs."""
|
||||||
|
response = cohere_embedder.embed(text)
|
||||||
|
|
||||||
|
assert len(response.embeddings) == 1
|
||||||
|
assert len(response.embeddings[0]) > 0
|
245
tests/test_generation.py
Normal file
245
tests/test_generation.py
Normal file
|
@ -0,0 +1,245 @@
|
||||||
|
"""Test text generation with real APIs."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from router import AIRouter, RouterResponse
|
||||||
|
from router.gemini import Gemini
|
||||||
|
from router.openai_compatible import OpenAI
|
||||||
|
from router.cohere import Cohere
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestOpenAICompatible:
|
||||||
|
"""Test OpenAI-compatible API functionality."""
|
||||||
|
|
||||||
|
def test_simple_generation(self, openai_router):
|
||||||
|
"""Test simple text generation."""
|
||||||
|
response = openai_router.call("Say hello in 5 words or less")
|
||||||
|
|
||||||
|
assert isinstance(response, RouterResponse)
|
||||||
|
assert response.content is not None
|
||||||
|
assert response.provider == "openai"
|
||||||
|
assert response.latency > 0
|
||||||
|
assert len(response.content.split()) <= 10 # Allow some flexibility
|
||||||
|
|
||||||
|
def test_generation_with_parameters(self, openai_router):
|
||||||
|
"""Test generation with custom parameters."""
|
||||||
|
response = openai_router.call(
|
||||||
|
"Write a haiku about coding",
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.content) > 0
|
||||||
|
assert response.total_tokens is not None
|
||||||
|
assert response.input_tokens is not None
|
||||||
|
assert response.output_tokens is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_generation(self, openai_router):
|
||||||
|
"""Test async generation."""
|
||||||
|
response = await openai_router.acall("Say 'async works' if this is working")
|
||||||
|
|
||||||
|
assert "async" in response.content.lower()
|
||||||
|
assert response.provider == "openai"
|
||||||
|
|
||||||
|
def test_streaming(self, openai_router):
|
||||||
|
"""Test streaming responses."""
|
||||||
|
chunks = []
|
||||||
|
for chunk in openai_router.stream("Count from 1 to 3"):
|
||||||
|
chunks.append(chunk.content)
|
||||||
|
assert isinstance(chunk, RouterResponse)
|
||||||
|
assert chunk.provider == "openai"
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
full_response = "".join(chunks)
|
||||||
|
assert len(full_response) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_streaming(self, openai_router):
|
||||||
|
"""Test async streaming."""
|
||||||
|
chunks = []
|
||||||
|
async for chunk in openai_router.astream("List 3 colors"):
|
||||||
|
chunks.append(chunk.content)
|
||||||
|
assert isinstance(chunk, RouterResponse)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
full_response = "".join(chunks)
|
||||||
|
assert len(full_response) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestGemini:
|
||||||
|
"""Test Gemini API functionality."""
|
||||||
|
|
||||||
|
def test_simple_generation(self, gemini_router):
|
||||||
|
"""Test simple text generation."""
|
||||||
|
response = gemini_router.call("What is 2+2? Answer in one word.")
|
||||||
|
|
||||||
|
assert isinstance(response, RouterResponse)
|
||||||
|
assert response.provider == "gemini"
|
||||||
|
assert response.content is not None
|
||||||
|
assert len(response.content) > 0
|
||||||
|
|
||||||
|
def test_generation_with_temperature(self, gemini_router):
|
||||||
|
"""Test generation with temperature parameter."""
|
||||||
|
response = gemini_router.call(
|
||||||
|
"List 3 benefits of Python programming. Be concise.",
|
||||||
|
temperature=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content is not None
|
||||||
|
assert response.total_tokens is not None
|
||||||
|
# Verify the response contains programming benefits (may not always mention Python by name)
|
||||||
|
assert any(keyword in response.content.lower() for keyword in ["readability", "versatility", "libraries", "development", "programming"])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_generation(self, gemini_router):
|
||||||
|
"""Test async generation."""
|
||||||
|
response = await gemini_router.acall("What is 5x5? Answer only the number.")
|
||||||
|
|
||||||
|
assert response.content is not None
|
||||||
|
assert "25" in response.content
|
||||||
|
|
||||||
|
def test_streaming(self, gemini_router):
|
||||||
|
"""Test streaming responses."""
|
||||||
|
chunks = []
|
||||||
|
for chunk in gemini_router.stream("Say 'streaming works'"):
|
||||||
|
chunks.append(chunk.content)
|
||||||
|
assert chunk.provider == "gemini"
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
full_response = "".join(chunks)
|
||||||
|
assert "streaming" in full_response.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestCohere:
|
||||||
|
"""Test Cohere API functionality."""
|
||||||
|
|
||||||
|
def test_simple_generation(self, cohere_router):
|
||||||
|
"""Test simple text generation."""
|
||||||
|
response = cohere_router.call("Complete this: The capital of France is")
|
||||||
|
|
||||||
|
assert isinstance(response, RouterResponse)
|
||||||
|
assert response.provider == "cohere"
|
||||||
|
assert "Paris" in response.content
|
||||||
|
|
||||||
|
def test_generation_with_preamble(self, cohere_router):
|
||||||
|
"""Test generation with preamble."""
|
||||||
|
response = cohere_router.call(
|
||||||
|
"What is machine learning?",
|
||||||
|
preamble="You are a helpful AI assistant. Be concise.",
|
||||||
|
max_tokens=100
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.content) > 0
|
||||||
|
assert response.finish_reason is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_generation(self, cohere_router):
|
||||||
|
"""Test async generation."""
|
||||||
|
response = await cohere_router.acall("Say hello asynchronously")
|
||||||
|
|
||||||
|
assert response.content is not None
|
||||||
|
assert response.provider == "cohere"
|
||||||
|
|
||||||
|
def test_chat_history(self, cohere_router):
|
||||||
|
"""Test generation with chat history."""
|
||||||
|
response = cohere_router.call(
|
||||||
|
"What did I just ask?",
|
||||||
|
chat_history=[
|
||||||
|
{"role": "USER", "message": "What is Python?"},
|
||||||
|
{"role": "CHATBOT", "message": "Python is a programming language."}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content is not None
|
||||||
|
assert len(response.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestFactoryFunction:
|
||||||
|
"""Test the AIRouter factory functionality."""
|
||||||
|
|
||||||
|
def test_create_openai_router(self, api_keys):
|
||||||
|
"""Test creating OpenAI router via factory."""
|
||||||
|
router = AIRouter(
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-4o",
|
||||||
|
base_url=api_keys["openai_base_url"]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = router.call("Say 'factory works'")
|
||||||
|
assert "factory" in response.content.lower()
|
||||||
|
assert response.provider == "openai"
|
||||||
|
|
||||||
|
def test_create_gemini_router(self):
|
||||||
|
"""Test creating Gemini router via factory."""
|
||||||
|
router = AIRouter(provider="gemini")
|
||||||
|
|
||||||
|
response = router.call("What is 1+1?")
|
||||||
|
assert response.provider == "gemini"
|
||||||
|
assert "2" in response.content
|
||||||
|
|
||||||
|
def test_create_cohere_router(self):
|
||||||
|
"""Test creating Cohere router via factory."""
|
||||||
|
router = AIRouter(provider="cohere")
|
||||||
|
|
||||||
|
response = router.call("Hello")
|
||||||
|
assert response.provider == "cohere"
|
||||||
|
assert len(response.content) > 0
|
||||||
|
|
||||||
|
def test_invalid_provider(self):
|
||||||
|
"""Test factory with invalid provider."""
|
||||||
|
with pytest.raises(ValueError, match="Unknown provider"):
|
||||||
|
AIRouter(provider="invalid")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestGenerationFeatures:
|
||||||
|
"""Test various generation features across providers."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider,router_fixture", [
|
||||||
|
("openai", "openai_router"),
|
||||||
|
("gemini", "gemini_router"),
|
||||||
|
("cohere", "cohere_router"),
|
||||||
|
])
|
||||||
|
def test_max_tokens_limit(self, provider, router_fixture, request):
|
||||||
|
"""Test max_tokens parameter across providers."""
|
||||||
|
router = request.getfixturevalue(router_fixture)
|
||||||
|
|
||||||
|
response = router.call(
|
||||||
|
"Write a very long story about dragons",
|
||||||
|
max_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content is not None
|
||||||
|
assert response.provider == provider
|
||||||
|
# Output should be limited (though exact token count varies by tokenizer)
|
||||||
|
assert len(response.content.split()) < 50
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("temperature", [0.0, 0.5, 1.0])
|
||||||
|
def test_temperature_variations(self, gemini_router, temperature):
|
||||||
|
"""Test different temperature settings."""
|
||||||
|
response = gemini_router.call(
|
||||||
|
"Generate a random word",
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content is not None
|
||||||
|
assert len(response.content) > 0
|
||||||
|
|
||||||
|
def test_cost_tracking(self, openai_router):
|
||||||
|
"""Test cost tracking functionality."""
|
||||||
|
from router.config import config
|
||||||
|
original_track_costs = config.track_costs
|
||||||
|
config.track_costs = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = openai_router.call("Hello")
|
||||||
|
|
||||||
|
if response.cost is not None:
|
||||||
|
assert response.cost >= 0
|
||||||
|
assert isinstance(response.cost, float)
|
||||||
|
finally:
|
||||||
|
config.track_costs = original_track_costs
|
188
tests/test_rerank.py
Normal file
188
tests/test_rerank.py
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
"""Test reranking with real Cohere API."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from router.rerank import CohereRerank, RerankDocument, RerankResponse
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestCohereRerank:
|
||||||
|
"""Test Cohere reranking functionality."""
|
||||||
|
|
||||||
|
def test_basic_reranking(self, cohere_reranker):
|
||||||
|
"""Test basic document reranking."""
|
||||||
|
documents = [
|
||||||
|
"Python is a programming language",
|
||||||
|
"The weather is nice today",
|
||||||
|
"Machine learning with Python",
|
||||||
|
"Coffee is a popular beverage",
|
||||||
|
"Deep learning frameworks in Python"
|
||||||
|
]
|
||||||
|
|
||||||
|
query = "Python programming"
|
||||||
|
|
||||||
|
response = cohere_reranker.rerank(query, documents)
|
||||||
|
|
||||||
|
assert isinstance(response, RerankResponse)
|
||||||
|
assert response.provider == "cohere"
|
||||||
|
assert response.num_documents == 5
|
||||||
|
assert len(response.results) > 0
|
||||||
|
|
||||||
|
# Check results are sorted by relevance
|
||||||
|
scores = [r.relevance_score for r in response.results]
|
||||||
|
assert scores == sorted(scores, reverse=True), "Results should be sorted by relevance"
|
||||||
|
|
||||||
|
# The most relevant document should be about Python
|
||||||
|
top_result = response.results[0]
|
||||||
|
assert "Python" in top_result.document.text
|
||||||
|
|
||||||
|
def test_rerank_with_top_n(self, cohere_reranker):
|
||||||
|
"""Test reranking with top_n parameter."""
|
||||||
|
documents = [
|
||||||
|
"Document 1", "Document 2", "Document 3",
|
||||||
|
"Document 4", "Document 5"
|
||||||
|
]
|
||||||
|
|
||||||
|
response = cohere_reranker.rerank("test", documents, top_n=3)
|
||||||
|
|
||||||
|
assert len(response.results) == 3
|
||||||
|
assert all(0 <= r.index < 5 for r in response.results)
|
||||||
|
|
||||||
|
def test_rerank_with_document_objects(self, cohere_reranker):
|
||||||
|
"""Test reranking with RerankDocument objects."""
|
||||||
|
documents = [
|
||||||
|
RerankDocument(text="Python tutorial", id="doc1"),
|
||||||
|
RerankDocument(text="Java guide", id="doc2"),
|
||||||
|
RerankDocument(text="Python cookbook", id="doc3")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = cohere_reranker.rerank("Python", documents)
|
||||||
|
|
||||||
|
assert response.results[0].document.text in ["Python tutorial", "Python cookbook"]
|
||||||
|
assert len(response.results) == 3
|
||||||
|
|
||||||
|
def test_rerank_mixed_document_types(self, cohere_reranker):
|
||||||
|
"""Test reranking with mixed document types."""
|
||||||
|
documents = [
|
||||||
|
"Plain string document",
|
||||||
|
{"text": "Dict with text field"},
|
||||||
|
RerankDocument(text="RerankDocument object")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = cohere_reranker.rerank("test", documents)
|
||||||
|
|
||||||
|
assert len(response.results) == 3
|
||||||
|
assert all(hasattr(r.document, 'text') for r in response.results)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_reranking(self, cohere_reranker):
|
||||||
|
"""Test async reranking."""
|
||||||
|
documents = [
|
||||||
|
"Async programming in Python",
|
||||||
|
"Synchronous vs asynchronous",
|
||||||
|
"Python asyncio tutorial",
|
||||||
|
"Database operations",
|
||||||
|
]
|
||||||
|
|
||||||
|
query = "async Python"
|
||||||
|
|
||||||
|
response = await cohere_reranker.arerank(query, documents, top_n=2)
|
||||||
|
|
||||||
|
assert isinstance(response, RerankResponse)
|
||||||
|
assert len(response.results) == 2
|
||||||
|
assert response.results[0].relevance_score >= response.results[1].relevance_score
|
||||||
|
|
||||||
|
def test_single_document_reranking(self, cohere_reranker):
|
||||||
|
"""Test reranking with a single document."""
|
||||||
|
response = cohere_reranker.rerank("test", ["single document"])
|
||||||
|
|
||||||
|
assert len(response.results) == 1
|
||||||
|
assert response.results[0].index == 0
|
||||||
|
assert response.results[0].document.text == "single document"
|
||||||
|
|
||||||
|
def test_empty_query(self, cohere_reranker):
|
||||||
|
"""Test reranking with empty query."""
|
||||||
|
documents = ["doc1", "doc2"]
|
||||||
|
|
||||||
|
# Empty query should raise ValueError
|
||||||
|
with pytest.raises(ValueError, match="Query cannot be empty"):
|
||||||
|
cohere_reranker.rerank("", documents)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", [
|
||||||
|
"rerank-english-v3.0",
|
||||||
|
"rerank-multilingual-v3.0"
|
||||||
|
])
|
||||||
|
def test_different_models(self, model):
|
||||||
|
"""Test different rerank models."""
|
||||||
|
reranker = CohereRerank(model=model)
|
||||||
|
|
||||||
|
documents = ["Hello world", "Bonjour monde", "Hola mundo"]
|
||||||
|
response = reranker.rerank("greeting", documents)
|
||||||
|
|
||||||
|
assert response.model == model
|
||||||
|
assert len(response.results) == 3
|
||||||
|
|
||||||
|
def test_relevance_scores(self, cohere_reranker):
|
||||||
|
"""Test that relevance scores are properly set."""
|
||||||
|
documents = [
|
||||||
|
"Exact match for Python programming",
|
||||||
|
"Somewhat related to Python",
|
||||||
|
"Completely unrelated topic"
|
||||||
|
]
|
||||||
|
|
||||||
|
response = cohere_reranker.rerank("Python programming", documents)
|
||||||
|
|
||||||
|
# All results should have relevance scores
|
||||||
|
assert all(hasattr(r, 'relevance_score') for r in response.results)
|
||||||
|
assert all(isinstance(r.relevance_score, (int, float)) for r in response.results)
|
||||||
|
|
||||||
|
# Scores should be between 0 and 1 (typical range)
|
||||||
|
assert all(0 <= r.relevance_score <= 1 for r in response.results)
|
||||||
|
|
||||||
|
def test_cost_tracking(self, cohere_reranker):
|
||||||
|
"""Test cost tracking for reranking."""
|
||||||
|
# Enable cost tracking
|
||||||
|
from router.config import config
|
||||||
|
original_track_costs = config.track_costs
|
||||||
|
config.track_costs = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = cohere_reranker.rerank("test", ["doc1", "doc2"])
|
||||||
|
|
||||||
|
if response.cost is not None:
|
||||||
|
assert response.cost > 0
|
||||||
|
assert isinstance(response.cost, float)
|
||||||
|
finally:
|
||||||
|
config.track_costs = original_track_costs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestRerankEdgeCases:
|
||||||
|
"""Test edge cases and error handling."""
|
||||||
|
|
||||||
|
def test_very_long_query(self, cohere_reranker):
|
||||||
|
"""Test with a very long query."""
|
||||||
|
long_query = "Python " * 100 # Very long query
|
||||||
|
documents = ["Python doc", "Java doc"]
|
||||||
|
|
||||||
|
response = cohere_reranker.rerank(long_query, documents)
|
||||||
|
assert isinstance(response, RerankResponse)
|
||||||
|
|
||||||
|
def test_special_characters(self, cohere_reranker):
|
||||||
|
"""Test documents with special characters."""
|
||||||
|
documents = [
|
||||||
|
"Document with @#$% special chars",
|
||||||
|
"Normal document",
|
||||||
|
"Document with émojis 🐍"
|
||||||
|
]
|
||||||
|
|
||||||
|
response = cohere_reranker.rerank("special", documents)
|
||||||
|
assert len(response.results) == 3
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_docs", [1, 10, 50])
|
||||||
|
def test_various_document_counts(self, cohere_reranker, num_docs):
|
||||||
|
"""Test with different numbers of documents."""
|
||||||
|
documents = [f"Document {i}" for i in range(num_docs)]
|
||||||
|
|
||||||
|
response = cohere_reranker.rerank("Document", documents)
|
||||||
|
assert len(response.results) == num_docs
|
||||||
|
assert response.num_documents == num_docs
|
Loading…
Reference in a new issue