"""Embedding router implementation for multiple providers.""" from typing import List, Dict, Any, Optional, Union, Literal from dataclasses import dataclass, field import time from abc import ABC, abstractmethod from .config import config from .exceptions import ( ConfigurationError, map_provider_error ) @dataclass class EmbeddingResponse: """Response from embedding operation.""" embeddings: List[List[float]] # List of embedding vectors model: str provider: str latency: float dimension: int # Dimension of embeddings num_inputs: int # Number of inputs embedded total_tokens: Optional[int] = None cost: Optional[float] = None metadata: Dict[str, Any] = field(default_factory=dict) raw_response: Optional[Any] = None class BaseEmbedding(ABC): """Base class for embedding routers.""" def __init__(self, model: str, api_key: str, **kwargs: Any) -> None: """Initialize embedding router. Args: model: Model identifier api_key: API key **kwargs: Additional configuration """ self.model = model self.api_key = api_key self.config = kwargs @abstractmethod def embed( self, texts: Union[str, List[str]], **kwargs: Any ) -> EmbeddingResponse: """Generate embeddings for texts. Args: texts: Single text or list of texts to embed **kwargs: Additional parameters Returns: EmbeddingResponse """ pass @abstractmethod async def aembed( self, texts: Union[str, List[str]], **kwargs: Any ) -> EmbeddingResponse: """Asynchronously generate embeddings for texts. Args: texts: Single text or list of texts to embed **kwargs: Additional parameters Returns: EmbeddingResponse """ pass class CohereEmbedding(BaseEmbedding): """Router for Cohere embedding models.""" def __init__( self, model: str = "embed-english-v3.0", api_key: Optional[str] = None, **kwargs: Any ) -> None: """Initialize Cohere embedding router. Args: model: Embedding model (embed-english-v3.0, embed-multilingual-v3.0, etc.) api_key: Cohere API key (optional if set in environment) **kwargs: Additional configuration """ # Get API key from config if not provided if not api_key: api_key = config.get_api_key("cohere") if not api_key: raise ConfigurationError( "Cohere API key not found. Set COHERE_API_KEY environment variable.", provider="cohere" ) super().__init__(model, api_key, **kwargs) # Initialize Cohere client try: import cohere # For v5.15.0, use the standard Client self.client = cohere.Client(api_key) self.async_client = cohere.AsyncClient(api_key) except ImportError: raise ConfigurationError( "cohere package not installed. Install with: pip install cohere", provider="cohere" ) def embed( self, texts: Union[str, List[str]], input_type: Optional[Literal["search_document", "search_query", "classification", "clustering"]] = None, truncate: Optional[Literal["NONE", "START", "END"]] = None, **kwargs: Any ) -> EmbeddingResponse: """Generate embeddings for texts. Args: texts: Single text or list of texts to embed input_type: Purpose of embeddings (affects vector space) truncate: How to handle inputs longer than max tokens **kwargs: Additional parameters Returns: EmbeddingResponse """ # Ensure texts is a list if isinstance(texts, str): texts = [texts] params = self._prepare_request(texts, input_type, truncate, **kwargs) try: start_time = time.time() response = self.client.embed(**params) latency = time.time() - start_time return self._parse_response(response, latency, len(texts)) except Exception as e: raise map_provider_error("cohere", e) async def aembed( self, texts: Union[str, List[str]], input_type: Optional[Literal["search_document", "search_query", "classification", "clustering"]] = None, truncate: Optional[Literal["NONE", "START", "END"]] = None, **kwargs: Any ) -> EmbeddingResponse: """Asynchronously generate embeddings for texts. Args: texts: Single text or list of texts to embed input_type: Purpose of embeddings truncate: How to handle long inputs **kwargs: Additional parameters Returns: EmbeddingResponse """ # Ensure texts is a list if isinstance(texts, str): texts = [texts] params = self._prepare_request(texts, input_type, truncate, **kwargs) try: start_time = time.time() response = await self.async_client.embed(**params) latency = time.time() - start_time return self._parse_response(response, latency, len(texts)) except Exception as e: raise map_provider_error("cohere", e) def _prepare_request( self, texts: List[str], input_type: Optional[str], truncate: Optional[str], **kwargs: Any ) -> Dict[str, Any]: """Prepare embed request parameters.""" params = { "model": kwargs.get("model", self.model), "texts": texts, } # For v3 models, input_type is required - default to search_document if input_type: params["input_type"] = input_type elif "v3" in self.model or "3.0" in self.model or "3.5" in self.model: params["input_type"] = "search_document" if truncate: params["truncate"] = truncate return params def _parse_response( self, raw_response: Any, latency: float, num_inputs: int ) -> EmbeddingResponse: """Parse Cohere embed response.""" # Extract embeddings embeddings = [] if hasattr(raw_response, "embeddings"): embeddings = raw_response.embeddings # Get dimension from first embedding dimension = len(embeddings[0]) if embeddings else 0 # Get token count from response metadata total_tokens = None if hasattr(raw_response, "meta") and hasattr(raw_response.meta, "billed_units"): total_tokens = getattr(raw_response.meta.billed_units, "input_tokens", None) # Calculate cost cost = None if config.track_costs and total_tokens: cost = config.calculate_embed_cost(self.model, total_tokens) return EmbeddingResponse( embeddings=embeddings, model=self.model, provider="cohere", latency=latency, dimension=dimension, num_inputs=num_inputs, total_tokens=total_tokens, cost=cost, metadata={ "id": getattr(raw_response, "id", None), "response_type": getattr(raw_response, "response_type", None), }, raw_response=raw_response ) class GeminiEmbedding(BaseEmbedding): """Router for Google Gemini embedding models.""" def __init__( self, model: str = "text-embedding-004", api_key: Optional[str] = None, **kwargs: Any ) -> None: """Initialize Gemini embedding router. Args: model: Embedding model (text-embedding-004, etc.) api_key: Google API key (optional if set in environment) **kwargs: Additional configuration """ # Get API key from config if not provided if not api_key: api_key = config.get_api_key("gemini") if not api_key: raise ConfigurationError( "Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.", provider="gemini" ) super().__init__(model, api_key, **kwargs) # Initialize Gemini client using new google.genai library try: from google import genai from google.genai import types self.genai = genai self.types = types self.client = genai.Client(api_key=api_key) except ImportError: raise ConfigurationError( "google-genai package not installed. Install with: pip install google-genai", provider="gemini" ) def embed( self, texts: Union[str, List[str]], task_type: Optional[str] = None, title: Optional[str] = None, **kwargs: Any ) -> EmbeddingResponse: """Generate embeddings for texts. Args: texts: Single text or list of texts to embed task_type: Task type for embeddings title: Optional title for context **kwargs: Additional parameters Returns: EmbeddingResponse """ # Ensure texts is a list if isinstance(texts, str): texts = [texts] try: start_time = time.time() # The Google genai SDK supports batch embedding # Pass all texts at once for better performance params = { "model": kwargs.get("model", self.model), "contents": texts, # Can pass multiple texts } # Add optional config config_params = {} if task_type: config_params["task_type"] = task_type if title: config_params["title"] = title if config_params: params["config"] = self.types.EmbedContentConfig(**config_params) response = self.client.models.embed_content(**params) # Extract embeddings from response # The response always has an 'embeddings' attribute (even for single text) embeddings = [] # Check if embeddings exist if response.embeddings is None: raise ValueError("No embeddings returned in response") # response.embeddings is always present for emb in response.embeddings: # ContentEmbedding objects have a 'values' attribute containing the float list if hasattr(emb, "values") and emb.values is not None: embeddings.append(list(emb.values)) elif hasattr(emb, "__iter__"): # If the embedding is directly iterable (list-like) try: embeddings.append(list(emb)) except Exception as e: print(f"Warning: Could not extract embedding values from {type(emb)}: {e}") else: print(f"Warning: Unknown embedding format: {type(emb)}, attributes: {dir(emb)}") latency = time.time() - start_time # Token counting is not directly available in the response # Set to None for now total_tokens = None return self._create_response( embeddings=embeddings, latency=latency, num_inputs=len(texts), total_tokens=total_tokens ) except Exception as e: raise map_provider_error("gemini", e) async def aembed( self, texts: Union[str, List[str]], task_type: Optional[str] = None, title: Optional[str] = None, **kwargs: Any ) -> EmbeddingResponse: """Asynchronously generate embeddings for texts. Args: texts: Single text or list of texts to embed task_type: Task type for embeddings title: Optional title for context **kwargs: Additional parameters Returns: EmbeddingResponse """ # Ensure texts is a list if isinstance(texts, str): texts = [texts] try: start_time = time.time() # The Google genai SDK supports batch embedding # Pass all texts at once for better performance params = { "model": kwargs.get("model", self.model), "contents": texts, # Can pass multiple texts } # Add optional config config_params = {} if task_type: config_params["task_type"] = task_type if title: config_params["title"] = title if config_params: params["config"] = self.types.EmbedContentConfig(**config_params) response = await self.client.aio.models.embed_content(**params) # Extract embeddings from response # The response always has an 'embeddings' attribute (even for single text) embeddings = [] # Check if embeddings exist if response.embeddings is None: raise ValueError("No embeddings returned in response") # response.embeddings is always present for emb in response.embeddings: # ContentEmbedding objects have a 'values' attribute containing the float list if hasattr(emb, "values") and emb.values is not None: embeddings.append(list(emb.values)) elif hasattr(emb, "__iter__"): # If the embedding is directly iterable (list-like) try: embeddings.append(list(emb)) except Exception as e: print(f"Warning: Could not extract embedding values from {type(emb)}: {e}") else: print(f"Warning: Unknown embedding format: {type(emb)}, attributes: {dir(emb)}") latency = time.time() - start_time # Token counting is not directly available in the response # Set to None for now total_tokens = None return self._create_response( embeddings=embeddings, latency=latency, num_inputs=len(texts), total_tokens=total_tokens ) except Exception as e: raise map_provider_error("gemini", e) def _create_response( self, embeddings: List[List[float]], latency: float, num_inputs: int, total_tokens: Optional[int] = None ) -> EmbeddingResponse: """Create embedding response.""" # Get dimension from first embedding dimension = len(embeddings[0]) if embeddings else 0 # Calculate cost cost = None if config.track_costs and total_tokens: cost = config.calculate_embed_cost(self.model, total_tokens) return EmbeddingResponse( embeddings=embeddings, model=self.model, provider="gemini", latency=latency, dimension=dimension, num_inputs=num_inputs, total_tokens=total_tokens, cost=cost, metadata={}, raw_response=None ) class OllamaEmbedding(BaseEmbedding): """Router for Ollama local embedding models.""" def __init__( self, model: str = "nomic-embed-text:latest", base_url: Optional[str] = None, **kwargs: Any ) -> None: """Initialize Ollama embedding router. Args: model: Ollama embedding model (mxbai-embed-large, nomic-embed-text, etc.) base_url: Ollama API base URL (default: http://localhost:11434) **kwargs: Additional configuration """ # No API key needed for local Ollama super().__init__(model, api_key="local", **kwargs) # Get base URL from config or parameter self.base_url = base_url or config.ollama_base_url self.embeddings_url = f"{self.base_url}/api/embeddings" # Check if Ollama is available try: import requests self.requests = requests except ImportError: raise ConfigurationError( "requests package not installed. Install with: pip install requests", provider="ollama" ) # For async support try: import httpx self.httpx = httpx except ImportError: self.httpx = None # Async support is optional def embed( self, texts: Union[str, List[str]], **kwargs: Any ) -> EmbeddingResponse: """Generate embeddings for texts using Ollama. Args: texts: Single text or list of texts to embed **kwargs: Additional parameters Returns: EmbeddingResponse """ # Ensure texts is a list if isinstance(texts, str): texts = [texts] try: start_time = time.time() embeddings = [] # Ollama API currently handles one text at a time for text in texts: payload = { "model": kwargs.get("model", self.model), "prompt": text } try: response = self.requests.post( self.embeddings_url, json=payload, timeout=kwargs.get("timeout", 30) ) response.raise_for_status() data = response.json() if "embedding" in data: embeddings.append(data["embedding"]) else: raise ValueError(f"No embedding in response: {data}") except self.requests.exceptions.ConnectionError: raise ConfigurationError( f"Cannot connect to Ollama at {self.base_url}. " "Make sure Ollama is running (ollama serve).", provider="ollama" ) except self.requests.exceptions.HTTPError as e: if e.response.status_code == 404: raise ValueError( f"Model '{self.model}' not found. " f"Pull it first with: ollama pull {self.model}" ) raise latency = time.time() - start_time return self._create_response( embeddings=embeddings, latency=latency, num_inputs=len(texts) ) except Exception as e: raise map_provider_error("ollama", e) async def aembed( self, texts: Union[str, List[str]], **kwargs: Any ) -> EmbeddingResponse: """Asynchronously generate embeddings for texts using Ollama. Args: texts: Single text or list of texts to embed **kwargs: Additional parameters Returns: EmbeddingResponse """ if self.httpx is None: raise ConfigurationError( "httpx package not installed for async support. Install with: pip install httpx", provider="ollama" ) # Ensure texts is a list if isinstance(texts, str): texts = [texts] try: start_time = time.time() embeddings = [] async with self.httpx.AsyncClient() as client: # Process texts one at a time (Ollama limitation) for text in texts: payload = { "model": kwargs.get("model", self.model), "prompt": text } try: response = await client.post( self.embeddings_url, json=payload, timeout=kwargs.get("timeout", 30) ) response.raise_for_status() data = response.json() if "embedding" in data: embeddings.append(data["embedding"]) else: raise ValueError(f"No embedding in response: {data}") except self.httpx.ConnectError: raise ConfigurationError( f"Cannot connect to Ollama at {self.base_url}. " "Make sure Ollama is running (ollama serve).", provider="ollama" ) except self.httpx.HTTPStatusError as e: if e.response.status_code == 404: raise ValueError( f"Model '{self.model}' not found. " f"Pull it first with: ollama pull {self.model}" ) raise latency = time.time() - start_time return self._create_response( embeddings=embeddings, latency=latency, num_inputs=len(texts) ) except Exception as e: raise map_provider_error("ollama", e) def _create_response( self, embeddings: List[List[float]], latency: float, num_inputs: int ) -> EmbeddingResponse: """Create embedding response.""" # Get dimension from first embedding dimension = len(embeddings[0]) if embeddings else 0 # No cost for local models cost = 0.0 # Always 0.0 for local models return EmbeddingResponse( embeddings=embeddings, model=self.model, provider="ollama", latency=latency, dimension=dimension, num_inputs=num_inputs, total_tokens=None, # Ollama doesn't provide token counts cost=cost, metadata={ "base_url": self.base_url, "local": True }, raw_response=None ) # Factory function to create embedding instances def create_embedding( provider: str = "cohere", model: Optional[str] = None, api_key: Optional[str] = None, **kwargs: Any ) -> BaseEmbedding: """Create an embedding instance based on provider. Args: provider: Embedding provider (cohere, gemini) model: Model to use (provider-specific) api_key: API key (optional if set in environment) **kwargs: Additional configuration Returns: BaseEmbedding instance """ provider = provider.lower() if provider == "cohere": return CohereEmbedding( model=model or "embed-english-v3.0", api_key=api_key, **kwargs ) elif provider in ["gemini", "google"]: return GeminiEmbedding( model=model or "text-embedding-004", api_key=api_key, **kwargs ) elif provider == "ollama": return OllamaEmbedding( model=model or "nomic-embed-text:latest", base_url=kwargs.pop("base_url", None), # Extract base_url from kwargs **kwargs ) else: raise ValueError(f"Unknown embedding provider: {provider}") # Convenience aliases Embed = create_embedding CohereEmbed = CohereEmbedding GeminiEmbed = GeminiEmbedding OllamaEmbed = OllamaEmbedding