template-with-ai/tests/test_embeddings.py
2025-07-01 17:07:02 +05:30

209 lines
No EOL
7.6 KiB
Python

"""Test embeddings with real API calls."""
import pytest
from router.embed import (
create_embedding,
CohereEmbedding,
GeminiEmbedding,
OllamaEmbedding,
EmbeddingResponse
)
# Removed skip decorator - using pytest marks instead
@pytest.mark.integration
class TestCohereEmbeddings:
"""Test Cohere embedding functionality."""
def test_single_text_embedding(self, cohere_embedder):
"""Test embedding a single text."""
response = cohere_embedder.embed("Hello world")
assert isinstance(response, EmbeddingResponse)
assert len(response.embeddings) == 1
assert len(response.embeddings[0]) > 0
assert response.dimension > 0
assert response.provider == "cohere"
assert response.latency > 0
assert response.num_inputs == 1
def test_multiple_texts_embedding(self, cohere_embedder):
"""Test embedding multiple texts."""
texts = ["First text", "Second text", "Third text"]
response = cohere_embedder.embed(texts)
assert len(response.embeddings) == 3
assert response.num_inputs == 3
assert all(len(emb) == response.dimension for emb in response.embeddings)
def test_input_type_parameter(self, cohere_embedder):
"""Test embedding with input_type parameter."""
response = cohere_embedder.embed("Search query", input_type="search_query")
assert response.embeddings[0] is not None
response = cohere_embedder.embed("Document text", input_type="search_document")
assert response.embeddings[0] is not None
@pytest.mark.asyncio
async def test_async_embedding(self, cohere_embedder):
"""Test async embedding."""
response = await cohere_embedder.aembed("Hello async world")
assert isinstance(response, EmbeddingResponse)
assert len(response.embeddings) == 1
assert response.provider == "cohere"
@pytest.mark.asyncio
async def test_async_multiple_embeddings(self, cohere_embedder):
"""Test async embedding of multiple texts."""
texts = ["Async text 1", "Async text 2"]
response = await cohere_embedder.aembed(texts)
assert len(response.embeddings) == 2
assert response.num_inputs == 2
@pytest.mark.integration
class TestGeminiEmbeddings:
"""Test Gemini embedding functionality."""
def test_single_text_embedding(self, gemini_embedder):
"""Test embedding a single text."""
response = gemini_embedder.embed("Hello from Gemini")
assert isinstance(response, EmbeddingResponse)
assert len(response.embeddings) == 1
assert len(response.embeddings[0]) > 0
assert response.provider == "gemini"
assert response.dimension > 0
def test_multiple_texts_embedding(self, gemini_embedder):
"""Test embedding multiple texts."""
texts = ["Gemini text one", "Gemini text two"]
response = gemini_embedder.embed(texts)
assert len(response.embeddings) == 2
assert all(len(emb) == response.dimension for emb in response.embeddings)
def test_task_type_parameter(self, gemini_embedder):
"""Test embedding with task_type parameter."""
response = gemini_embedder.embed(
"Test with task type",
task_type="retrieval_document"
)
assert response.embeddings[0] is not None
@pytest.mark.asyncio
async def test_async_embedding(self, gemini_embedder):
"""Test async embedding."""
response = await gemini_embedder.aembed("Async Gemini test")
assert isinstance(response, EmbeddingResponse)
assert response.provider == "gemini"
assert len(response.embeddings) == 1
@pytest.mark.integration
class TestOllamaEmbeddings:
"""Test Ollama embedding functionality."""
@pytest.mark.ollama
def test_single_text_embedding(self, ollama_embedder):
"""Test embedding a single text with Ollama."""
try:
response = ollama_embedder.embed("Test Ollama")
assert isinstance(response, EmbeddingResponse)
assert response.provider == "ollama"
assert response.cost == 0.0 # Local model should be free
assert len(response.embeddings) == 1
assert response.dimension > 0
except Exception as e:
if "Cannot connect to Ollama" in str(e):
pytest.skip("Ollama not running")
raise
@pytest.mark.ollama
def test_multiple_texts_embedding(self, ollama_embedder):
"""Test embedding multiple texts with Ollama."""
try:
response = ollama_embedder.embed(["Local text 1", "Local text 2"])
assert len(response.embeddings) == 2
assert response.num_inputs == 2
assert response.metadata.get("local") is True
except Exception as e:
if "Cannot connect to Ollama" in str(e):
pytest.skip("Ollama not running")
raise
@pytest.mark.integration
class TestEmbeddingFactory:
"""Test the embedding factory function."""
def test_create_cohere_embedder(self):
"""Test creating Cohere embedder via factory."""
embedder = create_embedding(provider="cohere")
response = embedder.embed("Factory test")
assert response.provider == "cohere"
assert isinstance(embedder, CohereEmbedding)
def test_create_gemini_embedder(self):
"""Test creating Gemini embedder via factory."""
embedder = create_embedding(provider="gemini", model="text-embedding-004")
response = embedder.embed("Factory Gemini")
assert response.provider == "gemini"
assert response.model == "text-embedding-004"
assert isinstance(embedder, GeminiEmbedding)
def test_create_ollama_embedder(self):
"""Test creating Ollama embedder via factory."""
embedder = create_embedding(
provider="ollama",
model="mxbai-embed-large:latest"
)
assert embedder.model == "mxbai-embed-large:latest"
assert isinstance(embedder, OllamaEmbedding)
def test_invalid_provider(self):
"""Test factory with invalid provider."""
with pytest.raises(ValueError, match="Unknown embedding provider"):
create_embedding(provider="invalid")
@pytest.mark.integration
class TestEmbeddingComparison:
"""Test comparing embeddings across providers."""
def test_embedding_dimensions(self, cohere_embedder, gemini_embedder):
"""Test that different models produce different dimensions."""
test_text = "dimension test"
cohere_resp = cohere_embedder.embed(test_text)
gemini_resp = gemini_embedder.embed(test_text)
assert cohere_resp.dimension > 0
assert gemini_resp.dimension > 0
# Both should return valid embeddings
assert len(cohere_resp.embeddings[0]) == cohere_resp.dimension
assert len(gemini_resp.embeddings[0]) == gemini_resp.dimension
@pytest.mark.parametrize("text", [
"Short text",
"A much longer text that contains multiple sentences and should still work fine.",
"Text with special characters: @#$%^&*()",
"文本与中文字符", # Text with Chinese characters
])
def test_various_text_inputs(self, cohere_embedder, text):
"""Test embedding various types of text inputs."""
response = cohere_embedder.embed(text)
assert len(response.embeddings) == 1
assert len(response.embeddings[0]) > 0