template-with-ai/router/embed.py
2025-07-01 17:07:02 +05:30

745 lines
No EOL
25 KiB
Python

"""Embedding router implementation for multiple providers."""
from typing import List, Dict, Any, Optional, Union, Literal
from dataclasses import dataclass, field
import time
from abc import ABC, abstractmethod
from .config import config
from .exceptions import (
ConfigurationError,
map_provider_error
)
@dataclass
class EmbeddingResponse:
"""Response from embedding operation."""
embeddings: List[List[float]] # List of embedding vectors
model: str
provider: str
latency: float
dimension: int # Dimension of embeddings
num_inputs: int # Number of inputs embedded
total_tokens: Optional[int] = None
cost: Optional[float] = None
metadata: Dict[str, Any] = field(default_factory=dict)
raw_response: Optional[Any] = None
class BaseEmbedding(ABC):
"""Base class for embedding routers."""
def __init__(self, model: str, api_key: str, **kwargs: Any) -> None:
"""Initialize embedding router.
Args:
model: Model identifier
api_key: API key
**kwargs: Additional configuration
"""
self.model = model
self.api_key = api_key
self.config = kwargs
@abstractmethod
def embed(
self,
texts: Union[str, List[str]],
**kwargs: Any
) -> EmbeddingResponse:
"""Generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
pass
@abstractmethod
async def aembed(
self,
texts: Union[str, List[str]],
**kwargs: Any
) -> EmbeddingResponse:
"""Asynchronously generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
pass
class CohereEmbedding(BaseEmbedding):
"""Router for Cohere embedding models."""
def __init__(
self,
model: str = "embed-english-v3.0",
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Cohere embedding router.
Args:
model: Embedding model (embed-english-v3.0, embed-multilingual-v3.0, etc.)
api_key: Cohere API key (optional if set in environment)
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("cohere")
if not api_key:
raise ConfigurationError(
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
provider="cohere"
)
super().__init__(model, api_key, **kwargs)
# Initialize Cohere client
try:
import cohere
# For v5.15.0, use the standard Client
self.client = cohere.Client(api_key)
self.async_client = cohere.AsyncClient(api_key)
except ImportError:
raise ConfigurationError(
"cohere package not installed. Install with: pip install cohere",
provider="cohere"
)
def embed(
self,
texts: Union[str, List[str]],
input_type: Optional[Literal["search_document", "search_query", "classification", "clustering"]] = None,
truncate: Optional[Literal["NONE", "START", "END"]] = None,
**kwargs: Any
) -> EmbeddingResponse:
"""Generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
input_type: Purpose of embeddings (affects vector space)
truncate: How to handle inputs longer than max tokens
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
params = self._prepare_request(texts, input_type, truncate, **kwargs)
try:
start_time = time.time()
response = self.client.embed(**params)
latency = time.time() - start_time
return self._parse_response(response, latency, len(texts))
except Exception as e:
raise map_provider_error("cohere", e)
async def aembed(
self,
texts: Union[str, List[str]],
input_type: Optional[Literal["search_document", "search_query", "classification", "clustering"]] = None,
truncate: Optional[Literal["NONE", "START", "END"]] = None,
**kwargs: Any
) -> EmbeddingResponse:
"""Asynchronously generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
input_type: Purpose of embeddings
truncate: How to handle long inputs
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
params = self._prepare_request(texts, input_type, truncate, **kwargs)
try:
start_time = time.time()
response = await self.async_client.embed(**params)
latency = time.time() - start_time
return self._parse_response(response, latency, len(texts))
except Exception as e:
raise map_provider_error("cohere", e)
def _prepare_request(
self,
texts: List[str],
input_type: Optional[str],
truncate: Optional[str],
**kwargs: Any
) -> Dict[str, Any]:
"""Prepare embed request parameters."""
params = {
"model": kwargs.get("model", self.model),
"texts": texts,
}
# For v3 models, input_type is required - default to search_document
if input_type:
params["input_type"] = input_type
elif "v3" in self.model or "3.0" in self.model or "3.5" in self.model:
params["input_type"] = "search_document"
if truncate:
params["truncate"] = truncate
return params
def _parse_response(
self,
raw_response: Any,
latency: float,
num_inputs: int
) -> EmbeddingResponse:
"""Parse Cohere embed response."""
# Extract embeddings
embeddings = []
if hasattr(raw_response, "embeddings"):
embeddings = raw_response.embeddings
# Get dimension from first embedding
dimension = len(embeddings[0]) if embeddings else 0
# Get token count from response metadata
total_tokens = None
if hasattr(raw_response, "meta") and hasattr(raw_response.meta, "billed_units"):
total_tokens = getattr(raw_response.meta.billed_units, "input_tokens", None)
# Calculate cost
cost = None
if config.track_costs and total_tokens:
cost = config.calculate_embed_cost(self.model, total_tokens)
return EmbeddingResponse(
embeddings=embeddings,
model=self.model,
provider="cohere",
latency=latency,
dimension=dimension,
num_inputs=num_inputs,
total_tokens=total_tokens,
cost=cost,
metadata={
"id": getattr(raw_response, "id", None),
"response_type": getattr(raw_response, "response_type", None),
},
raw_response=raw_response
)
class GeminiEmbedding(BaseEmbedding):
"""Router for Google Gemini embedding models."""
def __init__(
self,
model: str = "text-embedding-004",
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Gemini embedding router.
Args:
model: Embedding model (text-embedding-004, etc.)
api_key: Google API key (optional if set in environment)
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("gemini")
if not api_key:
raise ConfigurationError(
"Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.",
provider="gemini"
)
super().__init__(model, api_key, **kwargs)
# Initialize Gemini client using new google.genai library
try:
from google import genai
from google.genai import types
self.genai = genai
self.types = types
self.client = genai.Client(api_key=api_key)
except ImportError:
raise ConfigurationError(
"google-genai package not installed. Install with: pip install google-genai",
provider="gemini"
)
def embed(
self,
texts: Union[str, List[str]],
task_type: Optional[str] = None,
title: Optional[str] = None,
**kwargs: Any
) -> EmbeddingResponse:
"""Generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
task_type: Task type for embeddings
title: Optional title for context
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
try:
start_time = time.time()
# The Google genai SDK supports batch embedding
# Pass all texts at once for better performance
params = {
"model": kwargs.get("model", self.model),
"contents": texts, # Can pass multiple texts
}
# Add optional config
config_params = {}
if task_type:
config_params["task_type"] = task_type
if title:
config_params["title"] = title
if config_params:
params["config"] = self.types.EmbedContentConfig(**config_params)
response = self.client.models.embed_content(**params)
# Extract embeddings from response
# The response always has an 'embeddings' attribute (even for single text)
embeddings = []
# Check if embeddings exist
if response.embeddings is None:
raise ValueError("No embeddings returned in response")
# response.embeddings is always present
for emb in response.embeddings:
# ContentEmbedding objects have a 'values' attribute containing the float list
if hasattr(emb, "values") and emb.values is not None:
embeddings.append(list(emb.values))
elif hasattr(emb, "__iter__"):
# If the embedding is directly iterable (list-like)
try:
embeddings.append(list(emb))
except Exception as e:
print(f"Warning: Could not extract embedding values from {type(emb)}: {e}")
else:
print(f"Warning: Unknown embedding format: {type(emb)}, attributes: {dir(emb)}")
latency = time.time() - start_time
# Token counting is not directly available in the response
# Set to None for now
total_tokens = None
return self._create_response(
embeddings=embeddings,
latency=latency,
num_inputs=len(texts),
total_tokens=total_tokens
)
except Exception as e:
raise map_provider_error("gemini", e)
async def aembed(
self,
texts: Union[str, List[str]],
task_type: Optional[str] = None,
title: Optional[str] = None,
**kwargs: Any
) -> EmbeddingResponse:
"""Asynchronously generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
task_type: Task type for embeddings
title: Optional title for context
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
try:
start_time = time.time()
# The Google genai SDK supports batch embedding
# Pass all texts at once for better performance
params = {
"model": kwargs.get("model", self.model),
"contents": texts, # Can pass multiple texts
}
# Add optional config
config_params = {}
if task_type:
config_params["task_type"] = task_type
if title:
config_params["title"] = title
if config_params:
params["config"] = self.types.EmbedContentConfig(**config_params)
response = await self.client.aio.models.embed_content(**params)
# Extract embeddings from response
# The response always has an 'embeddings' attribute (even for single text)
embeddings = []
# Check if embeddings exist
if response.embeddings is None:
raise ValueError("No embeddings returned in response")
# response.embeddings is always present
for emb in response.embeddings:
# ContentEmbedding objects have a 'values' attribute containing the float list
if hasattr(emb, "values") and emb.values is not None:
embeddings.append(list(emb.values))
elif hasattr(emb, "__iter__"):
# If the embedding is directly iterable (list-like)
try:
embeddings.append(list(emb))
except Exception as e:
print(f"Warning: Could not extract embedding values from {type(emb)}: {e}")
else:
print(f"Warning: Unknown embedding format: {type(emb)}, attributes: {dir(emb)}")
latency = time.time() - start_time
# Token counting is not directly available in the response
# Set to None for now
total_tokens = None
return self._create_response(
embeddings=embeddings,
latency=latency,
num_inputs=len(texts),
total_tokens=total_tokens
)
except Exception as e:
raise map_provider_error("gemini", e)
def _create_response(
self,
embeddings: List[List[float]],
latency: float,
num_inputs: int,
total_tokens: Optional[int] = None
) -> EmbeddingResponse:
"""Create embedding response."""
# Get dimension from first embedding
dimension = len(embeddings[0]) if embeddings else 0
# Calculate cost
cost = None
if config.track_costs and total_tokens:
cost = config.calculate_embed_cost(self.model, total_tokens)
return EmbeddingResponse(
embeddings=embeddings,
model=self.model,
provider="gemini",
latency=latency,
dimension=dimension,
num_inputs=num_inputs,
total_tokens=total_tokens,
cost=cost,
metadata={},
raw_response=None
)
class OllamaEmbedding(BaseEmbedding):
"""Router for Ollama local embedding models."""
def __init__(
self,
model: str = "nomic-embed-text:latest",
base_url: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Ollama embedding router.
Args:
model: Ollama embedding model (mxbai-embed-large, nomic-embed-text, etc.)
base_url: Ollama API base URL (default: http://localhost:11434)
**kwargs: Additional configuration
"""
# No API key needed for local Ollama
super().__init__(model, api_key="local", **kwargs)
# Get base URL from config or parameter
self.base_url = base_url or config.ollama_base_url
self.embeddings_url = f"{self.base_url}/api/embeddings"
# Check if Ollama is available
try:
import requests
self.requests = requests
except ImportError:
raise ConfigurationError(
"requests package not installed. Install with: pip install requests",
provider="ollama"
)
# For async support
try:
import httpx
self.httpx = httpx
except ImportError:
self.httpx = None # Async support is optional
def embed(
self,
texts: Union[str, List[str]],
**kwargs: Any
) -> EmbeddingResponse:
"""Generate embeddings for texts using Ollama.
Args:
texts: Single text or list of texts to embed
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
try:
start_time = time.time()
embeddings = []
# Ollama API currently handles one text at a time
for text in texts:
payload = {
"model": kwargs.get("model", self.model),
"prompt": text
}
try:
response = self.requests.post(
self.embeddings_url,
json=payload,
timeout=kwargs.get("timeout", 30)
)
response.raise_for_status()
data = response.json()
if "embedding" in data:
embeddings.append(data["embedding"])
else:
raise ValueError(f"No embedding in response: {data}")
except self.requests.exceptions.ConnectionError:
raise ConfigurationError(
f"Cannot connect to Ollama at {self.base_url}. "
"Make sure Ollama is running (ollama serve).",
provider="ollama"
)
except self.requests.exceptions.HTTPError as e:
if e.response.status_code == 404:
raise ValueError(
f"Model '{self.model}' not found. "
f"Pull it first with: ollama pull {self.model}"
)
raise
latency = time.time() - start_time
return self._create_response(
embeddings=embeddings,
latency=latency,
num_inputs=len(texts)
)
except Exception as e:
raise map_provider_error("ollama", e)
async def aembed(
self,
texts: Union[str, List[str]],
**kwargs: Any
) -> EmbeddingResponse:
"""Asynchronously generate embeddings for texts using Ollama.
Args:
texts: Single text or list of texts to embed
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
if self.httpx is None:
raise ConfigurationError(
"httpx package not installed for async support. Install with: pip install httpx",
provider="ollama"
)
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
try:
start_time = time.time()
embeddings = []
async with self.httpx.AsyncClient() as client:
# Process texts one at a time (Ollama limitation)
for text in texts:
payload = {
"model": kwargs.get("model", self.model),
"prompt": text
}
try:
response = await client.post(
self.embeddings_url,
json=payload,
timeout=kwargs.get("timeout", 30)
)
response.raise_for_status()
data = response.json()
if "embedding" in data:
embeddings.append(data["embedding"])
else:
raise ValueError(f"No embedding in response: {data}")
except self.httpx.ConnectError:
raise ConfigurationError(
f"Cannot connect to Ollama at {self.base_url}. "
"Make sure Ollama is running (ollama serve).",
provider="ollama"
)
except self.httpx.HTTPStatusError as e:
if e.response.status_code == 404:
raise ValueError(
f"Model '{self.model}' not found. "
f"Pull it first with: ollama pull {self.model}"
)
raise
latency = time.time() - start_time
return self._create_response(
embeddings=embeddings,
latency=latency,
num_inputs=len(texts)
)
except Exception as e:
raise map_provider_error("ollama", e)
def _create_response(
self,
embeddings: List[List[float]],
latency: float,
num_inputs: int
) -> EmbeddingResponse:
"""Create embedding response."""
# Get dimension from first embedding
dimension = len(embeddings[0]) if embeddings else 0
# No cost for local models
cost = 0.0 # Always 0.0 for local models
return EmbeddingResponse(
embeddings=embeddings,
model=self.model,
provider="ollama",
latency=latency,
dimension=dimension,
num_inputs=num_inputs,
total_tokens=None, # Ollama doesn't provide token counts
cost=cost,
metadata={
"base_url": self.base_url,
"local": True
},
raw_response=None
)
# Factory function to create embedding instances
def create_embedding(
provider: str = "cohere",
model: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs: Any
) -> BaseEmbedding:
"""Create an embedding instance based on provider.
Args:
provider: Embedding provider (cohere, gemini)
model: Model to use (provider-specific)
api_key: API key (optional if set in environment)
**kwargs: Additional configuration
Returns:
BaseEmbedding instance
"""
provider = provider.lower()
if provider == "cohere":
return CohereEmbedding(
model=model or "embed-english-v3.0",
api_key=api_key,
**kwargs
)
elif provider in ["gemini", "google"]:
return GeminiEmbedding(
model=model or "text-embedding-004",
api_key=api_key,
**kwargs
)
elif provider == "ollama":
return OllamaEmbedding(
model=model or "nomic-embed-text:latest",
base_url=kwargs.pop("base_url", None), # Extract base_url from kwargs
**kwargs
)
else:
raise ValueError(f"Unknown embedding provider: {provider}")
# Convenience aliases
Embed = create_embedding
CohereEmbed = CohereEmbedding
GeminiEmbed = GeminiEmbedding
OllamaEmbed = OllamaEmbedding