245 lines
No EOL
8.6 KiB
Python
245 lines
No EOL
8.6 KiB
Python
"""Test text generation with real APIs."""
|
|
|
|
import pytest
|
|
from router import AIRouter, RouterResponse
|
|
from router.gemini import Gemini
|
|
from router.openai_compatible import OpenAI
|
|
from router.cohere import Cohere
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestOpenAICompatible:
|
|
"""Test OpenAI-compatible API functionality."""
|
|
|
|
def test_simple_generation(self, openai_router):
|
|
"""Test simple text generation."""
|
|
response = openai_router.call("Say hello in 5 words or less")
|
|
|
|
assert isinstance(response, RouterResponse)
|
|
assert response.content is not None
|
|
assert response.provider == "openai"
|
|
assert response.latency > 0
|
|
assert len(response.content.split()) <= 10 # Allow some flexibility
|
|
|
|
def test_generation_with_parameters(self, openai_router):
|
|
"""Test generation with custom parameters."""
|
|
response = openai_router.call(
|
|
"Write a haiku about coding",
|
|
temperature=0.7,
|
|
max_tokens=100
|
|
)
|
|
|
|
assert len(response.content) > 0
|
|
assert response.total_tokens is not None
|
|
assert response.input_tokens is not None
|
|
assert response.output_tokens is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_generation(self, openai_router):
|
|
"""Test async generation."""
|
|
response = await openai_router.acall("Say 'async works' if this is working")
|
|
|
|
assert "async" in response.content.lower()
|
|
assert response.provider == "openai"
|
|
|
|
def test_streaming(self, openai_router):
|
|
"""Test streaming responses."""
|
|
chunks = []
|
|
for chunk in openai_router.stream("Count from 1 to 3"):
|
|
chunks.append(chunk.content)
|
|
assert isinstance(chunk, RouterResponse)
|
|
assert chunk.provider == "openai"
|
|
|
|
assert len(chunks) > 0
|
|
full_response = "".join(chunks)
|
|
assert len(full_response) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_streaming(self, openai_router):
|
|
"""Test async streaming."""
|
|
chunks = []
|
|
async for chunk in openai_router.astream("List 3 colors"):
|
|
chunks.append(chunk.content)
|
|
assert isinstance(chunk, RouterResponse)
|
|
|
|
assert len(chunks) > 0
|
|
full_response = "".join(chunks)
|
|
assert len(full_response) > 0
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestGemini:
|
|
"""Test Gemini API functionality."""
|
|
|
|
def test_simple_generation(self, gemini_router):
|
|
"""Test simple text generation."""
|
|
response = gemini_router.call("What is 2+2? Answer in one word.")
|
|
|
|
assert isinstance(response, RouterResponse)
|
|
assert response.provider == "gemini"
|
|
assert response.content is not None
|
|
assert len(response.content) > 0
|
|
|
|
def test_generation_with_temperature(self, gemini_router):
|
|
"""Test generation with temperature parameter."""
|
|
response = gemini_router.call(
|
|
"List 3 benefits of Python programming. Be concise.",
|
|
temperature=0.5
|
|
)
|
|
|
|
assert response.content is not None
|
|
assert response.total_tokens is not None
|
|
# Verify the response contains programming benefits (may not always mention Python by name)
|
|
assert any(keyword in response.content.lower() for keyword in ["readability", "versatility", "libraries", "development", "programming"])
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_generation(self, gemini_router):
|
|
"""Test async generation."""
|
|
response = await gemini_router.acall("What is 5x5? Answer only the number.")
|
|
|
|
assert response.content is not None
|
|
assert "25" in response.content
|
|
|
|
def test_streaming(self, gemini_router):
|
|
"""Test streaming responses."""
|
|
chunks = []
|
|
for chunk in gemini_router.stream("Say 'streaming works'"):
|
|
chunks.append(chunk.content)
|
|
assert chunk.provider == "gemini"
|
|
|
|
assert len(chunks) > 0
|
|
full_response = "".join(chunks)
|
|
assert "streaming" in full_response.lower()
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestCohere:
|
|
"""Test Cohere API functionality."""
|
|
|
|
def test_simple_generation(self, cohere_router):
|
|
"""Test simple text generation."""
|
|
response = cohere_router.call("Complete this: The capital of France is")
|
|
|
|
assert isinstance(response, RouterResponse)
|
|
assert response.provider == "cohere"
|
|
assert "Paris" in response.content
|
|
|
|
def test_generation_with_preamble(self, cohere_router):
|
|
"""Test generation with preamble."""
|
|
response = cohere_router.call(
|
|
"What is machine learning?",
|
|
preamble="You are a helpful AI assistant. Be concise.",
|
|
max_tokens=100
|
|
)
|
|
|
|
assert len(response.content) > 0
|
|
assert response.finish_reason is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_generation(self, cohere_router):
|
|
"""Test async generation."""
|
|
response = await cohere_router.acall("Say hello asynchronously")
|
|
|
|
assert response.content is not None
|
|
assert response.provider == "cohere"
|
|
|
|
def test_chat_history(self, cohere_router):
|
|
"""Test generation with chat history."""
|
|
response = cohere_router.call(
|
|
"What did I just ask?",
|
|
chat_history=[
|
|
{"role": "USER", "message": "What is Python?"},
|
|
{"role": "CHATBOT", "message": "Python is a programming language."}
|
|
]
|
|
)
|
|
|
|
assert response.content is not None
|
|
assert len(response.content) > 0
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestFactoryFunction:
|
|
"""Test the AIRouter factory functionality."""
|
|
|
|
def test_create_openai_router(self, api_keys):
|
|
"""Test creating OpenAI router via factory."""
|
|
router = AIRouter(
|
|
provider="openai",
|
|
model="gpt-4o",
|
|
base_url=api_keys["openai_base_url"]
|
|
)
|
|
|
|
response = router.call("Say 'factory works'")
|
|
assert "factory" in response.content.lower()
|
|
assert response.provider == "openai"
|
|
|
|
def test_create_gemini_router(self):
|
|
"""Test creating Gemini router via factory."""
|
|
router = AIRouter(provider="gemini")
|
|
|
|
response = router.call("What is 1+1?")
|
|
assert response.provider == "gemini"
|
|
assert "2" in response.content
|
|
|
|
def test_create_cohere_router(self):
|
|
"""Test creating Cohere router via factory."""
|
|
router = AIRouter(provider="cohere")
|
|
|
|
response = router.call("Hello")
|
|
assert response.provider == "cohere"
|
|
assert len(response.content) > 0
|
|
|
|
def test_invalid_provider(self):
|
|
"""Test factory with invalid provider."""
|
|
with pytest.raises(ValueError, match="Unknown provider"):
|
|
AIRouter(provider="invalid")
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestGenerationFeatures:
|
|
"""Test various generation features across providers."""
|
|
|
|
@pytest.mark.parametrize("provider,router_fixture", [
|
|
("openai", "openai_router"),
|
|
("gemini", "gemini_router"),
|
|
("cohere", "cohere_router"),
|
|
])
|
|
def test_max_tokens_limit(self, provider, router_fixture, request):
|
|
"""Test max_tokens parameter across providers."""
|
|
router = request.getfixturevalue(router_fixture)
|
|
|
|
response = router.call(
|
|
"Write a very long story about dragons",
|
|
max_tokens=10
|
|
)
|
|
|
|
assert response.content is not None
|
|
assert response.provider == provider
|
|
# Output should be limited (though exact token count varies by tokenizer)
|
|
assert len(response.content.split()) < 50
|
|
|
|
@pytest.mark.parametrize("temperature", [0.0, 0.5, 1.0])
|
|
def test_temperature_variations(self, gemini_router, temperature):
|
|
"""Test different temperature settings."""
|
|
response = gemini_router.call(
|
|
"Generate a random word",
|
|
temperature=temperature
|
|
)
|
|
|
|
assert response.content is not None
|
|
assert len(response.content) > 0
|
|
|
|
def test_cost_tracking(self, openai_router):
|
|
"""Test cost tracking functionality."""
|
|
from router.config import config
|
|
original_track_costs = config.track_costs
|
|
config.track_costs = True
|
|
|
|
try:
|
|
response = openai_router.call("Hello")
|
|
|
|
if response.cost is not None:
|
|
assert response.cost >= 0
|
|
assert isinstance(response.cost, float)
|
|
finally:
|
|
config.track_costs = original_track_costs |