209 lines
No EOL
7.6 KiB
Python
209 lines
No EOL
7.6 KiB
Python
"""Test embeddings with real API calls."""
|
|
|
|
import pytest
|
|
from router.embed import (
|
|
create_embedding,
|
|
CohereEmbedding,
|
|
GeminiEmbedding,
|
|
OllamaEmbedding,
|
|
EmbeddingResponse
|
|
)
|
|
|
|
|
|
# Removed skip decorator - using pytest marks instead
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestCohereEmbeddings:
|
|
"""Test Cohere embedding functionality."""
|
|
|
|
def test_single_text_embedding(self, cohere_embedder):
|
|
"""Test embedding a single text."""
|
|
response = cohere_embedder.embed("Hello world")
|
|
|
|
assert isinstance(response, EmbeddingResponse)
|
|
assert len(response.embeddings) == 1
|
|
assert len(response.embeddings[0]) > 0
|
|
assert response.dimension > 0
|
|
assert response.provider == "cohere"
|
|
assert response.latency > 0
|
|
assert response.num_inputs == 1
|
|
|
|
def test_multiple_texts_embedding(self, cohere_embedder):
|
|
"""Test embedding multiple texts."""
|
|
texts = ["First text", "Second text", "Third text"]
|
|
response = cohere_embedder.embed(texts)
|
|
|
|
assert len(response.embeddings) == 3
|
|
assert response.num_inputs == 3
|
|
assert all(len(emb) == response.dimension for emb in response.embeddings)
|
|
|
|
def test_input_type_parameter(self, cohere_embedder):
|
|
"""Test embedding with input_type parameter."""
|
|
response = cohere_embedder.embed("Search query", input_type="search_query")
|
|
assert response.embeddings[0] is not None
|
|
|
|
response = cohere_embedder.embed("Document text", input_type="search_document")
|
|
assert response.embeddings[0] is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_embedding(self, cohere_embedder):
|
|
"""Test async embedding."""
|
|
response = await cohere_embedder.aembed("Hello async world")
|
|
|
|
assert isinstance(response, EmbeddingResponse)
|
|
assert len(response.embeddings) == 1
|
|
assert response.provider == "cohere"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_multiple_embeddings(self, cohere_embedder):
|
|
"""Test async embedding of multiple texts."""
|
|
texts = ["Async text 1", "Async text 2"]
|
|
response = await cohere_embedder.aembed(texts)
|
|
|
|
assert len(response.embeddings) == 2
|
|
assert response.num_inputs == 2
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestGeminiEmbeddings:
|
|
"""Test Gemini embedding functionality."""
|
|
|
|
def test_single_text_embedding(self, gemini_embedder):
|
|
"""Test embedding a single text."""
|
|
response = gemini_embedder.embed("Hello from Gemini")
|
|
|
|
assert isinstance(response, EmbeddingResponse)
|
|
assert len(response.embeddings) == 1
|
|
assert len(response.embeddings[0]) > 0
|
|
assert response.provider == "gemini"
|
|
assert response.dimension > 0
|
|
|
|
def test_multiple_texts_embedding(self, gemini_embedder):
|
|
"""Test embedding multiple texts."""
|
|
texts = ["Gemini text one", "Gemini text two"]
|
|
response = gemini_embedder.embed(texts)
|
|
|
|
assert len(response.embeddings) == 2
|
|
assert all(len(emb) == response.dimension for emb in response.embeddings)
|
|
|
|
def test_task_type_parameter(self, gemini_embedder):
|
|
"""Test embedding with task_type parameter."""
|
|
response = gemini_embedder.embed(
|
|
"Test with task type",
|
|
task_type="retrieval_document"
|
|
)
|
|
assert response.embeddings[0] is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_embedding(self, gemini_embedder):
|
|
"""Test async embedding."""
|
|
response = await gemini_embedder.aembed("Async Gemini test")
|
|
|
|
assert isinstance(response, EmbeddingResponse)
|
|
assert response.provider == "gemini"
|
|
assert len(response.embeddings) == 1
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestOllamaEmbeddings:
|
|
"""Test Ollama embedding functionality."""
|
|
|
|
@pytest.mark.ollama
|
|
def test_single_text_embedding(self, ollama_embedder):
|
|
"""Test embedding a single text with Ollama."""
|
|
try:
|
|
response = ollama_embedder.embed("Test Ollama")
|
|
|
|
assert isinstance(response, EmbeddingResponse)
|
|
assert response.provider == "ollama"
|
|
assert response.cost == 0.0 # Local model should be free
|
|
assert len(response.embeddings) == 1
|
|
assert response.dimension > 0
|
|
except Exception as e:
|
|
if "Cannot connect to Ollama" in str(e):
|
|
pytest.skip("Ollama not running")
|
|
raise
|
|
|
|
@pytest.mark.ollama
|
|
def test_multiple_texts_embedding(self, ollama_embedder):
|
|
"""Test embedding multiple texts with Ollama."""
|
|
try:
|
|
response = ollama_embedder.embed(["Local text 1", "Local text 2"])
|
|
|
|
assert len(response.embeddings) == 2
|
|
assert response.num_inputs == 2
|
|
assert response.metadata.get("local") is True
|
|
except Exception as e:
|
|
if "Cannot connect to Ollama" in str(e):
|
|
pytest.skip("Ollama not running")
|
|
raise
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestEmbeddingFactory:
|
|
"""Test the embedding factory function."""
|
|
|
|
def test_create_cohere_embedder(self):
|
|
"""Test creating Cohere embedder via factory."""
|
|
embedder = create_embedding(provider="cohere")
|
|
response = embedder.embed("Factory test")
|
|
|
|
assert response.provider == "cohere"
|
|
assert isinstance(embedder, CohereEmbedding)
|
|
|
|
def test_create_gemini_embedder(self):
|
|
"""Test creating Gemini embedder via factory."""
|
|
embedder = create_embedding(provider="gemini", model="text-embedding-004")
|
|
response = embedder.embed("Factory Gemini")
|
|
|
|
assert response.provider == "gemini"
|
|
assert response.model == "text-embedding-004"
|
|
assert isinstance(embedder, GeminiEmbedding)
|
|
|
|
def test_create_ollama_embedder(self):
|
|
"""Test creating Ollama embedder via factory."""
|
|
embedder = create_embedding(
|
|
provider="ollama",
|
|
model="mxbai-embed-large:latest"
|
|
)
|
|
|
|
assert embedder.model == "mxbai-embed-large:latest"
|
|
assert isinstance(embedder, OllamaEmbedding)
|
|
|
|
def test_invalid_provider(self):
|
|
"""Test factory with invalid provider."""
|
|
with pytest.raises(ValueError, match="Unknown embedding provider"):
|
|
create_embedding(provider="invalid")
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestEmbeddingComparison:
|
|
"""Test comparing embeddings across providers."""
|
|
|
|
def test_embedding_dimensions(self, cohere_embedder, gemini_embedder):
|
|
"""Test that different models produce different dimensions."""
|
|
test_text = "dimension test"
|
|
|
|
cohere_resp = cohere_embedder.embed(test_text)
|
|
gemini_resp = gemini_embedder.embed(test_text)
|
|
|
|
assert cohere_resp.dimension > 0
|
|
assert gemini_resp.dimension > 0
|
|
|
|
# Both should return valid embeddings
|
|
assert len(cohere_resp.embeddings[0]) == cohere_resp.dimension
|
|
assert len(gemini_resp.embeddings[0]) == gemini_resp.dimension
|
|
|
|
@pytest.mark.parametrize("text", [
|
|
"Short text",
|
|
"A much longer text that contains multiple sentences and should still work fine.",
|
|
"Text with special characters: @#$%^&*()",
|
|
"文本与中文字符", # Text with Chinese characters
|
|
])
|
|
def test_various_text_inputs(self, cohere_embedder, text):
|
|
"""Test embedding various types of text inputs."""
|
|
response = cohere_embedder.embed(text)
|
|
|
|
assert len(response.embeddings) == 1
|
|
assert len(response.embeddings[0]) > 0 |