Initial commit

This commit is contained in:
Pratik Narola 2025-07-01 16:22:20 +05:30
commit 9961cb55a6
12 changed files with 3415 additions and 0 deletions

61
.gitignore vendored Normal file
View file

@ -0,0 +1,61 @@
.venv/
.env/
# Python cache files
__pycache__/
*.py[cod]
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
eggs/
.eggs/
lib/
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
dev-debug.log
.pytest_cache/
# Dependency directories
node_modules/
# Environment variables
.env
.venv
# Editor directories and files
.idea
.vscode
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?
# OS specific
.DS_Store
*.pyc
# Task files
# tasks.json
# tasks/
__pycache__/
*/__pycache__/
src/__pycache__/
tests/unit/__pycache__/
# Build artifacts
src/tinyrag.egg-info/
# Review files
review.md
chroma_db_poc/
tree-sitter-rescript/

259
docs/research/RAG.md Normal file
View file

@ -0,0 +1,259 @@
# 😎 Awesome Retrieval Augmented Generation (RAG) [![Awesome](https://awesome.re/badge-flat.svg)](https://awesome.re)
This repository contains a curated [Awesome List](https://github.com/sindresorhus/awesome) and general information on Retrieval-Augmented Generation (RAG) applications in Generative AI.
Retrieval-Augmented Generation (RAG) is a technique in Generative AI where additional context is retrieved from external sources to enrich the generative process of Large Language Models (LLMs). This approach allows LLMs to incorporate up-to-date, specific, or sensitive information that they may lack from their pre-training data alone.
## Content
- [ General Information on RAG](#%EF%B8%8F-general-information-on-rag)
- [🎯 Approaches](#-approaches)
- [🧰 Frameworks that Facilitate RAG](#-frameworks-that-facilitate-rag)
- [🛠️ Techniques](#-techniques)
- [📊 Metrics](#-metrics)
- [💾 Databases](#-databases)
## General Information on RAG
In traditional RAG approaches, a basic framework is employed to retrieve documents that enrich the context of an LLM prompt. For instance, when querying about materials for renovating a house, the LLM may possess general knowledge about renovation but lacks specific details about the particular house. Implementing an RAG architecture allows for quick searching and retrieval of relevant documents, such as blueprints, to offer more customized responses. This ensures that the LLM incorporates specific information to the renovation needs, thereby enhancing the accuracy of its responses.
**A typical RAG implementation follows these key steps:**
1. **Divide the knowledge base:** Break the document corpus into smaller, manageable chunks.
2. **Create embeddings:** Apply an embedding model to transform these text chunks into vector embeddings, capturing their semantic meaning.
3. **Store in a vector database:** Save the embeddings in a vector database, enabling fast retrieval based on semantic similarity.
4. **Handle user queries:** Convert the user's query into an embedding using the same model that was applied to the text chunks.
5. **Retrieve relevant data:** Search the vector database for embeddings that closely match the querys embedding based on semantic similarity.
6. **Enhance the prompt:** Incorporate the most relevant text chunks into the LLMs prompt to provide valuable context for generating a response.
7. **Generate a response:** The LLM leverages the augmented prompt to deliver a response that is accurate and tailored to the users query.
## 🎯 Approaches
RAG implementations vary in complexity, from simple document retrieval to advanced techniques integrating iterative feedback loops and domain-specific enhancements. Approaches may include:
- [Cache-Augmented Generation (CAG)](https://medium.com/@ronantech/cache-augmented-generation-cag-in-llms-a-step-by-step-tutorial-6ac35d415eec): Preloads relevant documents into a models context and stores the inference state (Key-Value (KV) cache).
- [Agentic RAG](https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/): Also known as retrieval agents, can make decisions on retrieval processes.
- [Corrective RAG](https://arxiv.org/pdf/2401.15884.pdf) (CRAG): Methods to correct or refine the retrieved information before integration into LLM responses.
- [Retrieval-Augmented Fine-Tuning](https://techcommunity.microsoft.com/t5/ai-ai-platform-blog/raft-a-new-way-to-teach-llms-to-be-better-at-rag/ba-p/4084674) (RAFT): Techniques to fine-tune LLMs specifically for enhanced retrieval and generation tasks.
- [Self Reflective RAG](https://selfrag.github.io/): Models that dynamically adjust retrieval strategies based on model performance feedback.
- [RAG Fusion](https://arxiv.org/abs/2402.03367): Techniques combining multiple retrieval methods for improved context integration.
- [Temporal Augmented Retrieval](https://adam-rida.medium.com/temporal-augmented-retrieval-tar-dynamic-rag-ad737506dfcc) (TAR): Considering time-sensitive data in retrieval processes.
- [Plan-then-RAG](https://arxiv.org/abs/2406.12430) (PlanRAG): Strategies involving planning stages before executing RAG for complex tasks.
- [GraphRAG](https://github.com/microsoft/graphrag): A structured approach using knowledge graphs for enhanced context integration and reasoning.
- [FLARE](https://medium.com/etoai/better-rag-with-active-retrieval-augmented-generation-flare-3b66646e2a9f) - An approach that incorporates active retrieval-augmented generation to improve response quality.
- [Contextual Retrieval](https://www.anthropic.com/news/contextual-retrieval) - Improves retrieval by adding relevant context to document chunks before retrieval, enhancing the relevance of information retrieved from large knowledge bases.
- [GNN-RAG](https://github.com/cmavro/GNN-RAG): Graph neural retrieval for large language modeling reasoning.
## 🧰 Frameworks that Facilitate RAG
- [Haystack](https://github.com/deepset-ai/haystack): LLM orchestration framework to build customizable, production-ready LLM applications.
- [LangChain](https://python.langchain.com/docs/modules/data_connection/): An all-purpose framework for working with LLMs.
- [Semantic Kernel](https://github.com/microsoft/semantic-kernel): An SDK from Microsoft for developing Generative AI applications.
- [LlamaIndex](https://docs.llamaindex.ai/en/stable/optimizing/production_rag/): Framework for connecting custom data sources to LLMs.
- [Dify](https://github.com/langgenius/dify): An open-source LLM app development platform.
- [Cognita](https://github.com/truefoundry/cognita): Open-source RAG framework for building modular and production ready applications.
- [Verba](https://github.com/weaviate/Verba): Open-source application for RAG out of the box.
- [Mastra](https://github.com/mastra-ai/mastra): Typescript framework for building AI applications.
- [Letta](https://github.com/letta-ai/letta): Open source framework for building stateful LLM applications.
- [Flowise](https://github.com/FlowiseAI/Flowise): Drag & drop UI to build customized LLM flows.
- [Swiftide](https://github.com/bosun-ai/swiftide): Rust framework for building modular, streaming LLM applications.
- [CocoIndex](https://github.com/cocoindex-io/cocoindex): ETL framework to index data for AI, such as RAG; with realtime incremental updates.
## 🛠️ Techniques
### Data cleaning
- [Data cleaning techniques](https://medium.com/intel-tech/four-data-cleaning-techniques-to-improve-large-language-model-llm-performance-77bee9003625): Pre-processing steps to refine input data and improve model performance.
### Prompting
- **Strategies**
- [Tagging and Labeling](https://python.langchain.com/v0.1/docs/use_cases/tagging/): Adding semantic tags or labels to retrieved data to enhance relevance.
- [Chain of Thought (CoT)](https://www.promptingguide.ai/techniques/cot): Encouraging the model to think through problems step by step before providing an answer.
- [Chain of Verification (CoVe)](https://sourajit16-02-93.medium.com/chain-of-verification-cove-understanding-implementation-e7338c7f4cb5): Prompting the model to verify each step of its reasoning for accuracy.
- [Self-Consistency](https://www.promptingguide.ai/techniques/consistency): Generating multiple reasoning paths and selecting the most consistent answer.
- [Zero-Shot Prompting](https://www.promptingguide.ai/techniques/zeroshot): Designing prompts that guide the model without any examples.
- [Few-Shot Prompting](https://python.langchain.com/docs/how_to/few_shot_examples/): Providing a few examples in the prompt to demonstrate the desired response format.
- [Reason & Act (ReAct) prompting](https://www.promptingguide.ai/techniques/react): Combines reasoning (e.g. CoT) with acting (e.g. tool calling).
- **Caching**
- [Prompt Caching](https://medium.com/@1kg/prompt-cache-what-is-prompt-caching-a-comprehensive-guide-e6cbae48e6a3): Optimizes LLMs by storing and reusing precomputed attention states.
### Chunking
- **[Fixed-size chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d)**
- Dividing text into consistent-sized segments for efficient processing.
- Splits texts into chunks based on size and overlap.
- Example: [Split by character](https://python.langchain.com/v0.1/docs/modules/data_connection/document_transformers/character_text_splitter/) (LangChain).
- Example: [SentenceSplitter](https://docs.llamaindex.ai/en/stable/api_reference/node_parsers/sentence_splitter/) (LlamaIndex).
- **[Recursive chunking](https://medium.com/@AbhiramiVS/chunking-methods-all-to-know-about-it-65c10aa7b24e)**
- Hierarchical segmentation using recursive algorithms for complex document structures.
- Example: [Recursively split by character](https://python.langchain.com/v0.1/docs/modules/data_connection/document_transformers/recursive_text_splitter/) (LangChain).
- **[Document-based chunking](https://medium.com/@david.richards.tech/document-chunking-for-rag-ai-applications-04363d48fbf7)**
- Segmenting documents based on metadata or formatting cues for targeted analysis.
- Example: [MarkdownHeaderTextSplitter](https://python.langchain.com/v0.1/docs/modules/data_connection/document_transformers/markdown_header_metadata/) (LangChain).
- Example: Handle image and text embeddings with models like [OpenCLIP](https://github.com/mlfoundations/open_clip).
- **[Semantic chunking](https://www.youtube.com/watch?v=8OJC21T2SL4&t=1933s)**
- Extracting meaningful sections based on semantic relevance rather than arbitrary boundaries.
- **[Agentic chunking](https://youtu.be/8OJC21T2SL4?si=8VnYaGUaBmtZhCsg&t=2882)**
- Interactive chunking methods where LLMs guide segmentation.
### Embeddings
- **Select embedding model**
- **[MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard)**: Explore [Hugging Face's](https://github.com/huggingface) benchmark for evaluating model embeddings.
- **Custom Embeddings**: Develop tailored embeddings for specific domains or tasks to enhance model performance. Custom embeddings can capture domain-specific terminology and nuances. Techniques include fine-tuning pre-trained models on your own dataset or training embeddings from scratch using frameworks like TensorFlow or PyTorch.
### Retrieval
- **Search Methods**
- [Vector Store Flat Index](https://weaviate.io/developers/academy/py/vector_index/flat)
- Simple and efficient form of retrieval.
- Content is vectorized and stored as flat content vectors.
- [Hierarchical Index Retrieval](https://pixion.co/blog/rag-strategies-hierarchical-index-retrieval)
- Hierarchically narrow data to different levels.
- Executes retrievals by hierarchical order.
- [Hypothetical Questions](https://pixion.co/blog/rag-strategies-hypothetical-questions-hyde)
- Used to increase similarity between database chunks and queries (same with HyDE).
- LLM is used to generate specific questions for each text chunk.
- Converts these questions into vector embeddings.
- During search, matches queries against this index of question vectors.
- [Hypothetical Document Embeddings (HyDE)](https://pixion.co/blog/rag-strategies-hypothetical-questions-hyde)
- Used to increase similarity between database chunks and queries (same with Hypothetical Questions).
- LLM is used to generate a hypothetical response based on the query.
- Converts this response into a vector embedding.
- Compares the query vector with the hypothetical response vector.
- [Small to Big Retrieval](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/retrieval-augmented-generation/small_to_big_rag/small_to_big_rag.ipynb)
- Improves retrieval by using smaller chunks for search and larger chunks for context.
- Smaller child chunks refers to bigger parent chunks
- **[Re-ranking](https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/)**: Enhances search results in RAG pipelines by reordering initially retrieved documents, prioritizing those most semantically relevant to the query.
### Response quality & safety
- **[Hallucination](https://machinelearningmastery.com/rag-hallucination-detection-techniques/):** When an AI model generates incorrect or fabricated information, which can be mitigated through grounding, refined retrieval, and verification techniques.
- **[Guardrails](https://developer.ibm.com/tutorials/awb-how-to-implement-llm-guardrails-for-rag-applications/):** Mechanisms to ensure accurate, ethical, and safe responses by applying content moderation, bias mitigation, and fact-checking.
- **[Prompt Injection Prevention](https://hiddenlayer.com/innovation-hub/prompt-injection-attacks-on-llms/):**
- **Input Validation:** Rigorously validate and sanitize all external inputs to ensure that only intended data is incorporated into the prompt.
- **Content Separation:** Clearly distinguish between trusted, static instructions and dynamic user data using templating or placeholders.
- **Output Monitoring:** Continuously monitor responses and logs for any anomalies that could indicate prompt manipulation, and adjust guardrails accordingly.
## 📊 Metrics
### Search metrics
These metrics are used to measure the similarity between embeddings, which is crucial for evaluating how effectively RAG systems retrieve and integrate external documents or data sources. By selecting appropriate similarity metrics, you can optimize the performance and accuracy of your RAG system. Alternatively, you may develop custom metrics tailored to your specific domain or niche to capture domain-specific nuances and improve relevance.
- **[Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity)**
- Measures the cosine of the angle between two vectors in a multi-dimensional space.
- Highly effective for comparing text embeddings where the direction of the vectors represents semantic information.
- Commonly used in RAG systems to measure semantic similarity between query embeddings and document embeddings.
- **[Dot Product](https://en.wikipedia.org/wiki/Dot_product)**
- Calculates the sum of the products of corresponding entries of two sequences of numbers.
- Equivalent to cosine similarity when vectors are normalized.
- Simple and efficient, often used with hardware acceleration for large-scale computations.
- **[Euclidean Distance](https://en.wikipedia.org/wiki/Euclidean_distance)**
- Computes the straight-line distance between two points in Euclidean space.
- Can be used with embeddings but may lose effectiveness in high-dimensional spaces due to the "[curse of dimensionality](https://stats.stackexchange.com/questions/99171/why-is-euclidean-distance-not-a-good-metric-in-high-dimensions)."
- Often used in clustering algorithms like K-means after dimensionality reduction.
- **[Jaccard Similarity](https://en.wikipedia.org/wiki/Jaccard_index)**
- Measures the similarity between two finite sets as the size of the intersection divided by the size of the union of the sets.
- Useful when comparing sets of tokens, such as in bag-of-words models or n-gram comparisons.
- Less applicable to continuous embeddings produced by LLMs.
> **Note:** Cosine Similarity and Dot Product are generally seen as the most effective metrics for measuring similarity between high-dimensional embeddings.
### Response Evaluation Metrics
Response evaluation in RAG solutions involves assessing the quality of language model outputs using diverse metrics. Here are structured approaches to evaluating these responses:
- **Automated Benchmarking**
- **[BLEU](https://en.wikipedia.org/wiki/BLEU):** Evaluates the overlap of n-grams between machine-generated and reference outputs, providing insight into precision.
- **[ROUGE](<https://en.wikipedia.org/wiki/ROUGE_(metric)>):** Measures recall by comparing n-grams, skip-bigrams, or longest common subsequence with reference outputs.
- **[METEOR](https://en.wikipedia.org/wiki/METEOR):** Focuses on exact matches, stemming, synonyms, and alignment for machine translation.
- **Human Evaluation**
Involves human judges assessing responses for:
- **Relevance:** Alignment with user queries.
- **Fluency:** Grammatical and stylistic quality.
- **Factual Accuracy:** Verifying claims against authoritative sources.
- **Coherence:** Logical consistency within responses.
- **Model Evaluation**
Leverages pre-trained evaluators to benchmark outputs against diverse criteria:
- **[TuringBench](https://turingbench.ist.psu.edu/):** Offers comprehensive evaluations across language benchmarks.
- **[Hugging Face Evaluate](https://huggingface.co/docs/evaluate/en/index):** Calculates alignment with human preferences.
- **Key Dimensions for Evaluation**
- **Groundedness:** Assesses if responses are based entirely on provided context. Low groundedness may indicate reliance on hallucinated or irrelevant information.
- **Completeness:** Measures if the response answers all aspects of a query.
- **Approaches:** AI-assisted retrieval scoring and prompt-based intent verification.
- **Utilization:** Evaluates the extent to which retrieved data contributes to the response.
- **Analysis:** Use LLMs to check the inclusion of retrieved chunks in responses.
#### Tools
These tools can assist in evaluating the performance of your RAG system, from tracking user feedback to logging query interactions and comparing multiple evaluation metrics over time.
- **[LangFuse](https://github.com/langfuse/langfuse)**: Open-source tool for tracking LLM metrics, observability, and prompt management.
- **[Ragas](https://docs.ragas.io/en/stable/)**: Framework that helps evaluate RAG pipelines.
- **[LangSmith](https://docs.smith.langchain.com/)**: A platform for building production-grade LLM applications, allows you to closely monitor and evaluate your application.
- **[Hugging Face Evaluate](https://github.com/huggingface/evaluate)**: Tool for computing metrics like BLEU and ROUGE to assess text quality.
- **[Weights & Biases](https://wandb.ai/wandb-japan/rag-hands-on/reports/Step-for-developing-and-evaluating-RAG-application-with-W-B--Vmlldzo1NzU4OTAx)**: Tracks experiments, logs metrics, and visualizes performance.
## 💾 Databases
The list below features several database systems suitable for Retrieval Augmented Generation (RAG) applications. They cover a range of RAG use cases, aiding in the efficient storage and retrieval of vectors to generate responses or recommendations.
### Benchmarks
- [Picking a vector database](https://benchmark.vectorview.ai/vectordbs.html)
### Distributed Data Processing and Serving Engines:
- [Apache Cassandra](https://cassandra.apache.org/doc/latest/cassandra/vector-search/concepts.html): Distributed NoSQL database management system.
- [MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search): Globally distributed, multi-model database service with integrated vector search.
- [Vespa](https://vespa.ai/): Open-source big data processing and serving engine designed for real-time applications.
### Search Engines with Vector Capabilities:
- [Elasticsearch](https://www.elastic.co/elasticsearch): Provides vector search capabilities along with traditional search functionalities.
- [OpenSearch](https://github.com/opensearch-project/OpenSearch): Distributed search and analytics engine, forked from Elasticsearch.
### Vector Databases:
- [Chroma DB](https://github.com/chroma-core/chroma): An AI-native open-source embedding database.
- [Milvus](https://github.com/milvus-io/milvus): An open-source vector database for AI-powered applications.
- [Pinecone](https://www.pinecone.io/): A serverless vector database, optimized for machine learning workflows.
- [Oracle AI Vector Search](https://www.oracle.com/database/ai-vector-search/#retrieval-augmented-generation): Integrates vector search capabilities within Oracle Database for semantic querying based on vector embeddings.
### Relational Database Extensions:
- [Pgvector](https://github.com/pgvector/pgvector): An open-source extension for vector similarity search in PostgreSQL.
### Other Database Systems:
- [Azure Cosmos DB](https://learn.microsoft.com/en-us/azure/cosmos-db/vector-database): Globally distributed, multi-model database service with integrated vector search.
- [Couchbase](https://www.couchbase.com/products/vector-search/): A distributed NoSQL cloud database.
- [Lantern](https://lantern.dev/): A privacy-aware personal search engine.
- [LlamaIndex](https://docs.llamaindex.ai/en/stable/module_guides/storing/vector_stores/): Employs a straightforward in-memory vector store for rapid experimentation.
- [Neo4j](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/): Graph database management system.
- [Qdrant](https://github.com/neo4j/neo4j): An open-source vector database designed for similarity search.
- [Redis Stack](https://redis.io/docs/latest/develop/interact/search-and-query/): An in-memory data structure store used as a database, cache, and message broker.
- [SurrealDB](https://github.com/surrealdb/surrealdb): A scalable multi-model database optimized for time-series data.
- [Weaviate](https://github.com/weaviate/weaviate): A open-source cloud-native vector search engine.
### Vector Search Libraries and Tools:
- [FAISS](https://github.com/facebookresearch/faiss): A library for efficient similarity search and clustering of dense vectors, designed to handle large-scale datasets and optimized for fast retrieval of nearest neighbors.
---
This list continues to evolve. Contributions are welcome to make this resource more comprehensive 🙌

View file

@ -0,0 +1,645 @@
# RAG Research for Personal Knowledge Base
## Research Summary (Phase 1: Initial Repository Analysis)
This phase focuses on analyzing the information retrieved from the `https://github.com/NirDiamant/RAG_Techniques` repository.
### Search Strategy
- **Initial Query (Implicit):** User provided `RAG.md` which has general RAG info.
- **Targeted Retrieval:** Focused on `https://github.com/NirDiamant/RAG_Techniques` as a comprehensive source of advanced RAG techniques.
- **Focus Areas:** Identifying RAG techniques, frameworks, evaluation methods, and architectural patterns relevant to building a personal, evolving knowledge base for an AI agent.
### Key Resources
1. **`RAG.md` (User-provided):** General overview of RAG, basic steps, common frameworks, databases, and evaluation metrics.
2. **`https://github.com/NirDiamant/RAG_Techniques` (Retrieved via Puppeteer):** An extensive collection of advanced RAG techniques, tutorials, and links to further resources.
### Critical Findings from `NirDiamant/RAG_Techniques` (Relevant to Personal KB)
The repository lists numerous techniques. The following are initially highlighted for their potential relevance to a personal knowledge base that needs to be accurate, adaptable, and efficiently queried by an AI agent.
**I. Foundational Techniques:**
* **Basic RAG:** Essential starting point.
* **Optimizing Chunk Sizes:** Crucial for balancing context and retrieval efficiency. A personal KB might have diverse document lengths.
* **Proposition Chunking:** Breaking text into meaningful sentences could be highly beneficial for factual recall from personal notes or scraped articles.
**II. Query Enhancement:**
* **Query Transformations (Rewriting, Step-back, Sub-query):** Useful if the AI agent needs to formulate complex queries or if initial queries are too narrow.
* **HyDE (Hypothetical Document Embedding):** Could improve retrieval for nuanced personal queries by generating a hypothetical answer first.
* **HyPE (Hypothetical Prompt Embedding):** Precomputing hypothetical questions at indexing could speed up retrieval and improve alignment for a personal KB where query patterns might emerge.
**III. Context Enrichment:**
* **Contextual Chunk Headers:** Adding document/section context to chunks can improve retrieval accuracy, especially for diverse personal documents.
* **Relevant Segment Extraction:** Dynamically constructing multi-chunk segments could provide more complete context to the LLM from personal notes or longer articles.
* **Semantic Chunking:** More meaningful than fixed-size, good for understanding topics within personal data.
* **Contextual Compression:** Useful for fitting more relevant information into the LLM's context window, especially if personal documents are verbose.
* **Document Augmentation (Question Generation):** Could enhance retrieval for a personal KB by pre-generating potential questions the user/agent might ask.
**IV. Advanced Retrieval Methods:**
* **Fusion Retrieval (Keyword + Vector):** Could be powerful for a personal KB, allowing both semantic and exact-match searches (e.g., finding a specific term in notes). This is distinct from RAG-Fusion which focuses on multiple generated queries.
* **Reranking (LLM-based, Cross-Encoder):** Important for ensuring the most relevant personal information is prioritized.
* **Multi-faceted Filtering (Metadata, Similarity, Content, Diversity):** Essential for a personal KB to filter by date, source, type of note, or to avoid redundant information.
* **Hierarchical Indices:** Summaries + detailed chunks could be useful for navigating a large personal KB.
**V. Iterative and Adaptive Techniques:**
* **Retrieval with Feedback Loops:** If the agent can provide feedback, this could continuously improve the personal RAG's performance.
* **Adaptive Retrieval:** Tailoring strategies based on query types (e.g., "summarize my notes on X" vs. "find specific detail Y").
* **Iterative Retrieval:** Useful for complex questions requiring information from multiple personal documents.
**VI. Evaluation:**
* **DeepEval, GroUSE:** While potentially complex for a purely personal setup, understanding these metrics can guide manual evaluation or simpler automated checks. Key aspects for personal KB: Faithfulness (is it true to my notes?), Contextual Relevancy.
**VII. Advanced Architectures:**
* **Graph RAG (General & Microsoft's):** Potentially very powerful for connecting disparate pieces of information in a personal KB (e.g., linking project notes, contacts, and related articles). Might be complex to set up initially.
* **RAPTOR (Recursive Abstractive Processing):** Tree-organized retrieval could be excellent for hierarchical personal notes or structured scraped content.
* **Self-RAG:** Dynamically deciding to retrieve or not could make the agent more efficient with its personal KB.
* **Corrective RAG (CRAG):** Evaluating and correcting retrieval, possibly using web search if personal KB is insufficient, aligns well with an AI agent's needs.
* **Integrating Structured Personal Data (Conceptual):** While the primary focus of RAG is often unstructured text, a truly unified personal knowledge base (PUKB) should consider how to incorporate structured personal data like calendar events, to-do lists, and contacts. Conceptual approaches include:
1. **Conversion to Natural Language for Vector Search:**
* **Approach:** Transform structured items into descriptive sentences or paragraphs. For example, a calendar event "Project Alpha Sync, 2025-05-15 10:00 AM" could become "A meeting for Project Alpha Sync is scheduled on May 15, 2025, at 10:00 AM." This text is then embedded and indexed alongside other unstructured data.
* **Pros:** Allows for a single, unified search interface using the existing vector RAG pipeline; enables semantic querying across all data types.
* **Cons:** Potential loss of precise structured query capabilities (e.g., exact date range filtering might be less reliable than on structured fields); overhead of accurately converting structured data to natural language; may require careful template design or LLM prompting for conversion.
2. **Hybrid Search (Vector DB + Dedicated Structured Store):**
* **Approach:** Maintain structured personal data in a simple, local structured datastore (e.g., an SQLite database, or by directly parsing standard PIM files like iCalendar `.ics` or vCard `.vcf`). The AI agent, based on query analysis (e.g., detecting keywords like "calendar," "task," "contact," or specific date/time phrases), would decide to query the vector DB, the structured store, or both. Results would then be synthesized.
* **Pros:** Preserves the full fidelity and precise queryability of structured data; potentially more efficient for purely structured queries (e.g., "What are my appointments next Monday?").
* **Cons:** Increases agent complexity (query routing, results fusion); requires maintaining and querying two distinct types of data stores; may necessitate Text-to-SQL or Text-to-API-call capabilities for the agent to interact with the structured store based on natural language.
3. **Knowledge Graph Augmentation (Advanced):**
* **Approach:** Model structured entities (people, events, tasks, projects) and their interrelationships within a local knowledge graph. Link these graph nodes to relevant unstructured text chunks in the vector DB.
* **Pros:** Enables complex, multi-hop queries that traverse relationships (e.g., "Show me notes related to people I met with last week concerning Project X"). Provides a richer, interconnected view of personal knowledge.
* **Cons:** Significantly more complex to design, implement, and maintain. Steeper learning curve. Best considered as a future enhancement.
* **Initial PUKB Strategy:** For initial development, converting structured data to natural language for vector search offers the simplest path to unification. A hybrid approach can be a powerful V2 enhancement for more precise structured queries.
**VIII. Special Technique:**
* **Sophisticated Controllable Agent:** The ultimate goal, using a graph as a "brain" for complex RAG tasks on personal data.
### Technical Considerations for Personal KB
* **Data Ingestion:** How easily can new notes, scraped web content (`crawl4ai`), PDFs, etc., be added and indexed?
* **Update Mechanisms, Frequency, and Synchronization for Local Files:** Personal knowledge, especially from local files, evolves. The PUKB must efficiently detect changes and update its vector index to maintain relevance and accuracy. Strategies include:
* **Change Detection Methods for Local Files:**
1. **File System Watchers (Real-time/Near Real-time):**
* **Tool:** Python's `watchdog` library is a common cross-platform solution.
* **Mechanism:** Monitors specified directories for events like file creation, deletion, modification, and moves.
* **Pros:** Provides immediate or near-immediate updates to the PUKB.
* **Cons:** Can be resource-intensive if watching a vast number of files/folders; setup might have platform-specific nuances; requires a persistent process.
2. **Checksum/Hash Comparison (Periodic Scan):**
* **Tool:** Python's `hashlib` (e.g., for SHA256).
* **Mechanism:** During initial ingestion, store a hash of each local file's content in its metadata (e.g., `file_hash`). Periodically (e.g., on agent startup, scheduled background task), scan monitored directories. For each file, recalculate its hash and compare it to the stored hash. If different, or if the file is new, trigger re-ingestion. If a file path from the DB is no longer found, its chunks can be marked as stale or deleted.
* **Pros:** Robust against content changes; platform-independent logic.
* **Cons:** Not real-time; scans can be time-consuming for very large knowledge bases if not optimized (e.g., only hashing files whose `last_modified_date_os` has changed since last scan).
3. **Last Modified Timestamp (Periodic Scan):**
* **Tool:** Python's `os.path.getmtime()`.
* **Mechanism:** Store the `last_modified_date_os` in metadata. Periodically scan and compare the current timestamp with the stored one. If different, trigger re-ingestion.
* **Pros:** Simpler and faster to check than full hashing for an initial pass.
* **Cons:** Less robust than hashing, as timestamps can sometimes be misleading or not change even if content does (rare, but possible with some tools/operations).
4. **Manual Trigger:**
* **Mechanism:** The user explicitly commands the AI agent or PUKB system to re-index a specific file or folder.
* **Pros:** Simple, user-controlled, good for ad-hoc updates.
* **Cons:** Relies on the user to remember and initiate; not suitable for automatic synchronization.
* **Recommended Hybrid Approach:** Combine a periodic scan (e.g., daily or on startup) using last modified timestamps as a quick check, followed by hash comparison for files whose timestamps have changed, with an optional file system watcher for designated 'hot' or frequently updated directories if near real-time sync is critical for those.
* **Efficient Vector DB Update Strategy (Document-Level Re-indexing):**
* When a change in a local file is detected (or a file is deleted):
1. Identify the `parent_document_id` associated with the changed/deleted file (e.g., using `original_path` from metadata).
2. Delete all existing chunks from the vector database that share this `parent_document_id`.
3. If the file was modified (not deleted), re-ingest it: perform full text extraction, preprocessing, chunking, and embedding for the updated content.
4. Add the new chunks and their updated metadata (including new `file_hash` and `last_modified_date_os`) to the vector database.
* This document-level re-indexing is far more efficient than a full re-index of the entire knowledge base, especially for local file changes.
* **Frequency of Updates:**
* For watchers: Near real-time.
* For periodic scans: Configurable (e.g., on application start, hourly, daily). Balance between freshness and resource usage.
* Web-scraped content (`crawl4ai`): Updates depend on re-crawling schedules, which is a separate consideration from local file sync.
* **Scalability:** While "personal," the KB could grow significantly. The chosen techniques and database should handle this.
* **Privacy & Security in a Unified PUKB:** Ensuring the privacy and security of potentially sensitive personal or mixed (work/personal) data is paramount. A multi-layered approach is recommended:
* **Data Segregation within the Vector DB:**
* **ChromaDB:** Supports a hierarchy of Tenants -> Databases -> Collections. For PUKB, different personas or data categories (e.g., 'personal_notes', 'work_project_X', 'sensitive_health_info') can be stored in separate **Collections** or even separate **Databases** under a single tenant. Application logic is crucial for ensuring queries are directed to the correct collection(s) based on user context or active persona.
* **SQLite-VSS:** Data segregation relies on standard SQLite practices. This typically involves adding a `persona_id` column to vector tables and ensuring all application queries filter by this ID. Alternatively, separate SQLite database files per persona offer stronger isolation.
* **Qdrant (If considered for scalability):** Recommends using a single collection with **payload-based partitioning**. A `persona_id` field in each vector's payload would be used for filtering, ensuring each persona only accesses their data.
* **Weaviate (If considered for scalability):** Offers robust **built-in multi-tenancy**. Each persona could be a distinct tenant, providing strong data isolation at the database level, with data stored on separate shards.
* **Encryption:**
* **At Rest:**
* **OS-Level Full-Disk Encryption:** A baseline security measure for the machine hosting the PUKB.
* **Database-Level Encryption:** If the chosen vector DB supports it (e.g., encrypting database files or directories). For embedded DBs like ChromaDB (persisted) or SQLite, encrypting the parent directory or the files themselves using OS tools or libraries like `cryptography` can be an option.
* **Chunk-Level Encryption (Application-Side):** For highly sensitive data, consider encrypting individual chunks *before* storing them in the vector DB using libraries like Python's `cryptography`. Decryption would occur after retrieval and before sending to the LLM. Secure key management (e.g., OS keychain, environment variables for simpler setups, or dedicated secrets manager for advanced use) is critical.
* **In Transit:** Primarily relevant if any PUKB component (LLM, embedding model, or the DB itself if run as a separate server) is accessed over a network. Ensure all communications use TLS/HTTPS.
* **Role-Based Access Control (RBAC) / Persona-Based Filtering (Application Layer):**
* If the PUKB is used by an agent supporting multiple 'personas' (e.g., 'personal', 'work_project_A'), the application (agent) must enforce access control.
* The agent maintains an 'active persona' context.
* All queries to the PUKB (vector DB, structured stores) MUST be filtered based on this active persona, using the segregation mechanisms of the chosen DB (e.g., querying specific Chroma collections, filtering by `persona_id` in Qdrant/SQLite-VSS, or targeting a specific Weaviate tenant).
* **Implications of Local vs. Cloud LLMs/Embedding Models:**
* **Local Models (Ollama, Sentence-Transformers, etc.):** Offer the highest privacy as data (prompts, chunks for embedding) does not leave the user's machine. Consider performance and model capability trade-offs.
* **Cloud Models (OpenAI, Anthropic, etc.):** Data is sent to third-party servers. Users must review and accept the provider's data usage, retention, and privacy policies. There's an inherent risk of data exposure or use for model training (unless explicitly opted out and guaranteed by the provider).
* **Data Minimization:** Only ingest data that is necessary for the PUKB's purpose. For particularly sensitive information, evaluate if a summary, an anonymized version, or just metadata can be ingested instead of the full raw content.
* **Secure Deletion:** Implement a reliable process for deleting data. This involves removing chunks from the vector DB (and ensuring the DB reclaims space/updates indexes correctly) and deleting the original source file if requested. Some vector DBs might require specific commands or compaction processes for permanent deletion.
* **Input Sanitization:** If user queries are used to construct dynamic filters or other database interactions, ensure proper sanitization to prevent injection-style attacks, even in a local context.
* **Regular Backups:** Securely back up the PUKB (vector DB, configuration, local source files if not backed up elsewhere) to prevent data loss. Ensure backups are also encrypted.
* **Ease of Implementation/Maintenance:** For a personal system, overly complex solutions might be burdensome.
* **Query Types:** The system should handle diverse queries: factual recall, summarization, comparison, open-ended questions.
* **Integration with AI Agent & Query Handling for Unified KB:** The RAG output must be easily consumable, and the agent needs sophisticated logic to interact effectively with a unified PUKB containing diverse data types. This involves:
* **Query Intent Recognition:** The agent should employ Natural Language Processing (NLP) techniques (e.g., keyword analysis, intent classification models, or LLM-based analysis) to understand the user's query. For example, distinguishing between:
* `\"Summarize my notes on Project Titan.\"` (targets local notes, specific project)
* `\"What are the latest web articles on quantum computing?\"` (targets web scrapes, specific topic, recency)
* `\"Find emails from John Doe about the Q3 budget.\"` (targets emails, specific sender, topic)
* `\"What's on my calendar for next Monday?\"` (targets structured calendar data)
* **Metadata-Driven Query Routing & Filtering:** Based on the recognized intent, the agent dynamically constructs queries for the vector database, leveraging the Unified Metadata Schema:
* **Source Filtering:** If intent clearly points to a source (e.g., \"my notes\"), the agent adds a filter like `metadata.source_type IN ['manual_note', 'local_md']`.
* **Topic/Keyword Filtering:** Standard semantic search combined with keyword filters on `document_title` or `chunk_text` (if the vector DB supports hybrid search or if keywords are part of the metadata).
* **Date Filtering:** For time-sensitive queries (e.g., \"latest articles,\" \"notes from last week\"), use `ingestion_date`, `creation_date_os`, or `last_modified_date_os` fields.
* **Tag Filtering:** `user_tags` become powerful for personalized retrieval (e.g., `\"Find documents tagged 'urgent' and 'project_alpha'\"`).
* **Author/Sender Filtering:** For queries like `\"emails from Jane\"` or `\"documents by Smith\"`.
* **Handling Ambiguity:** If a query is ambiguous (e.g., `\"Tell me about Project X\"` could be notes, emails, or web data):
* The agent might initially query across a broader set of relevant `source_type`s and present categorized results.
* Alternatively, it could ask for clarification: `\"Are you looking for your notes, web articles, or emails regarding Project X?\"`
* **Structured Data Querying (Conceptual - see 'Integrating Structured Personal Data'):** If the query targets structured data (e.g., `\"Show my tasks due today\"`), the agent would need to route this to the appropriate structured data store/handler (e.g., an SQLite DB, iCalendar parser) instead of, or in addition to, the vector DB.
* **Consuming RAG Output:** The agent receives retrieved chunks (with their metadata) and passes them as context to an LLM for answer generation, summarization, or other tasks. The metadata itself can be valuable context for the LLM.
* **Unified Metadata Schema:** A flexible and comprehensive metadata schema is crucial for managing diverse data sources within the PUKB, enabling robust filtering, source tracking, and providing context to the LLM.
* **Core Metadata Fields (Applicable to all chunks):**
* `chunk_id`: (String) Unique identifier for this specific chunk (e.g., UUID).
* `parent_document_id`: (String) Unique identifier for the original source document/file/email/note this chunk belongs to (e.g., hash of file path, URL, or a UUID for the document).
* `source_type`: (String, Enum-like) Type of the original source. Examples: 'web_crawl4ai', 'local_txt', 'local_md', 'local_pdf_text', 'local_pdf_image_ocr', 'local_docx', 'local_image_ocr', 'local_email_eml', 'local_email_mbox', 'local_email_msg', 'local_code_snippet', 'manual_note', 'structured_calendar_event', 'structured_todo_item', 'structured_contact'.
* `document_title`: (String) Title of the source document (e.g., web page `<title>`, filename, email subject, first heading in a note).
* `original_path`: (String, Nullable) Absolute file system path for local files.
* `url`: (String, Nullable) Original URL for web-scraped content.
* `file_hash`: (String, Nullable) Hash (e.g., SHA256) of the original local file content, for change detection.
* `creation_date_os`: (ISO8601 Timestamp, Nullable) Original creation date of the file/document from the OS or source metadata.
* `last_modified_date_os`: (ISO8601 Timestamp, Nullable) Original last modified date of the file/document from the OS or source metadata.
* `ingestion_date`: (ISO8601 Timestamp) Date and time when the content was ingested into the PUKB.
* `author`: (String, Nullable) Author(s) of the document, if available/extractable.
* `user_tags`: (List of Strings, Nullable) User-defined tags or keywords associated with the document or chunk.
* `chunk_sequence_number`: (Integer, Nullable) If the document is split into multiple chunks, this indicates the order of the chunk within the document.
* `text_preview`: (String, Nullable) A short preview (e.g., first 200 chars) of the chunk's text content for quick inspection (optional, can increase storage).
* **Source-Specific Optional Fields (Examples):** These can be stored as a nested JSON object (e.g., `source_specific_details`) or flattened with prefixes if the vector DB prefers.
* For `source_type` starting with `'local_email_'`: `email_sender`, `email_recipients_to`, `email_recipients_cc`, `email_recipients_bcc`, `email_date_sent`, `email_message_id`.
* For `source_type: 'web_crawl4ai'`: `web_domain`, `crawl_depth` (if applicable).
* For `source_type: 'local_code_snippet'`: `code_language`.
* For `source_type: 'structured_calendar_event'`: `event_start_time`, `event_end_time`, `event_location`, `event_attendees`.
* **Population:** Metadata is populated during the ingestion pipeline. `chunk_id` and `parent_document_id` are generated. OS/file metadata is gathered for local files. Web metadata comes from crawl results. Email headers provide email-specific fields. `ingestion_date` is timestamped. `user_tags` can be added later.
### Best Practices Identified (Preliminary)
* **Modular Design:** Separate components for ingestion, chunking, embedding, storage, retrieval, and generation.
* **Experimentation:** Chunking strategy, embedding models, and retrieval methods often require experimentation.
* **Evaluation is Key:** Even for a personal system, periodically checking relevance and accuracy is important.
* **Start Simple:** Begin with a basic RAG and iteratively add advanced techniques.
### Common Pitfalls to Avoid for Personal KB
* **Over-chunking or Under-chunking:** Finding the right balance is critical.
* **Stale Index:** Not updating the RAG with new information regularly.
* **Ignoring Metadata:** Not using source, date, tags for filtering and context.
* **Choosing an Overly Complex System:** Starting with something too difficult to maintain for personal use.
* **Vendor Lock-in:** If using cloud services, consider portability.
### Impact on Approach for Personal KB
The `NirDiamant/RAG_Techniques` repository provides a rich set of options. For a personal knowledge base, the initial focus should be on:
1. **Solid Foundational RAG:** Good chunking (semantic or proposition), reliable embedding model.
2. **Effective Retrieval:** Fusion retrieval (keyword + semantic) and reranking seem highly valuable.
3. **Contextual Understanding:** Techniques like contextual chunk headers and relevant segment extraction.
4. **Manageable Complexity:** Prioritize techniques that can be implemented and maintained without excessive effort for a personal system. GraphRAG and RAPTOR are powerful but might be later-stage enhancements.
5. **Data Ingestion:** Needs to be seamless with `crawl4ai` outputs and manual note entry.
### Frameworks and Tools Mentioned (from NirDiamant/RAG_Techniques & RAG.md)
* **RAG Frameworks:**
* LangChain (has `crawl4ai` loader)
* LlamaIndex
* Haystack
* Semantic Kernel
* Dify, Cognita, Verba, Mastra, Letta, Flowise, Swiftide, CocoIndex
* **Evaluation Tools:**
* DeepEval
* GroUSE
* LangFuse, Ragas, LangSmith (more for production LLM apps but principles apply)
* **Vector Databases (also from `RAG.md`):**
* **Open Source / Local-friendly:** ChromaDB, Milvus (can be local), Qdrant, Weaviate (can be local), pgvector (PostgreSQL extension), FAISS (library), LlamaIndex (in-memory default), SQLite-VSS.
* **Cloud/Managed:** Pinecone, MongoDB Atlas, Vespa, Elasticsearch, OpenSearch, Oracle AI Vector Search, Azure Cosmos DB, Couchbase.
* **Graph-based:** Neo4j (can store vectors).
This initial analysis of the `NirDiamant/RAG_Techniques` repository provides a strong foundation. The next steps will involve deeper dives into the most promising techniques, vector databases, and frameworks suitable for a personal knowledge base.
## Phase 2: Unified Ingestion Strategy for Local Files
This section details the tools and methods for ingesting various local file types into the Personal Unified Knowledge Base (PUKB), prioritizing local-first, open-source solutions.
### 1. Plain Text (.txt)
* **Tool(s):** Python's built-in `open()` function.
* **Workflow:**
1. Use `with open('filepath.txt', 'r', encoding='utf-8') as f:` to open the file (specify encoding if known, UTF-8 is a good default).
2. Read content using `f.read()`.
3. Basic cleaning: Strip leading/trailing whitespace, normalize newlines if necessary.
### 2. Markdown (.md)
* **Tool(s):**
* Python's built-in `open()` for reading raw Markdown text.
* `markdown` library (e.g., `pip install Markdown`) if conversion to HTML is desired as an intermediate step (less common for direct RAG ingestion of Markdown).
* **Workflow (Raw Text):**
1. Use `with open('filepath.md', 'r', encoding='utf-8') as f:` to open.
2. Read content using `f.read()`. This raw Markdown is often suitable for direct chunking.
3. Basic cleaning: Similar to .txt files.
### 3. PDF (Text-Based)
* **Tool(s):** `PyMuPDF` (also known as `fitz`, `pip install PyMuPDF`). `pypdf2` (`pip install pypdf2`) is an alternative.
* **Workflow (`PyMuPDF`):**
1. Import `fitz`.
2. Open PDF: `doc = fitz.open('filepath.pdf')`.
3. Initialize an empty string for all text: `full_text = ""`.
4. Iterate through pages: `for page_num in range(len(doc)): page = doc.load_page(page_num); full_text += page.get_text()`.
5. Close document: `doc.close()`.
6. Clean extracted text: Remove excessive newlines, ligatures, or broken words if present.
* **Feeding to Advanced Chunkers:**
* **Semantic Chunking:** The preprocessed `raw_markdown` can be directly fed. Semantic chunkers often leverage Markdown structure (headings, paragraphs, lists) to identify meaningful boundaries for chunks.
* **Proposition Chunking:** The preprocessed `raw_markdown` is suitable. The proposition extractor (often LLM-based) will then parse this text to identify and extract atomic factual statements.
### 4. PDF (Scanned/Image-Based OCR)
* **Tool(s):**
* Option A: `OCRmyPDF` (`pip install ocrmypdf`, requires Tesseract installed system-wide).
* Option B: `PyMuPDF` (to extract images) + `pytesseract` (`pip install pytesseract`, requires Tesseract) + `Pillow` (`pip install Pillow`). `EasyOCR` (`pip install easyocr`) is an alternative to `pytesseract`.
* **Workflow (Option A - `OCRmyPDF`):**
1. Use command line: `ocrmypdf input.pdf output_ocr.pdf`.
2. Process `output_ocr.pdf` as a text-based PDF using `PyMuPDF` (see above) to extract the OCRed text layer.
* **Workflow (Option B - `PyMuPDF` + `pytesseract`):**
* **Feeding to Advanced Chunkers:**
* **Semantic Chunking:** The clean text extracted from `cleaned_html` (after HTML-to-text conversion and further cleaning) is fed to the semantic chunker. The quality of semantic segmentation will depend on how well the original HTML structure (paragraphs, sections) was translated into logical text blocks during the HTML-to-text conversion.
* **Proposition Chunking:** The clean text extracted from `cleaned_html` is suitable. The proposition extractor will parse this text for factual statements.
1. Open PDF with `PyMuPDF`: `doc = fitz.open('filepath.pdf')`.
2. Iterate through pages. For each page:
* Extract images: `pix = page.get_pixmap()`. Convert `pix` to a `Pillow` Image object.
* Perform OCR: `text_on_page = pytesseract.image_to_string(pil_image)`.
* Append `text_on_page` to `full_text`.
3. Clean OCRed text: This often requires more significant cleaning for OCR errors, layout artifacts, etc.
### 5. DOCX (Microsoft Word)
* **Tool(s):** `python-docx` (`pip install python-docx`).
* **Workflow:**
1. Import `Document` from `docx`.
2. Open document: `doc = Document('filepath.docx')`.
3. Initialize an empty string: `full_text = ""`.
4. Iterate through paragraphs: `for para in doc.paragraphs: full_text += para.text + '\n'`.
5. (Optional) Extract text from tables, headers, footers if needed, using respective `python-docx` APIs. `python-docx2txt` might simplify this.
### 6. Common Image Formats (PNG, JPG, etc. for OCR)
* **Tool(s):** `Pillow` (`pip install Pillow`) + `pytesseract` (`pip install pytesseract`). `EasyOCR` as an alternative.
* **Workflow:**
1. Import `Image` from `PIL` and `pytesseract`.
2. Open image: `img = Image.open('imagepath.png')`.
3. Extract text: `text_content = pytesseract.image_to_string(img)`.
4. Clean OCRed text.
### 7. Email (.eml)
* **Tool(s):** Python's built-in `email` module (specifically `email.parser.Parser` or `email.message_from_file`). `eml_parser` for a higher-level API.
* **Workflow (built-in `email`):**
1. Import `email.parser`.
2. Open file: `with open('filepath.eml', 'r') as f: msg = email.parser.Parser().parse(f)`.
3. Extract body: Iterate `msg.walk()`. For each part, check `part.get_content_type()`.
* If `text/plain`, get payload: `body = part.get_payload(decode=True).decode(part.get_content_charset() or 'utf-8')`.
* If `text/html`, get payload and strip HTML tags (e.g., using `BeautifulSoup`).
4. Extract other relevant fields: `msg['Subject']`, `msg['From']`, `msg['To']`, `msg['Date']`.
### 8. Email (.msg - Microsoft Outlook)
* **Tool(s):** `extract_msg` (`pip install extract-msg`).
* **Workflow:**
1. Import `Message` from `extract_msg`.
2. Open file: `msg = Message('filepath.msg')`.
3. Access properties: `body = msg.body`, `subject = msg.subject`, `sender = msg.sender`, `to = msg.to`, `date = msg.date`.
4. The `body` usually contains the main text content.
### 9. Email (mbox)
* **Tool(s):** Python's built-in `mailbox` module.
* **Workflow:**
1. Import `mailbox`.
2. Open mbox file: `mbox_archive = mailbox.mbox('filepath.mbox')`.
3. Iterate through messages: `for message in mbox_archive:`.
4. Each `message` is an `email.message.Message` instance. Process it similarly to .eml files (see above) to extract text from payloads.
5. Extract other relevant fields: `message['Subject']`, etc.
### 10. Code Snippets (e.g., .py, .js, .java)
* **Tool(s):** Python's built-in `open()`. Language-specific parsers like `ast` for Python for deeper analysis (optional).
* **Workflow (Raw Text):**
1. Treat as plain text: `with open('filepath.py', 'r', encoding='utf-8') as f: code_text = f.read()`.
2. This raw code text can be directly used for embedding and RAG.
3. Further preprocessing might involve stripping comments or formatting, depending on the desired RAG behavior.
---
## Preprocessing and Advanced Chunking for Diverse Data
Effective RAG relies on high-quality chunks. This section details preprocessing steps for various data types before applying advanced chunking techniques.
### Part 1: Preprocessing Specifics by Data Type
(Assumes initial text extraction as per "Phase 2: Unified Ingestion Strategy for Local Files")
* **`crawl4ai` - `markdown.raw_markdown` Output:**
* **Goal:** Ensure clean, consistent Markdown.
* **Steps:**
1. Normalize newlines (e.g., ensure all are `\n`).
2. Trim leading/trailing whitespace from the entire document and from individual lines.
3. Collapse multiple consecutive blank lines into a maximum of one or two.
4. Optionally, strip or handle any rare HTML remnants if `crawl4ai`'s conversion wasn't perfect (though `raw_markdown` should be fairly clean).
* **`crawl4ai` - `cleaned_html` Output:**
* **Goal:** Convert to well-structured plain text.
* **Steps:**
1. Use a robust HTML-to-text converter (e.g., `BeautifulSoup(html_content, 'html.parser').get_text(separator='\n', strip=True)` or `html2text`).
2. Ensure paragraph breaks and list structures from HTML are preserved as newlines or appropriate text formatting.
3. Remove any residual non-content elements (e.g., rare JavaScript snippets, CSS, boilerplate not fully caught by `crawl4ai`).
4. Normalize whitespace (trim, collapse multiple spaces).
* **Local Plain Text (.txt) & Markdown (.md):**
* **Goal:** Clean and standardize plain text.
* **Steps:**
1. Normalize newlines.
2. Trim leading/trailing whitespace.
3. Collapse multiple blank lines.
4. For Markdown, ensure its syntax is preserved for chunkers that might leverage it.
* **PDF (Text-Based, output from `PyMuPDF` etc.):**
* **Goal:** Clean artifacts from PDF text extraction.
* **Steps:**
1. Replace common ligatures (e.g., "fi" to "fi", "fl" to "fl").
2. Attempt to rejoin hyphenated words broken across lines (can be challenging; may use heuristics or dictionary-based approaches).
3. Identify and remove repetitive headers/footers if not handled by the extraction library (e.g., using pattern matching or positional analysis if page layout is consistent).
4. Normalize whitespace and remove excessive blank lines.
* **PDF (Scanned/Image-Based, OCR output from `pytesseract`, `EasyOCR`):**
* **Goal:** Correct OCR errors and improve readability.
* **Steps:**
1. Apply spelling correction (e.g., using `pyspellchecker` or a similar library).
2. Filter out OCR noise or gibberish (e.g., sequences of non-alphanumeric characters unlikely to be valid text, very short isolated "words").
3. Attempt to reconstruct paragraph/section structure if lost during OCR (e.g., based on analyzing vertical spacing if available from OCR engine, or using NLP techniques to group related sentences).
4. Normalize all forms of whitespace.
* **DOCX (Output from `python-docx`):**
* **Goal:** Clean text extracted from Word documents.
* **Steps:**
1. Normalize whitespace (trim, collapse multiple spaces/newlines).
2. Remove artifacts from complex Word formatting if they translate poorly to plain text.
3. Handle lists and table text appropriately if extracted.
* **Image OCR Output (from standalone images):**
* **Goal:** Similar to scanned PDF OCR.
* **Steps:**
1. Spelling correction.
2. Noise removal.
3. Whitespace normalization.
* **Emails (Text extracted from .eml, .msg, .mbox):**
* **Goal:** Isolate main content and standardize.
* **Steps:**
1. Remove or clearly demarcate headers (From, To, Subject, Date), footers, and common disclaimers.
2. Normalize quoting styles for replies (e.g., convert `>` prefixes consistently, or attempt to strip quoted history if only the latest message is desired).
3. If original was HTML, ensure clean conversion to text, preserving paragraph structure.
4. Standardize signature blocks or remove them.
5. Normalize whitespace.
* **Code Snippets (Raw text):**
* **Goal:** Prepare code for embedding, preserving semantic structure.
* **Steps:**
1. Normalize newlines.
2. Ensure consistent indentation (e.g., convert tabs to spaces, or vice-versa, though often best left as-is if the chunker can use indentation for semantic grouping).
3. Decide on handling comments: strip them, or preserve them as they provide valuable context for RAG.
4. For some languages, normalizing case for keywords might be considered, but generally not required for modern embedding models.
### Part 2: Application of Advanced Chunking Techniques to Preprocessed Data
Once the data from various sources has been preprocessed into clean text, the following advanced chunking strategies can be applied:
* **Semantic Chunking:**
* **Principle:** Divides text into chunks based on semantic similarity or topic coherence, rather than fixed sizes. Often uses embedding models to measure sentence/paragraph similarity.
* **Application to Diverse Data:**
* **`crawl4ai` (Web Articles, Docs):** Effectively groups paragraphs or sections discussing the same sub-topic. Can leverage existing HTML structure (like `<section>`, `<h2>`) if translated well into the text.
* **Local Notes (.txt, .md):** Identifies coherent thoughts or topics within longer notes. Markdown headings can provide strong hints.
* **PDF/DOCX (Reports, Chapters):** Groups related paragraphs within sections, even if original formatting cues are subtle in the extracted text.
* **Emails:** Can separate distinct topics if an email covers multiple subjects, or keep a single coherent discussion together.
* **Code Snippets:** Can group entire functions, classes, or logically related blocks of code, especially if comments and docstrings are included in the preprocessed text.
* **Conceptual Example (Python-like pseudo-code):**
```python
# from semantic_text_splitter import CharacterTextSplitter # Example library
# text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
# "sentence-transformers/all-MiniLM-L6-v2", # Example embedding model
# chunk_size=512, # Target, but will split semantically
# chunk_overlap=50
# )
# chunks = text_splitter.split_text(preprocessed_text_from_any_source)
```
* **Proposition Chunking:**
* **Principle:** Breaks down text into atomic, factual statements or propositions. This often involves using an LLM to rephrase or extract these propositions.
* **Application to Diverse Data:**
* **`crawl4ai` (Factual Articles):** Ideal for extracting key facts, figures, and claims from news or informational web pages.
* **Local Notes (.txt, .md):** Converts personal notes or meeting minutes into a list of distinct facts, ideas, or action items.
* **PDF/DOCX (Dense Documents):** Extracts core assertions from academic papers, technical manuals, or reports.
* **Emails:** Isolates key information, decisions, or requests from email conversations.
* **Code Snippets:** Less about the code logic itself, but can extract factual statements from docstrings or high-level comments (e.g., "Function `calculate_sum` returns the total of a list.").
* **Conceptual Example (Python-like pseudo-code):**
```python
# llm_client = YourLLMClient() # Interface to an LLM
# prompt = f"Extract all distinct factual propositions from the following text. Each proposition should be a complete sentence and stand alone:\n\n{preprocessed_text_from_any_source}"
# response = llm_client.generate(prompt)
# propositions = response.splitlines() # Assuming LLM returns one prop per line
# # Each proposition in 'propositions' is a chunk
```
* **Contextual Chunk Headers:**
* **Principle:** Adds a small piece of contextual metadata (e.g., source, title, section) as a prefix to each chunk's text before embedding. This helps the retrieval system and LLM understand the chunk's origin and context.
* **Application to Diverse Data (Header Examples):**
* **`crawl4ai` Output:** `"[Source: Web Article | URL: {url} | Title: {page_title} | Section: {nearest_heading_text}]\n{chunk_text}"`
* **Local Markdown Note:** `"[Source: Local Note | File: {filename} | Path: {relative_path} | Title: {document_title_from_h1_or_filename}]\n{chunk_text}"`
* **PDF Document:** `"[Source: PDF Document | File: {filename} | Title: {pdf_title_metadata} | Page: {page_number}]\n{chunk_text}"`
* **Email:** `"[Source: Email | From: {sender} | Subject: {subject} | Date: {date}]\n{chunk_text}"`
* **Code Snippet:** `"[Source: Code File | File: {filename} | Language: {language} | Context: {function/class_name}]\n{chunk_text}"`
* **Implementation:** This is typically done by constructing the header string from the document's metadata (see Unified Metadata Schema) and prepending it to the chunk content before it's passed to the embedding model.
## Phase 3: Deep Dive into Specific RAG Techniques
### 1. Semantic Chunking
... (content as before) ...
### 2. Fusion Retrieval (Hybrid Search) & Reranking (Initial Summary)
... (content as before) ...
### 3. RAG-Fusion (Query Generation & Reciprocal Rank Fusion)
... (content as before) ...
### 4. Reranking with Cross-Encoders and LLMs
... (content as before) ...
### 5. RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval)
... (content as before) ...
### 6. Corrective RAG (CRAG)
... (content as before) ...
---
## Phase 4: Deep Dive into Vector Databases (for Local RAG)
This section will explore various vector database options suitable for a local personal knowledge base, focusing on ease of setup, Python integration, performance for typical personal KB sizes, and maintenance.
### 1. ChromaDB
... (content as before) ...
### 2. FAISS (Facebook AI Similarity Search)
... (content as before) ...
### 3. Qdrant
... (content as before) ...
### 4. Weaviate
... (content as before) ...
### 5. SQLite-VSS (SQLite Vector Similarity Search)
... (content as before) ...
---
## Phase 5: Deep Dive into RAG Frameworks
This section will explore RAG orchestration frameworks, focusing on their suitability for building a custom RAG pipeline for a personal knowledge base.
### 1. LangChain vs. LlamaIndex
... (content as before) ...
### 2. Haystack (by deepset)
... (content as before) ...
### 3. Semantic Kernel (by Microsoft)
... (content as before) ...
---
## Phase 6: Synthesis and Recommendations for Personal Knowledge Base RAG
This section synthesizes the research on RAG techniques, vector databases, and frameworks to provide recommendations for building a personal knowledge base for an AI agent, with `crawl4ai` as a primary data ingestion tool.
### I. Core Goal Recap & Decision on Unified vs. Separate RAG
* **Primary Goal:** Create a robust, adaptable, and efficiently queryable personal knowledge base for an AI agent. This KB will store diverse information, including scraped web content, project notes, and potentially code.
* **Secondary Goal:** Consider if this personal KB can be unified with a work-related KB or if they should remain separate.
* **Recommendation:** For initial development and given the "personal" focus, **start with a separate RAG for the personal knowledge base.**
* **Reasoning:**
* **Simplicity:** Reduces initial complexity in terms of data siloing, access control, and differing update cadences.
* **Privacy:** Keeps personal data distinctly separate, which is crucial.
* **Focus:** Allows tailoring the RAG system (chunking, embedding, retrieval strategies) specifically to the nature of personal data and queries.
* **Future Unification:** A well-designed personal RAG can potentially be integrated or federated with a work RAG later if designed with modularity in mind. The core challenge would be managing context and preventing data leakage between the two.
### II. Recommended RAG Framework
Based on the research, here's a comparison and recommendation:
| Feature/Aspect | LangChain | LlamaIndex | Haystack | Semantic Kernel |
| :---------------------- | :------------------------------------------ | :--------------------------------------------- | :--------------------------------------------- | :------------------------------------------ |
| **Primary Focus** | General LLM App Dev, Agents, Chains | Data Framework for LLM Apps, RAG focus | LLM Orchestration, Production RAG, Pipelines | Agentic AI, Planning, Plugins |
| **Ease of Basic RAG** | Moderate | High (RAG-centric abstractions) | Moderate (explicit pipeline setup) | Moderate (RAG is a pattern to build) |
| **Advanced RAG** | High (many components, flexible) | High (many advanced RAG modules) | High (flexible pipelines, diverse components) | Moderate (via plugins, memory) |
| **`crawl4ai` Integration** | Native `Crawl4aiLoader` | Custom loader needed (easy to adapt) | Custom integration needed (easy to adapt) | Custom plugin/memory integration needed |
| **Local Deployment** | Excellent (many local model/DB integrations) | Excellent (many local model/DB integrations) | Excellent (many local model/DB integrations) | Good (supports local models, memory stores) |
| **Modularity** | High (Chains, LCEL) | High (Engines, Indices, Retrievers) | Very High (Pipelines, Components) | Very High (Kernel, Plugins) |
| **Community/Docs** | Very Large | Large, RAG-focused | Good, Growing | Good, Microsoft-backed |
| **Personal KB Fit** | Good, flexible | **Excellent**, RAG-first design simplifies setup | Very Good, robust for evolving needs | Good, esp. if agent needs more than RAG |
**Primary Recommendation: LlamaIndex**
* **Reasons:**
* **RAG-Centric Design:** LlamaIndex is built from the ground up for connecting custom data to LLMs, making RAG its core competency. This often leads to more intuitive and quicker setup for RAG-specific tasks.
* **Ease of Use for Core RAG:** High-level abstractions for indexing, retrieval, and query engines simplify building a basic-to-intermediate personal RAG.
* **Advanced RAG Features:** Offers a rich set of modules for advanced RAG techniques (e.g., various retrievers, node postprocessors/rerankers, query transformations) that can be incrementally added.
* **Strong Local Ecosystem:** Excellent support for local embedding models, LLMs (via Ollama, HuggingFace), and local vector stores.
* **Data Ingestion Flexibility:** While no direct `crawl4ai` loader exists *yet*, creating one or simply passing `crawl4ai`'s text output to LlamaIndex's `Document` objects for indexing is straightforward.
* **Python Native:** Aligns well with `crawl4ai` and common data science/AI workflows.
**Runner-Up Recommendation: LangChain**
* **Reasons:**
* **Mature and Flexible:** A very versatile framework with a vast ecosystem of integrations.
* **Native `Crawl4aiLoader`:** Simplifies the initial data ingestion step from `crawl4ai`.
* **Strong for Complex Chains/Agents:** If the AI agent's capabilities extend significantly beyond RAG, LangChain's strengths in building complex chains and agents become more prominent.
* **Large Community:** Extensive documentation, tutorials, and community support.
* Can achieve everything LlamaIndex can for RAG, but sometimes with more boilerplate or a less RAG-specific API.
**Why not Haystack or Semantic Kernel as primary for *this specific* goal?**
* **Haystack:** Very powerful and modular, excellent for production and complex pipelines. However, for a *personal* KB, its explicitness in pipeline definition might be slightly more setup than LlamaIndex for common RAG patterns. It's a strong contender if the personal RAG is expected to become very complex or integrate deeply with other Haystack-specific tooling.
* **Semantic Kernel:** Its primary strength lies in agentic AI, planning, and function calling. While RAG is achievable via its "Memory" and plugin system, it's more of a capability you build *into* a Semantic Kernel agent rather than the framework's central focus. If the AI agent's core is complex task execution and RAG is just one tool, SK is excellent. If RAG *is* the core, LlamaIndex or LangChain might be more direct.
### III. Recommended Local Vector Database
**Primary Recommendation: ChromaDB**
* **Reasons:**
* **Ease of Use & Setup:** `pip install chromadb`, runs in-memory by default or can persist to disk easily. Very developer-friendly.
* **Python-Native:** Designed with Python applications in mind.
* **Good Integration:** Well-supported by LlamaIndex and LangChain.
* **Sufficient for Personal KB:** Scales well enough for typical personal knowledge base sizes.
* **Metadata Filtering:** Supports filtering by metadata, which is crucial for a personal KB (e.g., by source, date, tags).
**Runner-Up Recommendation: SQLite-VSS**
* **Reasons:**
* **Simplicity of SQLite:** If already using SQLite for other application data, adding vector search via an extension is very convenient. No separate database server to manage.
* **Good Enough Performance:** For many personal KB use cases, performance will be adequate.
* **Growing Ecosystem:** Gaining traction and support in frameworks.
**Why not others for initial setup?**
* **FAISS:** A library, not a full database. Requires more manual setup for persistence and serving, though powerful for raw similarity search. Often used *under the hood* by other vector DBs or frameworks.
* **Qdrant/Weaviate:** More feature-rich and scalable, potentially overkill for a basic personal KB's initial setup. They are excellent choices if the KB grows very large or requires more advanced features not easily met by ChromaDB. They can be considered for a "version 2" of the personal RAG.
* **Scalability Considerations for a Unified PUKB:** Given that a "Personal Unified Knowledge Base" might grow significantly with diverse local files (text, PDFs, images, emails, code) and web scrapes, the scalability of the chosen vector database becomes more pertinent.
* **ChromaDB & SQLite-VSS:** While excellent for starting and for moderately sized KBs due to their ease of setup, their performance might degrade with many millions of diverse vectors or very complex metadata filtering if the PUKB becomes extremely large. SQLite-VSS, being embedded, also shares resources with the main application.
* **Qdrant & Weaviate:** These are designed for larger scale and offer more advanced features like optimized filtering, quantization, and potentially better performance under heavy load or with massive datasets. They typically require running as separate services (often via Docker), which adds a small layer of setup complexity compared to embedded solutions.
* **Recommendation Adjustment:** For a PUKB envisioned to be *very large and diverse from the outset*, or if initial prototypes with ChromaDB/SQLite-VSS show performance bottlenecks with representative data volumes, considering Qdrant or Weaviate *earlier* in the development lifecycle (perhaps as an alternative for Phase 1 or a direct step into Phase 2 for database setup) would be prudent. The trade-off is initial simplicity versus future-proofing for scale and advanced features. Migration later is possible but involves effort.
### IV. Recommended Initial RAG Techniques
Start with a solid foundation and iteratively add complexity:
1. **Data Ingestion (`crawl4ai` + Manual):**
* Use `crawl4ai` to scrape web content.
* Develop a simple way to add manual notes (e.g., Markdown files in a directory).
2. **Chunking Strategy:**
* **Start with:** Recursive Character Text Splitting (available in LlamaIndex/LangChain) with sensible chunk size (e.g., 500-1000 tokens) and overlap (e.g., 50-100 tokens).
* **Consider for V2:** Semantic Chunking or Proposition-based chunking for more meaningful segments, especially for diverse personal data.
3. **Embedding Model:**
* **Start with:** A good open-source sentence-transformer model (e.g., `all-MiniLM-L6-v2` for a balance of speed and quality, or a model from the MTEB leaderboard). LlamaIndex/LangChain make these easy to use.
* **Alternative:** If API costs are not a concern, OpenAI's `text-embedding-3-small` is a strong performer.
4. **Vector Storage:**
* ChromaDB (persisted to disk).
5. **Retrieval Strategy:**
* **Start with:** Basic semantic similarity search (top-k retrieval).
* **Add Early:** Metadata filtering (e.g., filter by source URL, date added, custom tags).
6. **Reranking:**
* **Consider adding soon after basic retrieval:** A simple reranker like a Cross-Encoder (e.g., `ms-marco-MiniLM-L-6-v2`) to improve the relevance of top-k results before sending to the LLM. LlamaIndex has `SentenceTransformerRerank`.
7. **LLM for Generation:**
* A local LLM via Ollama (e.g., Llama 3, Mistral) or a smaller, efficient model.
* Or an API-based model if preferred (OpenAI, Anthropic).
8. **Prompting:**
* Standard RAG prompt: "Use the following context to answer the question. Context: {context_str} Question: {query_str} Answer:"
### V. Phased Implementation Approach
1. **Phase 1: Core RAG Pipeline Setup**
* Choose framework (LlamaIndex recommended).
* Set up `crawl4ai` for data ingestion from a few sample websites.
* Implement basic chunking, embedding (local model), and storage in ChromaDB.
* Build a simple query engine for semantic search.
* Integrate a local LLM for generation.
* Test with basic queries.
2. **Phase 2: Enhancing Retrieval Quality**
* Implement metadata storage and filtering.
* Add a reranking step (e.g., Cross-Encoder).
* Experiment with different chunking strategies and embedding models.
* Develop a simple way to add/update manual notes.
3. **Phase 3: Advanced Features & Agent Integration**
* Explore more advanced retrieval techniques if needed (e.g., HyDE, fusion retrieval if keyword search is important).
* Consider query transformations.
* Integrate the RAG system as a tool for the AI agent.
* Develop a basic UI or CLI for interaction.
* Start thinking about evaluation (even if manual).
4. **Phase 4: Long-Term Enhancements**
* Explore techniques like CRAG or RAPTOR if complexity is justified.
* Implement more robust update/synchronization mechanisms for the knowledge base.
* Consider GraphRAG if relationships between personal data points become important.
### VI. Addressing Work vs. Personal Knowledge Base
* The recommendation to start with a separate personal KB allows focused development.
* If a unified system is desired later:
* **Data Separation:** Use distinct metadata tags (e.g., `source_type: "work"` vs. `source_type: "personal"`) within a single vector store.
* **Query-Time Filtering:** Ensure queries to the "work" aspect only retrieve from work-tagged documents, and vice-versa. This is critical.
* **Agent Context:** The AI agent must be aware of which "mode" it's in (work or personal) to apply the correct filters.
* This adds complexity but is feasible with careful design in frameworks like LlamaIndex or LangChain.
This synthesis provides a roadmap. The next practical step would be to start implementing Phase 1.

73
router/__init__.py Normal file
View file

@ -0,0 +1,73 @@
"""AI Router - Clean abstraction for AI providers."""
from .base import AIRouter, RouterResponse
from .config import RouterConfig, config
from .exceptions import (
RouterError,
AuthenticationError,
RateLimitError,
ModelNotFoundError,
InvalidRequestError,
TimeoutError,
ContentFilterError,
QuotaExceededError,
ProviderError,
ConfigurationError,
)
from .gemini import Gemini
from .openai_compatible import OpenAI, OpenAICompatible
from .cohere import Cohere
from .rerank import Rerank, CohereRerank, RerankDocument, RerankResult, RerankResponse
from .embed import (
Embed,
CohereEmbed,
GeminiEmbed,
OllamaEmbed,
CohereEmbedding,
GeminiEmbedding,
OllamaEmbedding,
EmbeddingResponse,
create_embedding,
)
__version__ = "0.1.0"
__all__ = [
# Base classes
"AIRouter",
"RouterResponse",
"RouterConfig",
"config",
# Generation Providers
"Gemini",
"OpenAI",
"OpenAICompatible",
"Cohere",
# Reranking
"Rerank",
"CohereRerank",
"RerankDocument",
"RerankResult",
"RerankResponse",
# Embeddings
"Embed",
"CohereEmbed",
"GeminiEmbed",
"OllamaEmbed",
"CohereEmbedding",
"GeminiEmbedding",
"OllamaEmbedding",
"EmbeddingResponse",
"create_embedding",
# Exceptions
"RouterError",
"AuthenticationError",
"RateLimitError",
"ModelNotFoundError",
"InvalidRequestError",
"TimeoutError",
"ContentFilterError",
"QuotaExceededError",
"ProviderError",
"ConfigurationError",
]

178
router/base.py Normal file
View file

@ -0,0 +1,178 @@
"""Base abstract class for AI router implementations."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union
import time
@dataclass
class RouterResponse:
"""Standardized response object for all AI providers."""
content: str
model: str
provider: str
latency: float = 0.0
input_tokens: Optional[int] = None
output_tokens: Optional[int] = None
total_tokens: Optional[int] = None
cost: Optional[float] = None
finish_reason: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
raw_response: Optional[Any] = None
def __str__(self) -> str:
return self.content
class AIRouter(ABC):
"""Abstract base class for AI provider routers."""
def __init__(
self,
model: str,
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize the router with model and configuration.
Args:
model: Model identifier for the provider
api_key: API key for authentication (optional if set in environment)
**kwargs: Additional provider-specific configuration
"""
self.model = model
self.api_key = api_key
self.config = kwargs
self.provider = self.__class__.__name__.lower()
@abstractmethod
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
"""Prepare provider-specific request parameters.
Args:
prompt: User prompt
**kwargs: Additional parameters
Returns:
Dictionary of request parameters
"""
pass
@abstractmethod
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
"""Parse provider response into RouterResponse.
Args:
raw_response: Raw response from provider
latency: Request latency in seconds
Returns:
Standardized RouterResponse
"""
pass
@abstractmethod
def _make_request(self, request_params: Dict[str, Any]) -> Any:
"""Make synchronous request to provider.
Args:
request_params: Provider-specific request parameters
Returns:
Raw provider response
"""
pass
@abstractmethod
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
"""Make asynchronous request to provider.
Args:
request_params: Provider-specific request parameters
Returns:
Raw provider response
"""
pass
def call(self, prompt: str, **kwargs: Any) -> RouterResponse:
"""Make a synchronous call to the AI provider.
Args:
prompt: User prompt
**kwargs: Additional parameters (model, temperature, max_tokens, etc.)
Returns:
RouterResponse with the result
"""
# Merge default config with call-specific parameters
params = {**self.config, **kwargs}
# Prepare request
request_params = self._prepare_request(prompt, **params)
# Make request and measure latency
start_time = time.time()
raw_response = self._make_request(request_params)
latency = time.time() - start_time
# Parse and return response
return self._parse_response(raw_response, latency)
async def acall(self, prompt: str, **kwargs: Any) -> RouterResponse:
"""Make an asynchronous call to the AI provider.
Args:
prompt: User prompt
**kwargs: Additional parameters (model, temperature, max_tokens, etc.)
Returns:
RouterResponse with the result
"""
# Merge default config with call-specific parameters
params = {**self.config, **kwargs}
# Prepare request
request_params = self._prepare_request(prompt, **params)
# Make request and measure latency
start_time = time.time()
raw_response = await self._make_async_request(request_params)
latency = time.time() - start_time
# Parse and return response
return self._parse_response(raw_response, latency)
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
"""Stream responses from the AI provider.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
Raises:
NotImplementedError: If streaming is not supported
"""
raise NotImplementedError(f"{self.provider} does not support streaming yet")
async def astream(
self, prompt: str, **kwargs: Any
) -> AsyncGenerator[RouterResponse, None]:
"""Asynchronously stream responses from the AI provider.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
Raises:
NotImplementedError: If async streaming is not supported
"""
raise NotImplementedError(f"{self.provider} does not support async streaming yet")
yield # Required for async generator type hint

259
router/cohere.py Normal file
View file

@ -0,0 +1,259 @@
"""Cohere provider implementation."""
from typing import Any, Dict, Optional, AsyncGenerator, Generator
import time
from .base import AIRouter, RouterResponse
from .config import config
from .exceptions import (
ConfigurationError,
map_provider_error
)
class Cohere(AIRouter):
"""Router for Cohere models."""
def __init__(
self,
model: str = "command-r-plus",
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Cohere router.
Args:
model: Cohere model to use (command-r-plus, command-r, etc.)
api_key: Cohere API key (optional if set in environment)
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("cohere")
if not api_key:
raise ConfigurationError(
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
provider="cohere"
)
super().__init__(model=model, api_key=api_key, **kwargs)
# Initialize Cohere client
try:
import cohere
# For v5.15.0, use the standard Client
self.client = cohere.Client(api_key)
self.async_client = cohere.AsyncClient(api_key)
except ImportError:
raise ConfigurationError(
"cohere package not installed. Install with: pip install cohere",
provider="cohere"
)
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
"""Prepare Cohere request parameters.
Args:
prompt: User prompt
**kwargs: Additional parameters
Returns:
Request parameters
"""
# Build request parameters for v5.15.0 API
params = {
"model": kwargs.get("model", self.model),
"message": prompt, # v5.15.0 uses 'message' not 'messages'
}
# Add optional parameters
if "temperature" in kwargs:
params["temperature"] = kwargs["temperature"]
elif hasattr(config, "default_temperature"):
params["temperature"] = config.default_temperature
if "max_tokens" in kwargs:
params["max_tokens"] = kwargs["max_tokens"]
elif hasattr(config, "default_max_tokens") and config.default_max_tokens:
params["max_tokens"] = config.default_max_tokens
# Cohere uses 'p' instead of 'top_p'
if "p" in kwargs:
params["p"] = kwargs["p"]
elif "top_p" in kwargs:
params["p"] = kwargs["top_p"]
elif hasattr(config, "default_top_p"):
params["p"] = config.default_top_p
# Cohere uses 'k' instead of 'top_k'
if "k" in kwargs:
params["k"] = kwargs["k"]
elif "top_k" in kwargs:
params["k"] = kwargs["top_k"]
# Other Cohere-specific parameters for v5.15.0
for key in ["chat_history", "preamble", "conversation_id", "prompt_truncation", "connectors", "search_queries_only", "documents", "tools", "tool_results", "stop_sequences", "seed"]:
if key in kwargs:
params[key] = kwargs[key]
return params
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
"""Parse Cohere response into RouterResponse.
Args:
raw_response: Raw Cohere response
latency: Request latency
Returns:
RouterResponse
"""
# Extract text content - v5.15.0 has 'text' attribute directly
content = getattr(raw_response, "text", "")
# Get token counts from meta/usage
input_tokens = None
output_tokens = None
# Try different token count locations based on v5.15.0 structure
if hasattr(raw_response, "meta") and hasattr(raw_response.meta, "billed_units"):
billed = raw_response.meta.billed_units
input_tokens = getattr(billed, "input_tokens", None)
output_tokens = getattr(billed, "output_tokens", None)
elif hasattr(raw_response, "usage") and hasattr(raw_response.usage, "billed_units"):
billed = raw_response.usage.billed_units
input_tokens = getattr(billed, "input_tokens", None)
output_tokens = getattr(billed, "output_tokens", None)
# Calculate cost if available
cost = None
if input_tokens and output_tokens and config.track_costs:
cost = config.calculate_cost(self.model, input_tokens, output_tokens)
# Extract finish reason
finish_reason = "stop"
if hasattr(raw_response, "finish_reason"):
finish_reason = raw_response.finish_reason
return RouterResponse(
content=content,
model=self.model,
provider="cohere",
latency=latency,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=(input_tokens + output_tokens) if input_tokens and output_tokens else None,
cost=cost,
finish_reason=finish_reason,
metadata={
"id": getattr(raw_response, "id", None),
"generation_id": getattr(raw_response, "generation_id", None),
"citations": getattr(raw_response, "citations", None),
"documents": getattr(raw_response, "documents", None),
"search_results": getattr(raw_response, "search_results", None),
"search_queries": getattr(raw_response, "search_queries", None),
"tool_calls": getattr(raw_response, "tool_calls", None),
},
raw_response=raw_response
)
def _make_request(self, request_params: Dict[str, Any]) -> Any:
"""Make synchronous request to Cohere.
Args:
request_params: Request parameters
Returns:
Raw Cohere response
"""
try:
response = self.client.chat(**request_params)
return response
except Exception as e:
raise map_provider_error("cohere", e)
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
"""Make asynchronous request to Cohere.
Args:
request_params: Request parameters
Returns:
Raw Cohere response
"""
try:
response = await self.async_client.chat(**request_params)
return response
except Exception as e:
raise map_provider_error("cohere", e)
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
"""Stream responses from Cohere.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
try:
start_time = time.time()
stream = self.client.chat_stream(**request_params)
for event in stream:
# v5.15.0 uses event_type and text directly
if hasattr(event, "event_type") and event.event_type == "text-generation":
content = getattr(event, "text", "")
if content:
yield RouterResponse(
content=content,
model=self.model,
provider="cohere",
latency=time.time() - start_time,
finish_reason=None,
metadata={"event_type": event.event_type},
raw_response=event
)
except Exception as e:
raise map_provider_error("cohere", e)
async def astream(
self, prompt: str, **kwargs: Any
) -> AsyncGenerator[RouterResponse, None]:
"""Asynchronously stream responses from Cohere.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
try:
start_time = time.time()
stream = self.async_client.chat_stream(**request_params)
async for event in stream:
# v5.15.0 uses event_type and text directly
if hasattr(event, "event_type") and event.event_type == "text-generation":
content = getattr(event, "text", "")
if content:
yield RouterResponse(
content=content,
model=self.model,
provider="cohere",
latency=time.time() - start_time,
finish_reason=None,
metadata={"event_type": event.event_type},
raw_response=event
)
except Exception as e:
raise map_provider_error("cohere", e)

259
router/config.py Normal file
View file

@ -0,0 +1,259 @@
"""Configuration management for AI router."""
import os
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
@dataclass
class RouterConfig:
"""Configuration for AI router."""
# API Keys
openai_api_key: Optional[str] = field(default_factory=lambda: os.getenv("OPENAI_API_KEY"))
gemini_api_key: Optional[str] = field(default_factory=lambda: os.getenv("GEMINI_API_KEY"))
google_api_key: Optional[str] = field(default_factory=lambda: os.getenv("GOOGLE_API_KEY"))
cohere_api_key: Optional[str] = field(default_factory=lambda: os.getenv("COHERE_API_KEY"))
# Ollama configuration
ollama_base_url: str = field(default_factory=lambda: os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"))
# Available Ollama embedding models
ollama_embedding_models: List[str] = field(default_factory=lambda: [
"mxbai-embed-large:latest",
"nomic-embed-text:latest",
])
# Default model settings
default_temperature: float = 0.7
default_max_tokens: Optional[int] = None
default_top_p: float = 1.0
default_timeout: int = 160 # seconds
# Retry configuration
max_retries: int = 3
retry_delay: float = 1.0
retry_backoff: float = 2.0
# Cost tracking
track_costs: bool = False
# Cost per million tokens (prices in USD)
cost_per_1m_input_tokens: Dict[str, float] = field(default_factory=lambda: {
# OpenAI Compatible Models
"gpt-4o": 2.5,
"azure/gpt-4.1-new": 2.0,
"o3": 2.0,
"o4-mini": 1.1,
"claude-opus-4": 15.0,
"claude-sonnet-4": 3.0,
# Gemini Models
"gemini-2.5-pro": 1.25,
"gemini-2.5-flash": 0.3,
"gemini-2.0-flash-001": 0.3, # Alias for current default
# Cohere Models
"command-a": 2.5,
"command-r-plus": 2.5,
})
cost_per_1m_output_tokens: Dict[str, float] = field(default_factory=lambda: {
# OpenAI Compatible Models
"gpt-4o": 10.0,
"azure/gpt-4.1-new": 8.0,
"o3": 8.0,
"o4-mini": 4.4,
"claude-opus-4": 75.0,
"claude-sonnet-4": 15.0,
# Gemini Models
"gemini-2.5-pro": 10.0,
"gemini-2.5-flash": 2.5,
"gemini-2.0-flash-001": 2.5, # Alias for current default
# Cohere Models
"command-a": 10.0,
"command-r-plus": 10.0,
})
# Cached input pricing for specific models (per million tokens)
cost_per_1m_cached_input_tokens: Dict[str, float] = field(default_factory=lambda: {
"azure/gpt-4.1-new": 0.5,
"o3": 0.5,
"o4-mini": 0.28,
})
# Reranking costs (per 1k searches)
cost_per_1k_rerank_searches: Dict[str, float] = field(default_factory=lambda: {
"rerank-3.5": 2.0,
"rerank-english-v3.0": 2.0,
"rerank-multilingual-v3.0": 2.0,
})
# Embedding costs (per million tokens)
cost_per_1m_embed_tokens: Dict[str, float] = field(default_factory=lambda: {
"embed-english-v3.0": 0.12,
"embed-multilingual-v3.0": 0.12,
"embed-english-light-v3.0": 0.12,
"text-embedding-004": 0.12, # Google's embedding model
# Ollama models are free (local)
"mxbai-embed-large:latest": 0.0,
"nomic-embed-text:latest": 0.0,
"nomic-embed-text:137m-v1.5-fp16": 0.0,
})
# Image embedding costs (per million image tokens)
cost_per_1m_embed_image_tokens: Dict[str, float] = field(default_factory=lambda: {
"embed-english-v3.0": 0.47,
"embed-multilingual-v3.0": 0.47,
})
# Logging
log_requests: bool = False
log_responses: bool = False
log_errors: bool = True
def get_api_key(self, provider: str) -> Optional[str]:
"""Get API key for a specific provider.
Args:
provider: Provider name (openai, gemini, cohere, etc.)
Returns:
API key if available
"""
provider = provider.lower()
if provider == "openai":
return self.openai_api_key
elif provider in ["gemini", "google"]:
return self.gemini_api_key or self.google_api_key
elif provider == "cohere":
return self.cohere_api_key
return None
def calculate_cost(
self,
model: str,
input_tokens: int,
output_tokens: int,
cached_input: bool = False
) -> Optional[float]:
"""Calculate cost for a request.
Args:
model: Model identifier
input_tokens: Number of input tokens
output_tokens: Number of output tokens
cached_input: Whether input tokens are cached (for compatible models)
Returns:
Total cost in USD, or None if cost data not available
"""
if not self.track_costs:
return None
# Check if using cached pricing
if cached_input and model in self.cost_per_1m_cached_input_tokens:
input_cost_per_1m = self.cost_per_1m_cached_input_tokens.get(model)
else:
input_cost_per_1m = self.cost_per_1m_input_tokens.get(model)
output_cost_per_1m = self.cost_per_1m_output_tokens.get(model)
if input_cost_per_1m is None or output_cost_per_1m is None:
return None
input_cost = (input_tokens / 1_000_000) * input_cost_per_1m
output_cost = (output_tokens / 1_000_000) * output_cost_per_1m
return round(input_cost + output_cost, 6)
def calculate_rerank_cost(
self,
model: str,
num_searches: int
) -> Optional[float]:
"""Calculate cost for reranking.
Args:
model: Rerank model identifier
num_searches: Number of searches performed
Returns:
Total cost in USD, or None if cost data not available
"""
if not self.track_costs:
return None
cost_per_1k = self.cost_per_1k_rerank_searches.get(model)
if cost_per_1k is None:
return None
return round((num_searches / 1000) * cost_per_1k, 6)
def calculate_embed_cost(
self,
model: str,
num_tokens: int,
is_image: bool = False
) -> Optional[float]:
"""Calculate cost for embeddings.
Args:
model: Embedding model identifier
num_tokens: Number of tokens to embed
is_image: Whether these are image tokens
Returns:
Total cost in USD, or None if cost data not available
"""
if not self.track_costs:
return None
if is_image:
cost_per_1m = self.cost_per_1m_embed_image_tokens.get(model)
else:
cost_per_1m = self.cost_per_1m_embed_tokens.get(model)
if cost_per_1m is None:
return None
return round((num_tokens / 1_000_000) * cost_per_1m, 6)
@classmethod
def from_env(cls) -> "RouterConfig":
"""Create config from environment variables.
Returns:
RouterConfig instance
"""
return cls()
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary.
Returns:
Dictionary representation
"""
return {
"openai_api_key": "***" if self.openai_api_key else None,
"gemini_api_key": "***" if self.gemini_api_key else None,
"google_api_key": "***" if self.google_api_key else None,
"cohere_api_key": "***" if self.cohere_api_key else None,
"default_temperature": self.default_temperature,
"default_max_tokens": self.default_max_tokens,
"default_top_p": self.default_top_p,
"default_timeout": self.default_timeout,
"max_retries": self.max_retries,
"retry_delay": self.retry_delay,
"retry_backoff": self.retry_backoff,
"track_costs": self.track_costs,
"log_requests": self.log_requests,
"log_responses": self.log_responses,
"log_errors": self.log_errors,
}
# Global config instance
config = RouterConfig.from_env()

742
router/embed.py Normal file
View file

@ -0,0 +1,742 @@
"""Embedding router implementation for multiple providers."""
from typing import List, Dict, Any, Optional, Union, Literal
from dataclasses import dataclass, field
import time
from abc import ABC, abstractmethod
from .config import config
from .exceptions import (
ConfigurationError,
map_provider_error
)
@dataclass
class EmbeddingResponse:
"""Response from embedding operation."""
embeddings: List[List[float]] # List of embedding vectors
model: str
provider: str
latency: float
dimension: int # Dimension of embeddings
num_inputs: int # Number of inputs embedded
total_tokens: Optional[int] = None
cost: Optional[float] = None
metadata: Dict[str, Any] = field(default_factory=dict)
raw_response: Optional[Any] = None
class BaseEmbedding(ABC):
"""Base class for embedding routers."""
def __init__(self, model: str, api_key: str, **kwargs: Any) -> None:
"""Initialize embedding router.
Args:
model: Model identifier
api_key: API key
**kwargs: Additional configuration
"""
self.model = model
self.api_key = api_key
self.config = kwargs
@abstractmethod
def embed(
self,
texts: Union[str, List[str]],
**kwargs: Any
) -> EmbeddingResponse:
"""Generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
pass
@abstractmethod
async def aembed(
self,
texts: Union[str, List[str]],
**kwargs: Any
) -> EmbeddingResponse:
"""Asynchronously generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
pass
class CohereEmbedding(BaseEmbedding):
"""Router for Cohere embedding models."""
def __init__(
self,
model: str = "embed-english-v3.0",
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Cohere embedding router.
Args:
model: Embedding model (embed-english-v3.0, embed-multilingual-v3.0, etc.)
api_key: Cohere API key (optional if set in environment)
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("cohere")
if not api_key:
raise ConfigurationError(
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
provider="cohere"
)
super().__init__(model, api_key, **kwargs)
# Initialize Cohere client
try:
import cohere
# For v5.15.0, use the standard Client
self.client = cohere.Client(api_key)
self.async_client = cohere.AsyncClient(api_key)
except ImportError:
raise ConfigurationError(
"cohere package not installed. Install with: pip install cohere",
provider="cohere"
)
def embed(
self,
texts: Union[str, List[str]],
input_type: Optional[Literal["search_document", "search_query", "classification", "clustering"]] = None,
truncate: Optional[Literal["NONE", "START", "END"]] = None,
**kwargs: Any
) -> EmbeddingResponse:
"""Generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
input_type: Purpose of embeddings (affects vector space)
truncate: How to handle inputs longer than max tokens
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
params = self._prepare_request(texts, input_type, truncate, **kwargs)
try:
start_time = time.time()
response = self.client.embed(**params)
latency = time.time() - start_time
return self._parse_response(response, latency, len(texts))
except Exception as e:
raise map_provider_error("cohere", e)
async def aembed(
self,
texts: Union[str, List[str]],
input_type: Optional[Literal["search_document", "search_query", "classification", "clustering"]] = None,
truncate: Optional[Literal["NONE", "START", "END"]] = None,
**kwargs: Any
) -> EmbeddingResponse:
"""Asynchronously generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
input_type: Purpose of embeddings
truncate: How to handle long inputs
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
params = self._prepare_request(texts, input_type, truncate, **kwargs)
try:
start_time = time.time()
response = await self.async_client.embed(**params)
latency = time.time() - start_time
return self._parse_response(response, latency, len(texts))
except Exception as e:
raise map_provider_error("cohere", e)
def _prepare_request(
self,
texts: List[str],
input_type: Optional[str],
truncate: Optional[str],
**kwargs: Any
) -> Dict[str, Any]:
"""Prepare embed request parameters."""
params = {
"model": kwargs.get("model", self.model),
"texts": texts,
}
if input_type:
params["input_type"] = input_type
if truncate:
params["truncate"] = truncate
return params
def _parse_response(
self,
raw_response: Any,
latency: float,
num_inputs: int
) -> EmbeddingResponse:
"""Parse Cohere embed response."""
# Extract embeddings
embeddings = []
if hasattr(raw_response, "embeddings"):
embeddings = raw_response.embeddings
# Get dimension from first embedding
dimension = len(embeddings[0]) if embeddings else 0
# Get token count from response metadata
total_tokens = None
if hasattr(raw_response, "meta") and hasattr(raw_response.meta, "billed_units"):
total_tokens = getattr(raw_response.meta.billed_units, "input_tokens", None)
# Calculate cost
cost = None
if config.track_costs and total_tokens:
cost = config.calculate_embed_cost(self.model, total_tokens)
return EmbeddingResponse(
embeddings=embeddings,
model=self.model,
provider="cohere",
latency=latency,
dimension=dimension,
num_inputs=num_inputs,
total_tokens=total_tokens,
cost=cost,
metadata={
"id": getattr(raw_response, "id", None),
"response_type": getattr(raw_response, "response_type", None),
},
raw_response=raw_response
)
class GeminiEmbedding(BaseEmbedding):
"""Router for Google Gemini embedding models."""
def __init__(
self,
model: str = "text-embedding-004",
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Gemini embedding router.
Args:
model: Embedding model (text-embedding-004, etc.)
api_key: Google API key (optional if set in environment)
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("gemini")
if not api_key:
raise ConfigurationError(
"Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.",
provider="gemini"
)
super().__init__(model, api_key, **kwargs)
# Initialize Gemini client using new google.genai library
try:
from google import genai
from google.genai import types
self.genai = genai
self.types = types
self.client = genai.Client(api_key=api_key)
except ImportError:
raise ConfigurationError(
"google-genai package not installed. Install with: pip install google-genai",
provider="gemini"
)
def embed(
self,
texts: Union[str, List[str]],
task_type: Optional[str] = None,
title: Optional[str] = None,
**kwargs: Any
) -> EmbeddingResponse:
"""Generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
task_type: Task type for embeddings
title: Optional title for context
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
try:
start_time = time.time()
# The Google genai SDK supports batch embedding
# Pass all texts at once for better performance
params = {
"model": kwargs.get("model", self.model),
"contents": texts, # Can pass multiple texts
}
# Add optional config
config_params = {}
if task_type:
config_params["task_type"] = task_type
if title:
config_params["title"] = title
if config_params:
params["config"] = self.types.EmbedContentConfig(**config_params)
response = self.client.models.embed_content(**params)
# Extract embeddings from response
# The response always has an 'embeddings' attribute (even for single text)
embeddings = []
# Check if embeddings exist
if response.embeddings is None:
raise ValueError("No embeddings returned in response")
# response.embeddings is always present
for emb in response.embeddings:
# ContentEmbedding objects have a 'values' attribute containing the float list
if hasattr(emb, "values") and emb.values is not None:
embeddings.append(list(emb.values))
elif hasattr(emb, "__iter__"):
# If the embedding is directly iterable (list-like)
try:
embeddings.append(list(emb))
except Exception as e:
print(f"Warning: Could not extract embedding values from {type(emb)}: {e}")
else:
print(f"Warning: Unknown embedding format: {type(emb)}, attributes: {dir(emb)}")
latency = time.time() - start_time
# Token counting is not directly available in the response
# Set to None for now
total_tokens = None
return self._create_response(
embeddings=embeddings,
latency=latency,
num_inputs=len(texts),
total_tokens=total_tokens
)
except Exception as e:
raise map_provider_error("gemini", e)
async def aembed(
self,
texts: Union[str, List[str]],
task_type: Optional[str] = None,
title: Optional[str] = None,
**kwargs: Any
) -> EmbeddingResponse:
"""Asynchronously generate embeddings for texts.
Args:
texts: Single text or list of texts to embed
task_type: Task type for embeddings
title: Optional title for context
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
try:
start_time = time.time()
# The Google genai SDK supports batch embedding
# Pass all texts at once for better performance
params = {
"model": kwargs.get("model", self.model),
"contents": texts, # Can pass multiple texts
}
# Add optional config
config_params = {}
if task_type:
config_params["task_type"] = task_type
if title:
config_params["title"] = title
if config_params:
params["config"] = self.types.EmbedContentConfig(**config_params)
response = await self.client.aio.models.embed_content(**params)
# Extract embeddings from response
# The response always has an 'embeddings' attribute (even for single text)
embeddings = []
# Check if embeddings exist
if response.embeddings is None:
raise ValueError("No embeddings returned in response")
# response.embeddings is always present
for emb in response.embeddings:
# ContentEmbedding objects have a 'values' attribute containing the float list
if hasattr(emb, "values") and emb.values is not None:
embeddings.append(list(emb.values))
elif hasattr(emb, "__iter__"):
# If the embedding is directly iterable (list-like)
try:
embeddings.append(list(emb))
except Exception as e:
print(f"Warning: Could not extract embedding values from {type(emb)}: {e}")
else:
print(f"Warning: Unknown embedding format: {type(emb)}, attributes: {dir(emb)}")
latency = time.time() - start_time
# Token counting is not directly available in the response
# Set to None for now
total_tokens = None
return self._create_response(
embeddings=embeddings,
latency=latency,
num_inputs=len(texts),
total_tokens=total_tokens
)
except Exception as e:
raise map_provider_error("gemini", e)
def _create_response(
self,
embeddings: List[List[float]],
latency: float,
num_inputs: int,
total_tokens: Optional[int] = None
) -> EmbeddingResponse:
"""Create embedding response."""
# Get dimension from first embedding
dimension = len(embeddings[0]) if embeddings else 0
# Calculate cost
cost = None
if config.track_costs and total_tokens:
cost = config.calculate_embed_cost(self.model, total_tokens)
return EmbeddingResponse(
embeddings=embeddings,
model=self.model,
provider="gemini",
latency=latency,
dimension=dimension,
num_inputs=num_inputs,
total_tokens=total_tokens,
cost=cost,
metadata={},
raw_response=None
)
class OllamaEmbedding(BaseEmbedding):
"""Router for Ollama local embedding models."""
def __init__(
self,
model: str = "nomic-embed-text:latest",
base_url: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Ollama embedding router.
Args:
model: Ollama embedding model (mxbai-embed-large, nomic-embed-text, etc.)
base_url: Ollama API base URL (default: http://localhost:11434)
**kwargs: Additional configuration
"""
# No API key needed for local Ollama
super().__init__(model, api_key="local", **kwargs)
# Get base URL from config or parameter
self.base_url = base_url or config.ollama_base_url
self.embeddings_url = f"{self.base_url}/api/embeddings"
# Check if Ollama is available
try:
import requests
self.requests = requests
except ImportError:
raise ConfigurationError(
"requests package not installed. Install with: pip install requests",
provider="ollama"
)
# For async support
try:
import httpx
self.httpx = httpx
except ImportError:
self.httpx = None # Async support is optional
def embed(
self,
texts: Union[str, List[str]],
**kwargs: Any
) -> EmbeddingResponse:
"""Generate embeddings for texts using Ollama.
Args:
texts: Single text or list of texts to embed
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
try:
start_time = time.time()
embeddings = []
# Ollama API currently handles one text at a time
for text in texts:
payload = {
"model": kwargs.get("model", self.model),
"prompt": text
}
try:
response = self.requests.post(
self.embeddings_url,
json=payload,
timeout=kwargs.get("timeout", 30)
)
response.raise_for_status()
data = response.json()
if "embedding" in data:
embeddings.append(data["embedding"])
else:
raise ValueError(f"No embedding in response: {data}")
except self.requests.exceptions.ConnectionError:
raise ConfigurationError(
f"Cannot connect to Ollama at {self.base_url}. "
"Make sure Ollama is running (ollama serve).",
provider="ollama"
)
except self.requests.exceptions.HTTPError as e:
if e.response.status_code == 404:
raise ValueError(
f"Model '{self.model}' not found. "
f"Pull it first with: ollama pull {self.model}"
)
raise
latency = time.time() - start_time
return self._create_response(
embeddings=embeddings,
latency=latency,
num_inputs=len(texts)
)
except Exception as e:
raise map_provider_error("ollama", e)
async def aembed(
self,
texts: Union[str, List[str]],
**kwargs: Any
) -> EmbeddingResponse:
"""Asynchronously generate embeddings for texts using Ollama.
Args:
texts: Single text or list of texts to embed
**kwargs: Additional parameters
Returns:
EmbeddingResponse
"""
if self.httpx is None:
raise ConfigurationError(
"httpx package not installed for async support. Install with: pip install httpx",
provider="ollama"
)
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
try:
start_time = time.time()
embeddings = []
async with self.httpx.AsyncClient() as client:
# Process texts one at a time (Ollama limitation)
for text in texts:
payload = {
"model": kwargs.get("model", self.model),
"prompt": text
}
try:
response = await client.post(
self.embeddings_url,
json=payload,
timeout=kwargs.get("timeout", 30)
)
response.raise_for_status()
data = response.json()
if "embedding" in data:
embeddings.append(data["embedding"])
else:
raise ValueError(f"No embedding in response: {data}")
except self.httpx.ConnectError:
raise ConfigurationError(
f"Cannot connect to Ollama at {self.base_url}. "
"Make sure Ollama is running (ollama serve).",
provider="ollama"
)
except self.httpx.HTTPStatusError as e:
if e.response.status_code == 404:
raise ValueError(
f"Model '{self.model}' not found. "
f"Pull it first with: ollama pull {self.model}"
)
raise
latency = time.time() - start_time
return self._create_response(
embeddings=embeddings,
latency=latency,
num_inputs=len(texts)
)
except Exception as e:
raise map_provider_error("ollama", e)
def _create_response(
self,
embeddings: List[List[float]],
latency: float,
num_inputs: int
) -> EmbeddingResponse:
"""Create embedding response."""
# Get dimension from first embedding
dimension = len(embeddings[0]) if embeddings else 0
# No cost for local models
cost = 0.0 if config.track_costs else None
return EmbeddingResponse(
embeddings=embeddings,
model=self.model,
provider="ollama",
latency=latency,
dimension=dimension,
num_inputs=num_inputs,
total_tokens=None, # Ollama doesn't provide token counts
cost=cost,
metadata={
"base_url": self.base_url,
"local": True
},
raw_response=None
)
# Factory function to create embedding instances
def create_embedding(
provider: str = "cohere",
model: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs: Any
) -> BaseEmbedding:
"""Create an embedding instance based on provider.
Args:
provider: Embedding provider (cohere, gemini)
model: Model to use (provider-specific)
api_key: API key (optional if set in environment)
**kwargs: Additional configuration
Returns:
BaseEmbedding instance
"""
provider = provider.lower()
if provider == "cohere":
return CohereEmbedding(
model=model or "embed-english-v3.0",
api_key=api_key,
**kwargs
)
elif provider in ["gemini", "google"]:
return GeminiEmbedding(
model=model or "text-embedding-004",
api_key=api_key,
**kwargs
)
elif provider == "ollama":
return OllamaEmbedding(
model=model or "nomic-embed-text:latest",
base_url=kwargs.pop("base_url", None), # Extract base_url from kwargs
**kwargs
)
else:
raise ValueError(f"Unknown embedding provider: {provider}")
# Convenience aliases
Embed = create_embedding
CohereEmbed = CohereEmbedding
GeminiEmbed = GeminiEmbedding
OllamaEmbed = OllamaEmbedding

172
router/exceptions.py Normal file
View file

@ -0,0 +1,172 @@
"""Custom exceptions for AI router."""
from typing import Optional, Any
class RouterError(Exception):
"""Base exception for all router errors."""
def __init__(
self,
message: str,
provider: Optional[str] = None,
original_error: Optional[Exception] = None,
**kwargs: Any
) -> None:
"""Initialize router error.
Args:
message: Error message
provider: Provider that raised the error
original_error: Original exception from provider
**kwargs: Additional error details
"""
super().__init__(message)
self.provider = provider
self.original_error = original_error
self.details = kwargs
def __str__(self) -> str:
"""String representation of error."""
base_msg = super().__str__()
if self.provider:
base_msg = f"[{self.provider}] {base_msg}"
if self.original_error:
base_msg += f" (Original: {self.original_error})"
return base_msg
class AuthenticationError(RouterError):
"""Raised when authentication fails."""
pass
class RateLimitError(RouterError):
"""Raised when rate limit is exceeded."""
def __init__(
self,
message: str,
retry_after: Optional[float] = None,
**kwargs: Any
) -> None:
"""Initialize rate limit error.
Args:
message: Error message
retry_after: Seconds to wait before retry
**kwargs: Additional error details
"""
super().__init__(message, **kwargs)
self.retry_after = retry_after
class ModelNotFoundError(RouterError):
"""Raised when requested model is not available."""
def __init__(
self,
message: str,
model: str,
available_models: Optional[list] = None,
**kwargs: Any
) -> None:
"""Initialize model not found error.
Args:
message: Error message
model: Requested model
available_models: List of available models
**kwargs: Additional error details
"""
super().__init__(message, **kwargs)
self.model = model
self.available_models = available_models or []
class InvalidRequestError(RouterError):
"""Raised when request parameters are invalid."""
pass
class TimeoutError(RouterError):
"""Raised when request times out."""
pass
class ContentFilterError(RouterError):
"""Raised when content is blocked by safety filters."""
pass
class QuotaExceededError(RouterError):
"""Raised when API quota is exceeded."""
pass
class ProviderError(RouterError):
"""Raised for provider-specific errors."""
pass
class ConfigurationError(RouterError):
"""Raised for configuration errors."""
pass
def map_provider_error(provider: str, error: Exception) -> RouterError:
"""Map provider-specific errors to router errors.
Args:
provider: Provider name
error: Original exception
Returns:
Appropriate RouterError subclass
"""
error_message = str(error)
error_type = type(error).__name__
# Common patterns across providers
if any(x in error_message.lower() for x in ["unauthorized", "invalid api key", "authentication"]):
return AuthenticationError(
f"Authentication failed: {error_message}",
provider=provider,
original_error=error
)
if any(x in error_message.lower() for x in ["rate limit", "too many requests", "quota exceeded"]):
return RateLimitError(
f"Rate limit exceeded: {error_message}",
provider=provider,
original_error=error
)
if any(x in error_message.lower() for x in ["model not found", "invalid model", "unknown model"]):
return ModelNotFoundError(
f"Model not found: {error_message}",
model="", # Extract from error if possible
provider=provider,
original_error=error
)
if any(x in error_message.lower() for x in ["timeout", "timed out"]):
return TimeoutError(
f"Request timed out: {error_message}",
provider=provider,
original_error=error
)
if any(x in error_message.lower() for x in ["content filter", "safety", "blocked"]):
return ContentFilterError(
f"Content blocked by safety filters: {error_message}",
provider=provider,
original_error=error
)
# Default to provider error
return ProviderError(
f"{error_type}: {error_message}",
provider=provider,
original_error=error
)

242
router/gemini.py Normal file
View file

@ -0,0 +1,242 @@
"""Google Gemini provider implementation."""
from typing import Any, Dict, Optional, AsyncGenerator, Generator
import time
from .base import AIRouter, RouterResponse
from .config import config
from .exceptions import (
AuthenticationError,
ConfigurationError,
map_provider_error
)
class Gemini(AIRouter):
"""Router for Google Gemini models."""
def __init__(
self,
model: str = "gemini-2.0-flash-001",
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Gemini router.
Args:
model: Gemini model to use (gemini-2.0-flash-001, gemini-1.5-pro, etc.)
api_key: Google API key (optional if set in environment)
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("gemini")
if not api_key:
raise ConfigurationError(
"Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.",
provider="gemini"
)
super().__init__(model=model, api_key=api_key, **kwargs)
# Initialize Gemini client using new google.genai library
try:
from google import genai
from google.genai import types
self.genai = genai
self.types = types
self.client = genai.Client(api_key=api_key)
except ImportError:
raise ConfigurationError(
"google-genai package not installed. Install with: pip install google-genai",
provider="gemini"
)
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
"""Prepare Gemini request parameters.
Args:
prompt: User prompt
**kwargs: Additional parameters
Returns:
Request parameters
"""
# Build config using new API structure
config_params = {}
if "temperature" in kwargs:
config_params["temperature"] = kwargs["temperature"]
elif hasattr(config, "default_temperature"):
config_params["temperature"] = config.default_temperature
if "max_tokens" in kwargs:
config_params["max_output_tokens"] = kwargs["max_tokens"]
elif "max_output_tokens" in kwargs:
config_params["max_output_tokens"] = kwargs["max_output_tokens"]
if "top_p" in kwargs:
config_params["top_p"] = kwargs["top_p"]
elif hasattr(config, "default_top_p"):
config_params["top_p"] = config.default_top_p
if "top_k" in kwargs:
config_params["top_k"] = kwargs["top_k"]
# Add safety settings if provided
if "safety_settings" in kwargs:
config_params["safety_settings"] = kwargs["safety_settings"]
return {
"model": kwargs.get("model", self.model),
"contents": prompt,
"config": self.types.GenerateContentConfig(**config_params) if config_params else None
}
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
"""Parse Gemini response into RouterResponse.
Args:
raw_response: Raw Gemini response
latency: Request latency
Returns:
RouterResponse
"""
# Extract text content
content = raw_response.text if hasattr(raw_response, "text") else ""
# Try to get token counts from usage_metadata
input_tokens = None
output_tokens = None
total_tokens = None
if hasattr(raw_response, "usage_metadata"):
usage = raw_response.usage_metadata
# Handle both old and new attribute names
input_tokens = getattr(usage, "prompt_token_count", None) or getattr(usage, "cached_content_token_count", None)
output_tokens = getattr(usage, "candidates_token_count", None)
total_tokens = getattr(usage, "total_token_count", None)
# Calculate cost if available
cost = None
if input_tokens and output_tokens and config.track_costs:
cost = config.calculate_cost(self.model, input_tokens, output_tokens)
# Extract finish reason
finish_reason = "stop"
if hasattr(raw_response, "candidates") and raw_response.candidates:
candidate = raw_response.candidates[0]
finish_reason = getattr(candidate, "finish_reason", "stop")
return RouterResponse(
content=content,
model=raw_response.model_version if hasattr(raw_response, "model_version") else self.model,
provider="gemini",
latency=latency,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens or ((input_tokens + output_tokens) if input_tokens and output_tokens else None),
cost=cost,
finish_reason=finish_reason,
metadata={
"prompt_feedback": getattr(raw_response, "prompt_feedback", None),
"safety_ratings": getattr(raw_response.candidates[0], "safety_ratings", None) if hasattr(raw_response, "candidates") and raw_response.candidates else None
},
raw_response=raw_response
)
def _make_request(self, request_params: Dict[str, Any]) -> Any:
"""Make synchronous request to Gemini.
Args:
request_params: Request parameters
Returns:
Raw Gemini response
"""
try:
response = self.client.models.generate_content(**request_params)
return response
except Exception as e:
raise map_provider_error("gemini", e)
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
"""Make asynchronous request to Gemini.
Args:
request_params: Request parameters
Returns:
Raw Gemini response
"""
try:
response = await self.client.aio.models.generate_content(**request_params)
return response
except Exception as e:
raise map_provider_error("gemini", e)
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
"""Stream responses from Gemini.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
try:
start_time = time.time()
stream = self.client.models.generate_content_stream(**request_params)
for chunk in stream:
if chunk.text:
yield RouterResponse(
content=chunk.text,
model=self.model,
provider="gemini",
latency=time.time() - start_time,
finish_reason=None,
metadata={},
raw_response=chunk
)
except Exception as e:
raise map_provider_error("gemini", e)
async def astream(
self, prompt: str, **kwargs: Any
) -> AsyncGenerator[RouterResponse, None]:
"""Asynchronously stream responses from Gemini.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
try:
start_time = time.time()
stream = await self.client.aio.models.generate_content_stream(**request_params)
async for chunk in stream:
if chunk.text:
yield RouterResponse(
content=chunk.text,
model=self.model,
provider="gemini",
latency=time.time() - start_time,
finish_reason=None,
metadata={},
raw_response=chunk
)
except Exception as e:
raise map_provider_error("gemini", e)

280
router/openai_compatible.py Normal file
View file

@ -0,0 +1,280 @@
"""OpenAI-compatible provider implementation."""
from typing import Any, Dict, Optional, AsyncGenerator, Generator
import time
from .base import AIRouter, RouterResponse
from .config import config
from .exceptions import (
ConfigurationError,
map_provider_error
)
class OpenAICompatible(AIRouter):
"""Router for OpenAI and compatible APIs (OpenAI, Azure OpenAI, etc.)."""
def __init__(
self,
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
organization: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize OpenAI-compatible router.
Args:
model: Model to use (gpt-4, gpt-3.5-turbo, etc.)
api_key: API key (optional if set in environment)
base_url: Base URL for API (for Azure or custom endpoints)
organization: Organization ID for OpenAI
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("openai")
if not api_key:
raise ConfigurationError(
"OpenAI API key not found. Set OPENAI_API_KEY environment variable.",
provider="openai"
)
super().__init__(model=model, api_key=api_key, **kwargs)
# Store additional config
self.base_url = base_url
self.organization = organization
# Initialize OpenAI client
try:
from openai import OpenAI, AsyncOpenAI
# Build client kwargs with proper types
client_kwargs: Dict[str, Any] = {
"api_key": api_key,
}
if base_url:
client_kwargs["base_url"] = base_url
if organization:
client_kwargs["organization"] = organization
# Add any additional client configuration from kwargs
# Note: Only pass through valid OpenAI client parameters
valid_client_params = ["timeout", "max_retries", "default_headers", "default_query", "http_client"]
for param in valid_client_params:
if param in kwargs:
client_kwargs[param] = kwargs.pop(param)
self.client = OpenAI(**client_kwargs)
self.async_client = AsyncOpenAI(**client_kwargs)
except ImportError:
raise ConfigurationError(
"openai package not installed. Install with: pip install openai",
provider="openai"
)
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
"""Prepare OpenAI request parameters.
Args:
prompt: User prompt
**kwargs: Additional parameters
Returns:
Request parameters
"""
# Build messages
messages = kwargs.get("messages", [
{"role": "user", "content": prompt}
])
# Build request parameters
params = {
"model": kwargs.get("model", self.model),
"messages": messages,
}
# Add optional parameters
if "temperature" in kwargs:
params["temperature"] = kwargs["temperature"]
elif hasattr(config, "default_temperature"):
params["temperature"] = config.default_temperature
if "max_tokens" in kwargs:
params["max_tokens"] = kwargs["max_tokens"]
elif hasattr(config, "default_max_tokens") and config.default_max_tokens:
params["max_tokens"] = config.default_max_tokens
if "top_p" in kwargs:
params["top_p"] = kwargs["top_p"]
elif hasattr(config, "default_top_p"):
params["top_p"] = config.default_top_p
# Other OpenAI-specific parameters
for key in ["n", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "user", "seed", "tools", "tool_choice", "response_format", "logprobs", "top_logprobs", "parallel_tool_calls"]:
if key in kwargs:
params[key] = kwargs[key]
return params
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
"""Parse OpenAI response into RouterResponse.
Args:
raw_response: Raw OpenAI response
latency: Request latency
Returns:
RouterResponse
"""
# Extract first choice
choice = raw_response.choices[0]
content = choice.message.content or ""
# Get token counts
usage = raw_response.usage
input_tokens = usage.prompt_tokens if usage else None
output_tokens = usage.completion_tokens if usage else None
total_tokens = usage.total_tokens if usage else None
# Calculate cost if available
cost = None
if input_tokens and output_tokens and config.track_costs:
cost = config.calculate_cost(raw_response.model, input_tokens, output_tokens)
return RouterResponse(
content=content,
model=raw_response.model,
provider="openai",
latency=latency,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
cost=cost,
finish_reason=choice.finish_reason,
metadata={
"id": raw_response.id,
"created": raw_response.created,
"system_fingerprint": getattr(raw_response, "system_fingerprint", None),
"tool_calls": getattr(choice.message, "tool_calls", None),
"function_call": getattr(choice.message, "function_call", None),
"logprobs": getattr(choice, "logprobs", None),
},
raw_response=raw_response
)
def _make_request(self, request_params: Dict[str, Any]) -> Any:
"""Make synchronous request to OpenAI.
Args:
request_params: Request parameters
Returns:
Raw OpenAI response
"""
try:
response = self.client.chat.completions.create(**request_params)
return response
except Exception as e:
raise map_provider_error("openai", e)
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
"""Make asynchronous request to OpenAI.
Args:
request_params: Request parameters
Returns:
Raw OpenAI response
"""
try:
response = await self.async_client.chat.completions.create(**request_params)
return response
except Exception as e:
raise map_provider_error("openai", e)
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
"""Stream responses from OpenAI.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
request_params["stream"] = True
try:
start_time = time.time()
stream = self.client.chat.completions.create(**request_params)
for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
choice = chunk.choices[0]
content = getattr(choice.delta, "content", None)
if content:
yield RouterResponse(
content=content,
model=chunk.model,
provider="openai",
latency=time.time() - start_time,
finish_reason=getattr(choice, "finish_reason", None),
metadata={
"chunk_id": chunk.id,
"tool_calls": getattr(choice.delta, "tool_calls", None),
"function_call": getattr(choice.delta, "function_call", None),
},
raw_response=chunk
)
except Exception as e:
raise map_provider_error("openai", e)
async def astream(
self, prompt: str, **kwargs: Any
) -> AsyncGenerator[RouterResponse, None]:
"""Asynchronously stream responses from OpenAI.
Args:
prompt: User prompt
**kwargs: Additional parameters
Yields:
RouterResponse chunks
"""
params = {**self.config, **kwargs}
request_params = self._prepare_request(prompt, **params)
request_params["stream"] = True
try:
start_time = time.time()
stream = await self.async_client.chat.completions.create(**request_params)
async for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
choice = chunk.choices[0]
content = getattr(choice.delta, "content", None)
if content:
yield RouterResponse(
content=content,
model=chunk.model,
provider="openai",
latency=time.time() - start_time,
finish_reason=getattr(choice, "finish_reason", None),
metadata={
"chunk_id": chunk.id,
"tool_calls": getattr(choice.delta, "tool_calls", None),
"function_call": getattr(choice.delta, "function_call", None),
},
raw_response=chunk
)
except Exception as e:
raise map_provider_error("openai", e)
# Convenience alias
OpenAI = OpenAICompatible

245
router/rerank.py Normal file
View file

@ -0,0 +1,245 @@
"""Reranking router implementation for Cohere Rerank models."""
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, field
import time
from .config import config
from .exceptions import (
ConfigurationError,
map_provider_error
)
@dataclass
class RerankDocument:
"""Document to be reranked."""
text: str
id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class RerankResult:
"""Single reranked result."""
index: int # Original index in the documents list
relevance_score: float
document: RerankDocument
@dataclass
class RerankResponse:
"""Response from reranking operation."""
results: List[RerankResult]
model: str
provider: str
latency: float
num_documents: int
cost: Optional[float] = None
metadata: Dict[str, Any] = field(default_factory=dict)
raw_response: Optional[Any] = None
class CohereRerank:
"""Router for Cohere Rerank models."""
def __init__(
self,
model: str = "rerank-english-v3.0",
api_key: Optional[str] = None,
**kwargs: Any
) -> None:
"""Initialize Cohere Rerank router.
Args:
model: Rerank model to use (rerank-3.5, rerank-english-v3.0, rerank-multilingual-v3.0)
api_key: Cohere API key (optional if set in environment)
**kwargs: Additional configuration
"""
# Get API key from config if not provided
if not api_key:
api_key = config.get_api_key("cohere")
if not api_key:
raise ConfigurationError(
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
provider="cohere"
)
self.model = model
self.api_key = api_key
self.config = kwargs
# Initialize Cohere client
try:
import cohere
# For v5.15.0, use the standard Client
self.client = cohere.Client(api_key)
self.async_client = cohere.AsyncClient(api_key)
except ImportError:
raise ConfigurationError(
"cohere package not installed. Install with: pip install cohere",
provider="cohere"
)
def rerank(
self,
query: str,
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
top_n: Optional[int] = None,
**kwargs: Any
) -> RerankResponse:
"""Rerank documents based on relevance to query.
Args:
query: The search query
documents: List of documents to rerank (strings, dicts, or RerankDocument objects)
top_n: Number of top results to return (None returns all)
**kwargs: Additional parameters (max_chunks_per_doc, return_documents, etc.)
Returns:
RerankResponse with reranked results
"""
params = self._prepare_request(query, documents, top_n, **kwargs)
try:
start_time = time.time()
response = self.client.rerank(**params)
latency = time.time() - start_time
return self._parse_response(response, latency, len(params["documents"]))
except Exception as e:
raise map_provider_error("cohere", e)
async def arerank(
self,
query: str,
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
top_n: Optional[int] = None,
**kwargs: Any
) -> RerankResponse:
"""Asynchronously rerank documents based on relevance to query.
Args:
query: The search query
documents: List of documents to rerank
top_n: Number of top results to return (None returns all)
**kwargs: Additional parameters
Returns:
RerankResponse with reranked results
"""
params = self._prepare_request(query, documents, top_n, **kwargs)
try:
start_time = time.time()
response = await self.async_client.rerank(**params)
latency = time.time() - start_time
return self._parse_response(response, latency, len(params["documents"]))
except Exception as e:
raise map_provider_error("cohere", e)
def _prepare_request(
self,
query: str,
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
top_n: Optional[int] = None,
**kwargs: Any
) -> Dict[str, Any]:
"""Prepare rerank request parameters.
Args:
query: The search query
documents: Documents to rerank
top_n: Number of results to return
**kwargs: Additional parameters
Returns:
Request parameters
"""
# Convert documents to the format expected by Cohere
formatted_docs = []
for i, doc in enumerate(documents):
if isinstance(doc, str):
formatted_docs.append({"text": doc})
elif isinstance(doc, RerankDocument):
formatted_docs.append({"text": doc.text})
elif isinstance(doc, dict):
# Assume dict has at least 'text' field
formatted_docs.append(doc)
else:
raise ValueError(f"Invalid document type at index {i}: {type(doc)}")
# Build request parameters
params = {
"model": kwargs.get("model", self.model),
"query": query,
"documents": formatted_docs,
}
if top_n is not None:
params["top_n"] = top_n
# Add optional parameters for v5.15.0
for key in ["max_chunks_per_doc", "return_documents", "rank_fields"]:
if key in kwargs:
params[key] = kwargs[key]
return params
def _parse_response(
self,
raw_response: Any,
latency: float,
num_documents: int
) -> RerankResponse:
"""Parse Cohere rerank response.
Args:
raw_response: Raw response from Cohere
latency: Request latency
num_documents: Total number of documents submitted
Returns:
RerankResponse
"""
results = []
# Parse results from v5.15.0 response format
if hasattr(raw_response, "results"):
for result in raw_response.results:
# Extract document info
doc_text = ""
if hasattr(result, "document") and hasattr(result.document, "text"):
doc_text = result.document.text
results.append(RerankResult(
index=result.index,
relevance_score=result.relevance_score,
document=RerankDocument(text=doc_text)
))
# Calculate cost
cost = None
if config.track_costs:
# Reranking is charged per search (1 search = 1 query across N documents)
cost = config.calculate_rerank_cost(self.model, 1)
return RerankResponse(
results=results,
model=self.model,
provider="cohere",
latency=latency,
num_documents=num_documents,
cost=cost,
metadata={
"id": getattr(raw_response, "id", None),
"meta": getattr(raw_response, "meta", None),
},
raw_response=raw_response
)
# Convenience alias
Rerank = CohereRerank