template-with-ai/tests/test_generation.py
2025-07-01 17:07:02 +05:30

245 lines
No EOL
8.6 KiB
Python

"""Test text generation with real APIs."""
import pytest
from router import AIRouter, RouterResponse
from router.gemini import Gemini
from router.openai_compatible import OpenAI
from router.cohere import Cohere
@pytest.mark.integration
class TestOpenAICompatible:
"""Test OpenAI-compatible API functionality."""
def test_simple_generation(self, openai_router):
"""Test simple text generation."""
response = openai_router.call("Say hello in 5 words or less")
assert isinstance(response, RouterResponse)
assert response.content is not None
assert response.provider == "openai"
assert response.latency > 0
assert len(response.content.split()) <= 10 # Allow some flexibility
def test_generation_with_parameters(self, openai_router):
"""Test generation with custom parameters."""
response = openai_router.call(
"Write a haiku about coding",
temperature=0.7,
max_tokens=100
)
assert len(response.content) > 0
assert response.total_tokens is not None
assert response.input_tokens is not None
assert response.output_tokens is not None
@pytest.mark.asyncio
async def test_async_generation(self, openai_router):
"""Test async generation."""
response = await openai_router.acall("Say 'async works' if this is working")
assert "async" in response.content.lower()
assert response.provider == "openai"
def test_streaming(self, openai_router):
"""Test streaming responses."""
chunks = []
for chunk in openai_router.stream("Count from 1 to 3"):
chunks.append(chunk.content)
assert isinstance(chunk, RouterResponse)
assert chunk.provider == "openai"
assert len(chunks) > 0
full_response = "".join(chunks)
assert len(full_response) > 0
@pytest.mark.asyncio
async def test_async_streaming(self, openai_router):
"""Test async streaming."""
chunks = []
async for chunk in openai_router.astream("List 3 colors"):
chunks.append(chunk.content)
assert isinstance(chunk, RouterResponse)
assert len(chunks) > 0
full_response = "".join(chunks)
assert len(full_response) > 0
@pytest.mark.integration
class TestGemini:
"""Test Gemini API functionality."""
def test_simple_generation(self, gemini_router):
"""Test simple text generation."""
response = gemini_router.call("What is 2+2? Answer in one word.")
assert isinstance(response, RouterResponse)
assert response.provider == "gemini"
assert response.content is not None
assert len(response.content) > 0
def test_generation_with_temperature(self, gemini_router):
"""Test generation with temperature parameter."""
response = gemini_router.call(
"List 3 benefits of Python programming. Be concise.",
temperature=0.5
)
assert response.content is not None
assert response.total_tokens is not None
# Verify the response contains programming benefits (may not always mention Python by name)
assert any(keyword in response.content.lower() for keyword in ["readability", "versatility", "libraries", "development", "programming"])
@pytest.mark.asyncio
async def test_async_generation(self, gemini_router):
"""Test async generation."""
response = await gemini_router.acall("What is 5x5? Answer only the number.")
assert response.content is not None
assert "25" in response.content
def test_streaming(self, gemini_router):
"""Test streaming responses."""
chunks = []
for chunk in gemini_router.stream("Say 'streaming works'"):
chunks.append(chunk.content)
assert chunk.provider == "gemini"
assert len(chunks) > 0
full_response = "".join(chunks)
assert "streaming" in full_response.lower()
@pytest.mark.integration
class TestCohere:
"""Test Cohere API functionality."""
def test_simple_generation(self, cohere_router):
"""Test simple text generation."""
response = cohere_router.call("Complete this: The capital of France is")
assert isinstance(response, RouterResponse)
assert response.provider == "cohere"
assert "Paris" in response.content
def test_generation_with_preamble(self, cohere_router):
"""Test generation with preamble."""
response = cohere_router.call(
"What is machine learning?",
preamble="You are a helpful AI assistant. Be concise.",
max_tokens=100
)
assert len(response.content) > 0
assert response.finish_reason is not None
@pytest.mark.asyncio
async def test_async_generation(self, cohere_router):
"""Test async generation."""
response = await cohere_router.acall("Say hello asynchronously")
assert response.content is not None
assert response.provider == "cohere"
def test_chat_history(self, cohere_router):
"""Test generation with chat history."""
response = cohere_router.call(
"What did I just ask?",
chat_history=[
{"role": "USER", "message": "What is Python?"},
{"role": "CHATBOT", "message": "Python is a programming language."}
]
)
assert response.content is not None
assert len(response.content) > 0
@pytest.mark.integration
class TestFactoryFunction:
"""Test the AIRouter factory functionality."""
def test_create_openai_router(self, api_keys):
"""Test creating OpenAI router via factory."""
router = AIRouter(
provider="openai",
model="gpt-4o",
base_url=api_keys["openai_base_url"]
)
response = router.call("Say 'factory works'")
assert "factory" in response.content.lower()
assert response.provider == "openai"
def test_create_gemini_router(self):
"""Test creating Gemini router via factory."""
router = AIRouter(provider="gemini")
response = router.call("What is 1+1?")
assert response.provider == "gemini"
assert "2" in response.content
def test_create_cohere_router(self):
"""Test creating Cohere router via factory."""
router = AIRouter(provider="cohere")
response = router.call("Hello")
assert response.provider == "cohere"
assert len(response.content) > 0
def test_invalid_provider(self):
"""Test factory with invalid provider."""
with pytest.raises(ValueError, match="Unknown provider"):
AIRouter(provider="invalid")
@pytest.mark.integration
class TestGenerationFeatures:
"""Test various generation features across providers."""
@pytest.mark.parametrize("provider,router_fixture", [
("openai", "openai_router"),
("gemini", "gemini_router"),
("cohere", "cohere_router"),
])
def test_max_tokens_limit(self, provider, router_fixture, request):
"""Test max_tokens parameter across providers."""
router = request.getfixturevalue(router_fixture)
response = router.call(
"Write a very long story about dragons",
max_tokens=10
)
assert response.content is not None
assert response.provider == provider
# Output should be limited (though exact token count varies by tokenizer)
assert len(response.content.split()) < 50
@pytest.mark.parametrize("temperature", [0.0, 0.5, 1.0])
def test_temperature_variations(self, gemini_router, temperature):
"""Test different temperature settings."""
response = gemini_router.call(
"Generate a random word",
temperature=temperature
)
assert response.content is not None
assert len(response.content) > 0
def test_cost_tracking(self, openai_router):
"""Test cost tracking functionality."""
from router.config import config
original_track_costs = config.track_costs
config.track_costs = True
try:
response = openai_router.call("Hello")
if response.cost is not None:
assert response.cost >= 0
assert isinstance(response.cost, float)
finally:
config.track_costs = original_track_costs