745 lines
No EOL
25 KiB
Python
745 lines
No EOL
25 KiB
Python
"""Embedding router implementation for multiple providers."""
|
|
|
|
from typing import List, Dict, Any, Optional, Union, Literal
|
|
from dataclasses import dataclass, field
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
|
|
from .config import config
|
|
from .exceptions import (
|
|
ConfigurationError,
|
|
map_provider_error
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingResponse:
|
|
"""Response from embedding operation."""
|
|
embeddings: List[List[float]] # List of embedding vectors
|
|
model: str
|
|
provider: str
|
|
latency: float
|
|
dimension: int # Dimension of embeddings
|
|
num_inputs: int # Number of inputs embedded
|
|
total_tokens: Optional[int] = None
|
|
cost: Optional[float] = None
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
raw_response: Optional[Any] = None
|
|
|
|
|
|
class BaseEmbedding(ABC):
|
|
"""Base class for embedding routers."""
|
|
|
|
def __init__(self, model: str, api_key: str, **kwargs: Any) -> None:
|
|
"""Initialize embedding router.
|
|
|
|
Args:
|
|
model: Model identifier
|
|
api_key: API key
|
|
**kwargs: Additional configuration
|
|
"""
|
|
self.model = model
|
|
self.api_key = api_key
|
|
self.config = kwargs
|
|
|
|
@abstractmethod
|
|
def embed(
|
|
self,
|
|
texts: Union[str, List[str]],
|
|
**kwargs: Any
|
|
) -> EmbeddingResponse:
|
|
"""Generate embeddings for texts.
|
|
|
|
Args:
|
|
texts: Single text or list of texts to embed
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
EmbeddingResponse
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def aembed(
|
|
self,
|
|
texts: Union[str, List[str]],
|
|
**kwargs: Any
|
|
) -> EmbeddingResponse:
|
|
"""Asynchronously generate embeddings for texts.
|
|
|
|
Args:
|
|
texts: Single text or list of texts to embed
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
EmbeddingResponse
|
|
"""
|
|
pass
|
|
|
|
|
|
class CohereEmbedding(BaseEmbedding):
|
|
"""Router for Cohere embedding models."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "embed-english-v3.0",
|
|
api_key: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Initialize Cohere embedding router.
|
|
|
|
Args:
|
|
model: Embedding model (embed-english-v3.0, embed-multilingual-v3.0, etc.)
|
|
api_key: Cohere API key (optional if set in environment)
|
|
**kwargs: Additional configuration
|
|
"""
|
|
# Get API key from config if not provided
|
|
if not api_key:
|
|
api_key = config.get_api_key("cohere")
|
|
|
|
if not api_key:
|
|
raise ConfigurationError(
|
|
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
|
|
provider="cohere"
|
|
)
|
|
|
|
super().__init__(model, api_key, **kwargs)
|
|
|
|
# Initialize Cohere client
|
|
try:
|
|
import cohere
|
|
# For v5.15.0, use the standard Client
|
|
self.client = cohere.Client(api_key)
|
|
self.async_client = cohere.AsyncClient(api_key)
|
|
except ImportError:
|
|
raise ConfigurationError(
|
|
"cohere package not installed. Install with: pip install cohere",
|
|
provider="cohere"
|
|
)
|
|
|
|
def embed(
|
|
self,
|
|
texts: Union[str, List[str]],
|
|
input_type: Optional[Literal["search_document", "search_query", "classification", "clustering"]] = None,
|
|
truncate: Optional[Literal["NONE", "START", "END"]] = None,
|
|
**kwargs: Any
|
|
) -> EmbeddingResponse:
|
|
"""Generate embeddings for texts.
|
|
|
|
Args:
|
|
texts: Single text or list of texts to embed
|
|
input_type: Purpose of embeddings (affects vector space)
|
|
truncate: How to handle inputs longer than max tokens
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
EmbeddingResponse
|
|
"""
|
|
# Ensure texts is a list
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
params = self._prepare_request(texts, input_type, truncate, **kwargs)
|
|
|
|
try:
|
|
start_time = time.time()
|
|
response = self.client.embed(**params)
|
|
latency = time.time() - start_time
|
|
|
|
return self._parse_response(response, latency, len(texts))
|
|
except Exception as e:
|
|
raise map_provider_error("cohere", e)
|
|
|
|
async def aembed(
|
|
self,
|
|
texts: Union[str, List[str]],
|
|
input_type: Optional[Literal["search_document", "search_query", "classification", "clustering"]] = None,
|
|
truncate: Optional[Literal["NONE", "START", "END"]] = None,
|
|
**kwargs: Any
|
|
) -> EmbeddingResponse:
|
|
"""Asynchronously generate embeddings for texts.
|
|
|
|
Args:
|
|
texts: Single text or list of texts to embed
|
|
input_type: Purpose of embeddings
|
|
truncate: How to handle long inputs
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
EmbeddingResponse
|
|
"""
|
|
# Ensure texts is a list
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
params = self._prepare_request(texts, input_type, truncate, **kwargs)
|
|
|
|
try:
|
|
start_time = time.time()
|
|
response = await self.async_client.embed(**params)
|
|
latency = time.time() - start_time
|
|
|
|
return self._parse_response(response, latency, len(texts))
|
|
except Exception as e:
|
|
raise map_provider_error("cohere", e)
|
|
|
|
def _prepare_request(
|
|
self,
|
|
texts: List[str],
|
|
input_type: Optional[str],
|
|
truncate: Optional[str],
|
|
**kwargs: Any
|
|
) -> Dict[str, Any]:
|
|
"""Prepare embed request parameters."""
|
|
params = {
|
|
"model": kwargs.get("model", self.model),
|
|
"texts": texts,
|
|
}
|
|
|
|
# For v3 models, input_type is required - default to search_document
|
|
if input_type:
|
|
params["input_type"] = input_type
|
|
elif "v3" in self.model or "3.0" in self.model or "3.5" in self.model:
|
|
params["input_type"] = "search_document"
|
|
|
|
if truncate:
|
|
params["truncate"] = truncate
|
|
|
|
return params
|
|
|
|
def _parse_response(
|
|
self,
|
|
raw_response: Any,
|
|
latency: float,
|
|
num_inputs: int
|
|
) -> EmbeddingResponse:
|
|
"""Parse Cohere embed response."""
|
|
# Extract embeddings
|
|
embeddings = []
|
|
if hasattr(raw_response, "embeddings"):
|
|
embeddings = raw_response.embeddings
|
|
|
|
# Get dimension from first embedding
|
|
dimension = len(embeddings[0]) if embeddings else 0
|
|
|
|
# Get token count from response metadata
|
|
total_tokens = None
|
|
if hasattr(raw_response, "meta") and hasattr(raw_response.meta, "billed_units"):
|
|
total_tokens = getattr(raw_response.meta.billed_units, "input_tokens", None)
|
|
|
|
# Calculate cost
|
|
cost = None
|
|
if config.track_costs and total_tokens:
|
|
cost = config.calculate_embed_cost(self.model, total_tokens)
|
|
|
|
return EmbeddingResponse(
|
|
embeddings=embeddings,
|
|
model=self.model,
|
|
provider="cohere",
|
|
latency=latency,
|
|
dimension=dimension,
|
|
num_inputs=num_inputs,
|
|
total_tokens=total_tokens,
|
|
cost=cost,
|
|
metadata={
|
|
"id": getattr(raw_response, "id", None),
|
|
"response_type": getattr(raw_response, "response_type", None),
|
|
},
|
|
raw_response=raw_response
|
|
)
|
|
|
|
|
|
class GeminiEmbedding(BaseEmbedding):
|
|
"""Router for Google Gemini embedding models."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "text-embedding-004",
|
|
api_key: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Initialize Gemini embedding router.
|
|
|
|
Args:
|
|
model: Embedding model (text-embedding-004, etc.)
|
|
api_key: Google API key (optional if set in environment)
|
|
**kwargs: Additional configuration
|
|
"""
|
|
# Get API key from config if not provided
|
|
if not api_key:
|
|
api_key = config.get_api_key("gemini")
|
|
|
|
if not api_key:
|
|
raise ConfigurationError(
|
|
"Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.",
|
|
provider="gemini"
|
|
)
|
|
|
|
super().__init__(model, api_key, **kwargs)
|
|
|
|
# Initialize Gemini client using new google.genai library
|
|
try:
|
|
from google import genai
|
|
from google.genai import types
|
|
self.genai = genai
|
|
self.types = types
|
|
self.client = genai.Client(api_key=api_key)
|
|
except ImportError:
|
|
raise ConfigurationError(
|
|
"google-genai package not installed. Install with: pip install google-genai",
|
|
provider="gemini"
|
|
)
|
|
|
|
def embed(
|
|
self,
|
|
texts: Union[str, List[str]],
|
|
task_type: Optional[str] = None,
|
|
title: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> EmbeddingResponse:
|
|
"""Generate embeddings for texts.
|
|
|
|
Args:
|
|
texts: Single text or list of texts to embed
|
|
task_type: Task type for embeddings
|
|
title: Optional title for context
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
EmbeddingResponse
|
|
"""
|
|
# Ensure texts is a list
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# The Google genai SDK supports batch embedding
|
|
# Pass all texts at once for better performance
|
|
params = {
|
|
"model": kwargs.get("model", self.model),
|
|
"contents": texts, # Can pass multiple texts
|
|
}
|
|
|
|
# Add optional config
|
|
config_params = {}
|
|
if task_type:
|
|
config_params["task_type"] = task_type
|
|
if title:
|
|
config_params["title"] = title
|
|
|
|
if config_params:
|
|
params["config"] = self.types.EmbedContentConfig(**config_params)
|
|
|
|
response = self.client.models.embed_content(**params)
|
|
|
|
# Extract embeddings from response
|
|
# The response always has an 'embeddings' attribute (even for single text)
|
|
embeddings = []
|
|
|
|
# Check if embeddings exist
|
|
if response.embeddings is None:
|
|
raise ValueError("No embeddings returned in response")
|
|
|
|
# response.embeddings is always present
|
|
for emb in response.embeddings:
|
|
# ContentEmbedding objects have a 'values' attribute containing the float list
|
|
if hasattr(emb, "values") and emb.values is not None:
|
|
embeddings.append(list(emb.values))
|
|
elif hasattr(emb, "__iter__"):
|
|
# If the embedding is directly iterable (list-like)
|
|
try:
|
|
embeddings.append(list(emb))
|
|
except Exception as e:
|
|
print(f"Warning: Could not extract embedding values from {type(emb)}: {e}")
|
|
else:
|
|
print(f"Warning: Unknown embedding format: {type(emb)}, attributes: {dir(emb)}")
|
|
|
|
latency = time.time() - start_time
|
|
|
|
# Token counting is not directly available in the response
|
|
# Set to None for now
|
|
total_tokens = None
|
|
|
|
return self._create_response(
|
|
embeddings=embeddings,
|
|
latency=latency,
|
|
num_inputs=len(texts),
|
|
total_tokens=total_tokens
|
|
)
|
|
except Exception as e:
|
|
raise map_provider_error("gemini", e)
|
|
|
|
async def aembed(
|
|
self,
|
|
texts: Union[str, List[str]],
|
|
task_type: Optional[str] = None,
|
|
title: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> EmbeddingResponse:
|
|
"""Asynchronously generate embeddings for texts.
|
|
|
|
Args:
|
|
texts: Single text or list of texts to embed
|
|
task_type: Task type for embeddings
|
|
title: Optional title for context
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
EmbeddingResponse
|
|
"""
|
|
# Ensure texts is a list
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# The Google genai SDK supports batch embedding
|
|
# Pass all texts at once for better performance
|
|
params = {
|
|
"model": kwargs.get("model", self.model),
|
|
"contents": texts, # Can pass multiple texts
|
|
}
|
|
|
|
# Add optional config
|
|
config_params = {}
|
|
if task_type:
|
|
config_params["task_type"] = task_type
|
|
if title:
|
|
config_params["title"] = title
|
|
|
|
if config_params:
|
|
params["config"] = self.types.EmbedContentConfig(**config_params)
|
|
|
|
response = await self.client.aio.models.embed_content(**params)
|
|
|
|
# Extract embeddings from response
|
|
# The response always has an 'embeddings' attribute (even for single text)
|
|
embeddings = []
|
|
|
|
# Check if embeddings exist
|
|
if response.embeddings is None:
|
|
raise ValueError("No embeddings returned in response")
|
|
|
|
# response.embeddings is always present
|
|
for emb in response.embeddings:
|
|
# ContentEmbedding objects have a 'values' attribute containing the float list
|
|
if hasattr(emb, "values") and emb.values is not None:
|
|
embeddings.append(list(emb.values))
|
|
elif hasattr(emb, "__iter__"):
|
|
# If the embedding is directly iterable (list-like)
|
|
try:
|
|
embeddings.append(list(emb))
|
|
except Exception as e:
|
|
print(f"Warning: Could not extract embedding values from {type(emb)}: {e}")
|
|
else:
|
|
print(f"Warning: Unknown embedding format: {type(emb)}, attributes: {dir(emb)}")
|
|
|
|
latency = time.time() - start_time
|
|
|
|
# Token counting is not directly available in the response
|
|
# Set to None for now
|
|
total_tokens = None
|
|
|
|
return self._create_response(
|
|
embeddings=embeddings,
|
|
latency=latency,
|
|
num_inputs=len(texts),
|
|
total_tokens=total_tokens
|
|
)
|
|
except Exception as e:
|
|
raise map_provider_error("gemini", e)
|
|
|
|
|
|
def _create_response(
|
|
self,
|
|
embeddings: List[List[float]],
|
|
latency: float,
|
|
num_inputs: int,
|
|
total_tokens: Optional[int] = None
|
|
) -> EmbeddingResponse:
|
|
"""Create embedding response."""
|
|
# Get dimension from first embedding
|
|
dimension = len(embeddings[0]) if embeddings else 0
|
|
|
|
# Calculate cost
|
|
cost = None
|
|
if config.track_costs and total_tokens:
|
|
cost = config.calculate_embed_cost(self.model, total_tokens)
|
|
|
|
return EmbeddingResponse(
|
|
embeddings=embeddings,
|
|
model=self.model,
|
|
provider="gemini",
|
|
latency=latency,
|
|
dimension=dimension,
|
|
num_inputs=num_inputs,
|
|
total_tokens=total_tokens,
|
|
cost=cost,
|
|
metadata={},
|
|
raw_response=None
|
|
)
|
|
|
|
|
|
class OllamaEmbedding(BaseEmbedding):
|
|
"""Router for Ollama local embedding models."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "nomic-embed-text:latest",
|
|
base_url: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Initialize Ollama embedding router.
|
|
|
|
Args:
|
|
model: Ollama embedding model (mxbai-embed-large, nomic-embed-text, etc.)
|
|
base_url: Ollama API base URL (default: http://localhost:11434)
|
|
**kwargs: Additional configuration
|
|
"""
|
|
# No API key needed for local Ollama
|
|
super().__init__(model, api_key="local", **kwargs)
|
|
|
|
# Get base URL from config or parameter
|
|
self.base_url = base_url or config.ollama_base_url
|
|
self.embeddings_url = f"{self.base_url}/api/embeddings"
|
|
|
|
# Check if Ollama is available
|
|
try:
|
|
import requests
|
|
self.requests = requests
|
|
except ImportError:
|
|
raise ConfigurationError(
|
|
"requests package not installed. Install with: pip install requests",
|
|
provider="ollama"
|
|
)
|
|
|
|
# For async support
|
|
try:
|
|
import httpx
|
|
self.httpx = httpx
|
|
except ImportError:
|
|
self.httpx = None # Async support is optional
|
|
|
|
def embed(
|
|
self,
|
|
texts: Union[str, List[str]],
|
|
**kwargs: Any
|
|
) -> EmbeddingResponse:
|
|
"""Generate embeddings for texts using Ollama.
|
|
|
|
Args:
|
|
texts: Single text or list of texts to embed
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
EmbeddingResponse
|
|
"""
|
|
# Ensure texts is a list
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
try:
|
|
start_time = time.time()
|
|
embeddings = []
|
|
|
|
# Ollama API currently handles one text at a time
|
|
for text in texts:
|
|
payload = {
|
|
"model": kwargs.get("model", self.model),
|
|
"prompt": text
|
|
}
|
|
|
|
try:
|
|
response = self.requests.post(
|
|
self.embeddings_url,
|
|
json=payload,
|
|
timeout=kwargs.get("timeout", 30)
|
|
)
|
|
response.raise_for_status()
|
|
|
|
data = response.json()
|
|
if "embedding" in data:
|
|
embeddings.append(data["embedding"])
|
|
else:
|
|
raise ValueError(f"No embedding in response: {data}")
|
|
|
|
except self.requests.exceptions.ConnectionError:
|
|
raise ConfigurationError(
|
|
f"Cannot connect to Ollama at {self.base_url}. "
|
|
"Make sure Ollama is running (ollama serve).",
|
|
provider="ollama"
|
|
)
|
|
except self.requests.exceptions.HTTPError as e:
|
|
if e.response.status_code == 404:
|
|
raise ValueError(
|
|
f"Model '{self.model}' not found. "
|
|
f"Pull it first with: ollama pull {self.model}"
|
|
)
|
|
raise
|
|
|
|
latency = time.time() - start_time
|
|
|
|
return self._create_response(
|
|
embeddings=embeddings,
|
|
latency=latency,
|
|
num_inputs=len(texts)
|
|
)
|
|
|
|
except Exception as e:
|
|
raise map_provider_error("ollama", e)
|
|
|
|
async def aembed(
|
|
self,
|
|
texts: Union[str, List[str]],
|
|
**kwargs: Any
|
|
) -> EmbeddingResponse:
|
|
"""Asynchronously generate embeddings for texts using Ollama.
|
|
|
|
Args:
|
|
texts: Single text or list of texts to embed
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
EmbeddingResponse
|
|
"""
|
|
if self.httpx is None:
|
|
raise ConfigurationError(
|
|
"httpx package not installed for async support. Install with: pip install httpx",
|
|
provider="ollama"
|
|
)
|
|
|
|
# Ensure texts is a list
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
try:
|
|
start_time = time.time()
|
|
embeddings = []
|
|
|
|
async with self.httpx.AsyncClient() as client:
|
|
# Process texts one at a time (Ollama limitation)
|
|
for text in texts:
|
|
payload = {
|
|
"model": kwargs.get("model", self.model),
|
|
"prompt": text
|
|
}
|
|
|
|
try:
|
|
response = await client.post(
|
|
self.embeddings_url,
|
|
json=payload,
|
|
timeout=kwargs.get("timeout", 30)
|
|
)
|
|
response.raise_for_status()
|
|
|
|
data = response.json()
|
|
if "embedding" in data:
|
|
embeddings.append(data["embedding"])
|
|
else:
|
|
raise ValueError(f"No embedding in response: {data}")
|
|
|
|
except self.httpx.ConnectError:
|
|
raise ConfigurationError(
|
|
f"Cannot connect to Ollama at {self.base_url}. "
|
|
"Make sure Ollama is running (ollama serve).",
|
|
provider="ollama"
|
|
)
|
|
except self.httpx.HTTPStatusError as e:
|
|
if e.response.status_code == 404:
|
|
raise ValueError(
|
|
f"Model '{self.model}' not found. "
|
|
f"Pull it first with: ollama pull {self.model}"
|
|
)
|
|
raise
|
|
|
|
latency = time.time() - start_time
|
|
|
|
return self._create_response(
|
|
embeddings=embeddings,
|
|
latency=latency,
|
|
num_inputs=len(texts)
|
|
)
|
|
|
|
except Exception as e:
|
|
raise map_provider_error("ollama", e)
|
|
|
|
def _create_response(
|
|
self,
|
|
embeddings: List[List[float]],
|
|
latency: float,
|
|
num_inputs: int
|
|
) -> EmbeddingResponse:
|
|
"""Create embedding response."""
|
|
# Get dimension from first embedding
|
|
dimension = len(embeddings[0]) if embeddings else 0
|
|
|
|
# No cost for local models
|
|
cost = 0.0 # Always 0.0 for local models
|
|
|
|
return EmbeddingResponse(
|
|
embeddings=embeddings,
|
|
model=self.model,
|
|
provider="ollama",
|
|
latency=latency,
|
|
dimension=dimension,
|
|
num_inputs=num_inputs,
|
|
total_tokens=None, # Ollama doesn't provide token counts
|
|
cost=cost,
|
|
metadata={
|
|
"base_url": self.base_url,
|
|
"local": True
|
|
},
|
|
raw_response=None
|
|
)
|
|
|
|
|
|
# Factory function to create embedding instances
|
|
def create_embedding(
|
|
provider: str = "cohere",
|
|
model: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> BaseEmbedding:
|
|
"""Create an embedding instance based on provider.
|
|
|
|
Args:
|
|
provider: Embedding provider (cohere, gemini)
|
|
model: Model to use (provider-specific)
|
|
api_key: API key (optional if set in environment)
|
|
**kwargs: Additional configuration
|
|
|
|
Returns:
|
|
BaseEmbedding instance
|
|
"""
|
|
provider = provider.lower()
|
|
|
|
if provider == "cohere":
|
|
return CohereEmbedding(
|
|
model=model or "embed-english-v3.0",
|
|
api_key=api_key,
|
|
**kwargs
|
|
)
|
|
elif provider in ["gemini", "google"]:
|
|
return GeminiEmbedding(
|
|
model=model or "text-embedding-004",
|
|
api_key=api_key,
|
|
**kwargs
|
|
)
|
|
elif provider == "ollama":
|
|
return OllamaEmbedding(
|
|
model=model or "nomic-embed-text:latest",
|
|
base_url=kwargs.pop("base_url", None), # Extract base_url from kwargs
|
|
**kwargs
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown embedding provider: {provider}")
|
|
|
|
|
|
# Convenience aliases
|
|
Embed = create_embedding
|
|
CohereEmbed = CohereEmbedding
|
|
GeminiEmbed = GeminiEmbedding
|
|
OllamaEmbed = OllamaEmbedding |