"""Test text generation with real APIs.""" import pytest from router import AIRouter, RouterResponse from router.gemini import Gemini from router.openai_compatible import OpenAI from router.cohere import Cohere @pytest.mark.integration class TestOpenAICompatible: """Test OpenAI-compatible API functionality.""" def test_simple_generation(self, openai_router): """Test simple text generation.""" response = openai_router.call("Say hello in 5 words or less") assert isinstance(response, RouterResponse) assert response.content is not None assert response.provider == "openai" assert response.latency > 0 assert len(response.content.split()) <= 10 # Allow some flexibility def test_generation_with_parameters(self, openai_router): """Test generation with custom parameters.""" response = openai_router.call( "Write a haiku about coding", temperature=0.7, max_tokens=100 ) assert len(response.content) > 0 assert response.total_tokens is not None assert response.input_tokens is not None assert response.output_tokens is not None @pytest.mark.asyncio async def test_async_generation(self, openai_router): """Test async generation.""" response = await openai_router.acall("Say 'async works' if this is working") assert "async" in response.content.lower() assert response.provider == "openai" def test_streaming(self, openai_router): """Test streaming responses.""" chunks = [] for chunk in openai_router.stream("Count from 1 to 3"): chunks.append(chunk.content) assert isinstance(chunk, RouterResponse) assert chunk.provider == "openai" assert len(chunks) > 0 full_response = "".join(chunks) assert len(full_response) > 0 @pytest.mark.asyncio async def test_async_streaming(self, openai_router): """Test async streaming.""" chunks = [] async for chunk in openai_router.astream("List 3 colors"): chunks.append(chunk.content) assert isinstance(chunk, RouterResponse) assert len(chunks) > 0 full_response = "".join(chunks) assert len(full_response) > 0 @pytest.mark.integration class TestGemini: """Test Gemini API functionality.""" def test_simple_generation(self, gemini_router): """Test simple text generation.""" response = gemini_router.call("What is 2+2? Answer in one word.") assert isinstance(response, RouterResponse) assert response.provider == "gemini" assert response.content is not None assert len(response.content) > 0 def test_generation_with_temperature(self, gemini_router): """Test generation with temperature parameter.""" response = gemini_router.call( "List 3 benefits of Python programming. Be concise.", temperature=0.5 ) assert response.content is not None assert response.total_tokens is not None # Verify the response contains programming benefits (may not always mention Python by name) assert any(keyword in response.content.lower() for keyword in ["readability", "versatility", "libraries", "development", "programming"]) @pytest.mark.asyncio async def test_async_generation(self, gemini_router): """Test async generation.""" response = await gemini_router.acall("What is 5x5? Answer only the number.") assert response.content is not None assert "25" in response.content def test_streaming(self, gemini_router): """Test streaming responses.""" chunks = [] for chunk in gemini_router.stream("Say 'streaming works'"): chunks.append(chunk.content) assert chunk.provider == "gemini" assert len(chunks) > 0 full_response = "".join(chunks) assert "streaming" in full_response.lower() @pytest.mark.integration class TestCohere: """Test Cohere API functionality.""" def test_simple_generation(self, cohere_router): """Test simple text generation.""" response = cohere_router.call("Complete this: The capital of France is") assert isinstance(response, RouterResponse) assert response.provider == "cohere" assert "Paris" in response.content def test_generation_with_preamble(self, cohere_router): """Test generation with preamble.""" response = cohere_router.call( "What is machine learning?", preamble="You are a helpful AI assistant. Be concise.", max_tokens=100 ) assert len(response.content) > 0 assert response.finish_reason is not None @pytest.mark.asyncio async def test_async_generation(self, cohere_router): """Test async generation.""" response = await cohere_router.acall("Say hello asynchronously") assert response.content is not None assert response.provider == "cohere" def test_chat_history(self, cohere_router): """Test generation with chat history.""" response = cohere_router.call( "What did I just ask?", chat_history=[ {"role": "USER", "message": "What is Python?"}, {"role": "CHATBOT", "message": "Python is a programming language."} ] ) assert response.content is not None assert len(response.content) > 0 @pytest.mark.integration class TestFactoryFunction: """Test the AIRouter factory functionality.""" def test_create_openai_router(self, api_keys): """Test creating OpenAI router via factory.""" router = AIRouter( provider="openai", model="gpt-4o", base_url=api_keys["openai_base_url"] ) response = router.call("Say 'factory works'") assert "factory" in response.content.lower() assert response.provider == "openai" def test_create_gemini_router(self): """Test creating Gemini router via factory.""" router = AIRouter(provider="gemini") response = router.call("What is 1+1?") assert response.provider == "gemini" assert "2" in response.content def test_create_cohere_router(self): """Test creating Cohere router via factory.""" router = AIRouter(provider="cohere") response = router.call("Hello") assert response.provider == "cohere" assert len(response.content) > 0 def test_invalid_provider(self): """Test factory with invalid provider.""" with pytest.raises(ValueError, match="Unknown provider"): AIRouter(provider="invalid") @pytest.mark.integration class TestGenerationFeatures: """Test various generation features across providers.""" @pytest.mark.parametrize("provider,router_fixture", [ ("openai", "openai_router"), ("gemini", "gemini_router"), ("cohere", "cohere_router"), ]) def test_max_tokens_limit(self, provider, router_fixture, request): """Test max_tokens parameter across providers.""" router = request.getfixturevalue(router_fixture) response = router.call( "Write a very long story about dragons", max_tokens=10 ) assert response.content is not None assert response.provider == provider # Output should be limited (though exact token count varies by tokenizer) assert len(response.content.split()) < 50 @pytest.mark.parametrize("temperature", [0.0, 0.5, 1.0]) def test_temperature_variations(self, gemini_router, temperature): """Test different temperature settings.""" response = gemini_router.call( "Generate a random word", temperature=temperature ) assert response.content is not None assert len(response.content) > 0 def test_cost_tracking(self, openai_router): """Test cost tracking functionality.""" from router.config import config original_track_costs = config.track_costs config.track_costs = True try: response = openai_router.call("Hello") if response.cost is not None: assert response.cost >= 0 assert isinstance(response.cost, float) finally: config.track_costs = original_track_costs