Initial commit
This commit is contained in:
commit
9961cb55a6
12 changed files with 3415 additions and 0 deletions
61
.gitignore
vendored
Normal file
61
.gitignore
vendored
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
.env/
|
||||||
|
|
||||||
|
# Python cache files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
dev-debug.log
|
||||||
|
|
||||||
|
|
||||||
|
.pytest_cache/
|
||||||
|
# Dependency directories
|
||||||
|
node_modules/
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.idea
|
||||||
|
.vscode
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
||||||
|
|
||||||
|
# OS specific
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
*.pyc
|
||||||
|
# Task files
|
||||||
|
# tasks.json
|
||||||
|
# tasks/
|
||||||
|
__pycache__/
|
||||||
|
*/__pycache__/
|
||||||
|
src/__pycache__/
|
||||||
|
tests/unit/__pycache__/
|
||||||
|
|
||||||
|
# Build artifacts
|
||||||
|
src/tinyrag.egg-info/
|
||||||
|
|
||||||
|
# Review files
|
||||||
|
review.md
|
||||||
|
chroma_db_poc/
|
||||||
|
tree-sitter-rescript/
|
259
docs/research/RAG.md
Normal file
259
docs/research/RAG.md
Normal file
|
@ -0,0 +1,259 @@
|
||||||
|
# 😎 Awesome Retrieval Augmented Generation (RAG) [](https://awesome.re)
|
||||||
|
|
||||||
|
This repository contains a curated [Awesome List](https://github.com/sindresorhus/awesome) and general information on Retrieval-Augmented Generation (RAG) applications in Generative AI.
|
||||||
|
|
||||||
|
Retrieval-Augmented Generation (RAG) is a technique in Generative AI where additional context is retrieved from external sources to enrich the generative process of Large Language Models (LLMs). This approach allows LLMs to incorporate up-to-date, specific, or sensitive information that they may lack from their pre-training data alone.
|
||||||
|
|
||||||
|
## Content
|
||||||
|
|
||||||
|
- [ℹ️ General Information on RAG](#ℹ%EF%B8%8F-general-information-on-rag)
|
||||||
|
- [🎯 Approaches](#-approaches)
|
||||||
|
- [🧰 Frameworks that Facilitate RAG](#-frameworks-that-facilitate-rag)
|
||||||
|
- [🛠️ Techniques](#-techniques)
|
||||||
|
- [📊 Metrics](#-metrics)
|
||||||
|
- [💾 Databases](#-databases)
|
||||||
|
|
||||||
|
## ℹ️ General Information on RAG
|
||||||
|
|
||||||
|
In traditional RAG approaches, a basic framework is employed to retrieve documents that enrich the context of an LLM prompt. For instance, when querying about materials for renovating a house, the LLM may possess general knowledge about renovation but lacks specific details about the particular house. Implementing an RAG architecture allows for quick searching and retrieval of relevant documents, such as blueprints, to offer more customized responses. This ensures that the LLM incorporates specific information to the renovation needs, thereby enhancing the accuracy of its responses.
|
||||||
|
|
||||||
|
**A typical RAG implementation follows these key steps:**
|
||||||
|
|
||||||
|
1. **Divide the knowledge base:** Break the document corpus into smaller, manageable chunks.
|
||||||
|
2. **Create embeddings:** Apply an embedding model to transform these text chunks into vector embeddings, capturing their semantic meaning.
|
||||||
|
3. **Store in a vector database:** Save the embeddings in a vector database, enabling fast retrieval based on semantic similarity.
|
||||||
|
4. **Handle user queries:** Convert the user's query into an embedding using the same model that was applied to the text chunks.
|
||||||
|
5. **Retrieve relevant data:** Search the vector database for embeddings that closely match the query’s embedding based on semantic similarity.
|
||||||
|
6. **Enhance the prompt:** Incorporate the most relevant text chunks into the LLM’s prompt to provide valuable context for generating a response.
|
||||||
|
7. **Generate a response:** The LLM leverages the augmented prompt to deliver a response that is accurate and tailored to the user’s query.
|
||||||
|
|
||||||
|
## 🎯 Approaches
|
||||||
|
|
||||||
|
RAG implementations vary in complexity, from simple document retrieval to advanced techniques integrating iterative feedback loops and domain-specific enhancements. Approaches may include:
|
||||||
|
|
||||||
|
- [Cache-Augmented Generation (CAG)](https://medium.com/@ronantech/cache-augmented-generation-cag-in-llms-a-step-by-step-tutorial-6ac35d415eec): Preloads relevant documents into a model’s context and stores the inference state (Key-Value (KV) cache).
|
||||||
|
- [Agentic RAG](https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/): Also known as retrieval agents, can make decisions on retrieval processes.
|
||||||
|
- [Corrective RAG](https://arxiv.org/pdf/2401.15884.pdf) (CRAG): Methods to correct or refine the retrieved information before integration into LLM responses.
|
||||||
|
- [Retrieval-Augmented Fine-Tuning](https://techcommunity.microsoft.com/t5/ai-ai-platform-blog/raft-a-new-way-to-teach-llms-to-be-better-at-rag/ba-p/4084674) (RAFT): Techniques to fine-tune LLMs specifically for enhanced retrieval and generation tasks.
|
||||||
|
- [Self Reflective RAG](https://selfrag.github.io/): Models that dynamically adjust retrieval strategies based on model performance feedback.
|
||||||
|
- [RAG Fusion](https://arxiv.org/abs/2402.03367): Techniques combining multiple retrieval methods for improved context integration.
|
||||||
|
- [Temporal Augmented Retrieval](https://adam-rida.medium.com/temporal-augmented-retrieval-tar-dynamic-rag-ad737506dfcc) (TAR): Considering time-sensitive data in retrieval processes.
|
||||||
|
- [Plan-then-RAG](https://arxiv.org/abs/2406.12430) (PlanRAG): Strategies involving planning stages before executing RAG for complex tasks.
|
||||||
|
- [GraphRAG](https://github.com/microsoft/graphrag): A structured approach using knowledge graphs for enhanced context integration and reasoning.
|
||||||
|
- [FLARE](https://medium.com/etoai/better-rag-with-active-retrieval-augmented-generation-flare-3b66646e2a9f) - An approach that incorporates active retrieval-augmented generation to improve response quality.
|
||||||
|
- [Contextual Retrieval](https://www.anthropic.com/news/contextual-retrieval) - Improves retrieval by adding relevant context to document chunks before retrieval, enhancing the relevance of information retrieved from large knowledge bases.
|
||||||
|
- [GNN-RAG](https://github.com/cmavro/GNN-RAG): Graph neural retrieval for large language modeling reasoning.
|
||||||
|
|
||||||
|
## 🧰 Frameworks that Facilitate RAG
|
||||||
|
|
||||||
|
- [Haystack](https://github.com/deepset-ai/haystack): LLM orchestration framework to build customizable, production-ready LLM applications.
|
||||||
|
- [LangChain](https://python.langchain.com/docs/modules/data_connection/): An all-purpose framework for working with LLMs.
|
||||||
|
- [Semantic Kernel](https://github.com/microsoft/semantic-kernel): An SDK from Microsoft for developing Generative AI applications.
|
||||||
|
- [LlamaIndex](https://docs.llamaindex.ai/en/stable/optimizing/production_rag/): Framework for connecting custom data sources to LLMs.
|
||||||
|
- [Dify](https://github.com/langgenius/dify): An open-source LLM app development platform.
|
||||||
|
- [Cognita](https://github.com/truefoundry/cognita): Open-source RAG framework for building modular and production ready applications.
|
||||||
|
- [Verba](https://github.com/weaviate/Verba): Open-source application for RAG out of the box.
|
||||||
|
- [Mastra](https://github.com/mastra-ai/mastra): Typescript framework for building AI applications.
|
||||||
|
- [Letta](https://github.com/letta-ai/letta): Open source framework for building stateful LLM applications.
|
||||||
|
- [Flowise](https://github.com/FlowiseAI/Flowise): Drag & drop UI to build customized LLM flows.
|
||||||
|
- [Swiftide](https://github.com/bosun-ai/swiftide): Rust framework for building modular, streaming LLM applications.
|
||||||
|
- [CocoIndex](https://github.com/cocoindex-io/cocoindex): ETL framework to index data for AI, such as RAG; with realtime incremental updates.
|
||||||
|
|
||||||
|
## 🛠️ Techniques
|
||||||
|
|
||||||
|
### Data cleaning
|
||||||
|
|
||||||
|
- [Data cleaning techniques](https://medium.com/intel-tech/four-data-cleaning-techniques-to-improve-large-language-model-llm-performance-77bee9003625): Pre-processing steps to refine input data and improve model performance.
|
||||||
|
|
||||||
|
### Prompting
|
||||||
|
|
||||||
|
- **Strategies**
|
||||||
|
- [Tagging and Labeling](https://python.langchain.com/v0.1/docs/use_cases/tagging/): Adding semantic tags or labels to retrieved data to enhance relevance.
|
||||||
|
- [Chain of Thought (CoT)](https://www.promptingguide.ai/techniques/cot): Encouraging the model to think through problems step by step before providing an answer.
|
||||||
|
- [Chain of Verification (CoVe)](https://sourajit16-02-93.medium.com/chain-of-verification-cove-understanding-implementation-e7338c7f4cb5): Prompting the model to verify each step of its reasoning for accuracy.
|
||||||
|
- [Self-Consistency](https://www.promptingguide.ai/techniques/consistency): Generating multiple reasoning paths and selecting the most consistent answer.
|
||||||
|
- [Zero-Shot Prompting](https://www.promptingguide.ai/techniques/zeroshot): Designing prompts that guide the model without any examples.
|
||||||
|
- [Few-Shot Prompting](https://python.langchain.com/docs/how_to/few_shot_examples/): Providing a few examples in the prompt to demonstrate the desired response format.
|
||||||
|
- [Reason & Act (ReAct) prompting](https://www.promptingguide.ai/techniques/react): Combines reasoning (e.g. CoT) with acting (e.g. tool calling).
|
||||||
|
- **Caching**
|
||||||
|
- [Prompt Caching](https://medium.com/@1kg/prompt-cache-what-is-prompt-caching-a-comprehensive-guide-e6cbae48e6a3): Optimizes LLMs by storing and reusing precomputed attention states.
|
||||||
|
|
||||||
|
### Chunking
|
||||||
|
|
||||||
|
- **[Fixed-size chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d)**
|
||||||
|
- Dividing text into consistent-sized segments for efficient processing.
|
||||||
|
- Splits texts into chunks based on size and overlap.
|
||||||
|
- Example: [Split by character](https://python.langchain.com/v0.1/docs/modules/data_connection/document_transformers/character_text_splitter/) (LangChain).
|
||||||
|
- Example: [SentenceSplitter](https://docs.llamaindex.ai/en/stable/api_reference/node_parsers/sentence_splitter/) (LlamaIndex).
|
||||||
|
- **[Recursive chunking](https://medium.com/@AbhiramiVS/chunking-methods-all-to-know-about-it-65c10aa7b24e)**
|
||||||
|
- Hierarchical segmentation using recursive algorithms for complex document structures.
|
||||||
|
- Example: [Recursively split by character](https://python.langchain.com/v0.1/docs/modules/data_connection/document_transformers/recursive_text_splitter/) (LangChain).
|
||||||
|
- **[Document-based chunking](https://medium.com/@david.richards.tech/document-chunking-for-rag-ai-applications-04363d48fbf7)**
|
||||||
|
- Segmenting documents based on metadata or formatting cues for targeted analysis.
|
||||||
|
- Example: [MarkdownHeaderTextSplitter](https://python.langchain.com/v0.1/docs/modules/data_connection/document_transformers/markdown_header_metadata/) (LangChain).
|
||||||
|
- Example: Handle image and text embeddings with models like [OpenCLIP](https://github.com/mlfoundations/open_clip).
|
||||||
|
- **[Semantic chunking](https://www.youtube.com/watch?v=8OJC21T2SL4&t=1933s)**
|
||||||
|
- Extracting meaningful sections based on semantic relevance rather than arbitrary boundaries.
|
||||||
|
- **[Agentic chunking](https://youtu.be/8OJC21T2SL4?si=8VnYaGUaBmtZhCsg&t=2882)**
|
||||||
|
- Interactive chunking methods where LLMs guide segmentation.
|
||||||
|
|
||||||
|
### Embeddings
|
||||||
|
|
||||||
|
- **Select embedding model**
|
||||||
|
- **[MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard)**: Explore [Hugging Face's](https://github.com/huggingface) benchmark for evaluating model embeddings.
|
||||||
|
- **Custom Embeddings**: Develop tailored embeddings for specific domains or tasks to enhance model performance. Custom embeddings can capture domain-specific terminology and nuances. Techniques include fine-tuning pre-trained models on your own dataset or training embeddings from scratch using frameworks like TensorFlow or PyTorch.
|
||||||
|
|
||||||
|
### Retrieval
|
||||||
|
|
||||||
|
- **Search Methods**
|
||||||
|
- [Vector Store Flat Index](https://weaviate.io/developers/academy/py/vector_index/flat)
|
||||||
|
- Simple and efficient form of retrieval.
|
||||||
|
- Content is vectorized and stored as flat content vectors.
|
||||||
|
- [Hierarchical Index Retrieval](https://pixion.co/blog/rag-strategies-hierarchical-index-retrieval)
|
||||||
|
- Hierarchically narrow data to different levels.
|
||||||
|
- Executes retrievals by hierarchical order.
|
||||||
|
- [Hypothetical Questions](https://pixion.co/blog/rag-strategies-hypothetical-questions-hyde)
|
||||||
|
- Used to increase similarity between database chunks and queries (same with HyDE).
|
||||||
|
- LLM is used to generate specific questions for each text chunk.
|
||||||
|
- Converts these questions into vector embeddings.
|
||||||
|
- During search, matches queries against this index of question vectors.
|
||||||
|
- [Hypothetical Document Embeddings (HyDE)](https://pixion.co/blog/rag-strategies-hypothetical-questions-hyde)
|
||||||
|
- Used to increase similarity between database chunks and queries (same with Hypothetical Questions).
|
||||||
|
- LLM is used to generate a hypothetical response based on the query.
|
||||||
|
- Converts this response into a vector embedding.
|
||||||
|
- Compares the query vector with the hypothetical response vector.
|
||||||
|
- [Small to Big Retrieval](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/retrieval-augmented-generation/small_to_big_rag/small_to_big_rag.ipynb)
|
||||||
|
- Improves retrieval by using smaller chunks for search and larger chunks for context.
|
||||||
|
- Smaller child chunks refers to bigger parent chunks
|
||||||
|
- **[Re-ranking](https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/)**: Enhances search results in RAG pipelines by reordering initially retrieved documents, prioritizing those most semantically relevant to the query.
|
||||||
|
|
||||||
|
### Response quality & safety
|
||||||
|
|
||||||
|
- **[Hallucination](https://machinelearningmastery.com/rag-hallucination-detection-techniques/):** When an AI model generates incorrect or fabricated information, which can be mitigated through grounding, refined retrieval, and verification techniques.
|
||||||
|
- **[Guardrails](https://developer.ibm.com/tutorials/awb-how-to-implement-llm-guardrails-for-rag-applications/):** Mechanisms to ensure accurate, ethical, and safe responses by applying content moderation, bias mitigation, and fact-checking.
|
||||||
|
- **[Prompt Injection Prevention](https://hiddenlayer.com/innovation-hub/prompt-injection-attacks-on-llms/):**
|
||||||
|
- **Input Validation:** Rigorously validate and sanitize all external inputs to ensure that only intended data is incorporated into the prompt.
|
||||||
|
- **Content Separation:** Clearly distinguish between trusted, static instructions and dynamic user data using templating or placeholders.
|
||||||
|
- **Output Monitoring:** Continuously monitor responses and logs for any anomalies that could indicate prompt manipulation, and adjust guardrails accordingly.
|
||||||
|
|
||||||
|
## 📊 Metrics
|
||||||
|
|
||||||
|
### Search metrics
|
||||||
|
|
||||||
|
These metrics are used to measure the similarity between embeddings, which is crucial for evaluating how effectively RAG systems retrieve and integrate external documents or data sources. By selecting appropriate similarity metrics, you can optimize the performance and accuracy of your RAG system. Alternatively, you may develop custom metrics tailored to your specific domain or niche to capture domain-specific nuances and improve relevance.
|
||||||
|
|
||||||
|
- **[Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity)**
|
||||||
|
|
||||||
|
- Measures the cosine of the angle between two vectors in a multi-dimensional space.
|
||||||
|
- Highly effective for comparing text embeddings where the direction of the vectors represents semantic information.
|
||||||
|
- Commonly used in RAG systems to measure semantic similarity between query embeddings and document embeddings.
|
||||||
|
|
||||||
|
- **[Dot Product](https://en.wikipedia.org/wiki/Dot_product)**
|
||||||
|
|
||||||
|
- Calculates the sum of the products of corresponding entries of two sequences of numbers.
|
||||||
|
- Equivalent to cosine similarity when vectors are normalized.
|
||||||
|
- Simple and efficient, often used with hardware acceleration for large-scale computations.
|
||||||
|
|
||||||
|
- **[Euclidean Distance](https://en.wikipedia.org/wiki/Euclidean_distance)**
|
||||||
|
|
||||||
|
- Computes the straight-line distance between two points in Euclidean space.
|
||||||
|
- Can be used with embeddings but may lose effectiveness in high-dimensional spaces due to the "[curse of dimensionality](https://stats.stackexchange.com/questions/99171/why-is-euclidean-distance-not-a-good-metric-in-high-dimensions)."
|
||||||
|
- Often used in clustering algorithms like K-means after dimensionality reduction.
|
||||||
|
|
||||||
|
- **[Jaccard Similarity](https://en.wikipedia.org/wiki/Jaccard_index)**
|
||||||
|
- Measures the similarity between two finite sets as the size of the intersection divided by the size of the union of the sets.
|
||||||
|
- Useful when comparing sets of tokens, such as in bag-of-words models or n-gram comparisons.
|
||||||
|
- Less applicable to continuous embeddings produced by LLMs.
|
||||||
|
|
||||||
|
> **Note:** Cosine Similarity and Dot Product are generally seen as the most effective metrics for measuring similarity between high-dimensional embeddings.
|
||||||
|
|
||||||
|
### Response Evaluation Metrics
|
||||||
|
|
||||||
|
Response evaluation in RAG solutions involves assessing the quality of language model outputs using diverse metrics. Here are structured approaches to evaluating these responses:
|
||||||
|
|
||||||
|
- **Automated Benchmarking**
|
||||||
|
|
||||||
|
- **[BLEU](https://en.wikipedia.org/wiki/BLEU):** Evaluates the overlap of n-grams between machine-generated and reference outputs, providing insight into precision.
|
||||||
|
- **[ROUGE](<https://en.wikipedia.org/wiki/ROUGE_(metric)>):** Measures recall by comparing n-grams, skip-bigrams, or longest common subsequence with reference outputs.
|
||||||
|
- **[METEOR](https://en.wikipedia.org/wiki/METEOR):** Focuses on exact matches, stemming, synonyms, and alignment for machine translation.
|
||||||
|
|
||||||
|
- **Human Evaluation**
|
||||||
|
Involves human judges assessing responses for:
|
||||||
|
|
||||||
|
- **Relevance:** Alignment with user queries.
|
||||||
|
- **Fluency:** Grammatical and stylistic quality.
|
||||||
|
- **Factual Accuracy:** Verifying claims against authoritative sources.
|
||||||
|
- **Coherence:** Logical consistency within responses.
|
||||||
|
|
||||||
|
- **Model Evaluation**
|
||||||
|
Leverages pre-trained evaluators to benchmark outputs against diverse criteria:
|
||||||
|
|
||||||
|
- **[TuringBench](https://turingbench.ist.psu.edu/):** Offers comprehensive evaluations across language benchmarks.
|
||||||
|
- **[Hugging Face Evaluate](https://huggingface.co/docs/evaluate/en/index):** Calculates alignment with human preferences.
|
||||||
|
|
||||||
|
- **Key Dimensions for Evaluation**
|
||||||
|
- **Groundedness:** Assesses if responses are based entirely on provided context. Low groundedness may indicate reliance on hallucinated or irrelevant information.
|
||||||
|
- **Completeness:** Measures if the response answers all aspects of a query.
|
||||||
|
- **Approaches:** AI-assisted retrieval scoring and prompt-based intent verification.
|
||||||
|
- **Utilization:** Evaluates the extent to which retrieved data contributes to the response.
|
||||||
|
- **Analysis:** Use LLMs to check the inclusion of retrieved chunks in responses.
|
||||||
|
|
||||||
|
#### Tools
|
||||||
|
|
||||||
|
These tools can assist in evaluating the performance of your RAG system, from tracking user feedback to logging query interactions and comparing multiple evaluation metrics over time.
|
||||||
|
|
||||||
|
- **[LangFuse](https://github.com/langfuse/langfuse)**: Open-source tool for tracking LLM metrics, observability, and prompt management.
|
||||||
|
- **[Ragas](https://docs.ragas.io/en/stable/)**: Framework that helps evaluate RAG pipelines.
|
||||||
|
- **[LangSmith](https://docs.smith.langchain.com/)**: A platform for building production-grade LLM applications, allows you to closely monitor and evaluate your application.
|
||||||
|
- **[Hugging Face Evaluate](https://github.com/huggingface/evaluate)**: Tool for computing metrics like BLEU and ROUGE to assess text quality.
|
||||||
|
- **[Weights & Biases](https://wandb.ai/wandb-japan/rag-hands-on/reports/Step-for-developing-and-evaluating-RAG-application-with-W-B--Vmlldzo1NzU4OTAx)**: Tracks experiments, logs metrics, and visualizes performance.
|
||||||
|
|
||||||
|
## 💾 Databases
|
||||||
|
|
||||||
|
The list below features several database systems suitable for Retrieval Augmented Generation (RAG) applications. They cover a range of RAG use cases, aiding in the efficient storage and retrieval of vectors to generate responses or recommendations.
|
||||||
|
|
||||||
|
### Benchmarks
|
||||||
|
|
||||||
|
- [Picking a vector database](https://benchmark.vectorview.ai/vectordbs.html)
|
||||||
|
|
||||||
|
### Distributed Data Processing and Serving Engines:
|
||||||
|
|
||||||
|
- [Apache Cassandra](https://cassandra.apache.org/doc/latest/cassandra/vector-search/concepts.html): Distributed NoSQL database management system.
|
||||||
|
- [MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search): Globally distributed, multi-model database service with integrated vector search.
|
||||||
|
- [Vespa](https://vespa.ai/): Open-source big data processing and serving engine designed for real-time applications.
|
||||||
|
|
||||||
|
### Search Engines with Vector Capabilities:
|
||||||
|
|
||||||
|
- [Elasticsearch](https://www.elastic.co/elasticsearch): Provides vector search capabilities along with traditional search functionalities.
|
||||||
|
- [OpenSearch](https://github.com/opensearch-project/OpenSearch): Distributed search and analytics engine, forked from Elasticsearch.
|
||||||
|
|
||||||
|
### Vector Databases:
|
||||||
|
|
||||||
|
- [Chroma DB](https://github.com/chroma-core/chroma): An AI-native open-source embedding database.
|
||||||
|
- [Milvus](https://github.com/milvus-io/milvus): An open-source vector database for AI-powered applications.
|
||||||
|
- [Pinecone](https://www.pinecone.io/): A serverless vector database, optimized for machine learning workflows.
|
||||||
|
- [Oracle AI Vector Search](https://www.oracle.com/database/ai-vector-search/#retrieval-augmented-generation): Integrates vector search capabilities within Oracle Database for semantic querying based on vector embeddings.
|
||||||
|
|
||||||
|
### Relational Database Extensions:
|
||||||
|
|
||||||
|
- [Pgvector](https://github.com/pgvector/pgvector): An open-source extension for vector similarity search in PostgreSQL.
|
||||||
|
|
||||||
|
### Other Database Systems:
|
||||||
|
|
||||||
|
- [Azure Cosmos DB](https://learn.microsoft.com/en-us/azure/cosmos-db/vector-database): Globally distributed, multi-model database service with integrated vector search.
|
||||||
|
- [Couchbase](https://www.couchbase.com/products/vector-search/): A distributed NoSQL cloud database.
|
||||||
|
- [Lantern](https://lantern.dev/): A privacy-aware personal search engine.
|
||||||
|
- [LlamaIndex](https://docs.llamaindex.ai/en/stable/module_guides/storing/vector_stores/): Employs a straightforward in-memory vector store for rapid experimentation.
|
||||||
|
- [Neo4j](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/): Graph database management system.
|
||||||
|
- [Qdrant](https://github.com/neo4j/neo4j): An open-source vector database designed for similarity search.
|
||||||
|
- [Redis Stack](https://redis.io/docs/latest/develop/interact/search-and-query/): An in-memory data structure store used as a database, cache, and message broker.
|
||||||
|
- [SurrealDB](https://github.com/surrealdb/surrealdb): A scalable multi-model database optimized for time-series data.
|
||||||
|
- [Weaviate](https://github.com/weaviate/weaviate): A open-source cloud-native vector search engine.
|
||||||
|
|
||||||
|
### Vector Search Libraries and Tools:
|
||||||
|
|
||||||
|
- [FAISS](https://github.com/facebookresearch/faiss): A library for efficient similarity search and clustering of dense vectors, designed to handle large-scale datasets and optimized for fast retrieval of nearest neighbors.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
This list continues to evolve. Contributions are welcome to make this resource more comprehensive 🙌
|
645
docs/research/rag-research.md
Normal file
645
docs/research/rag-research.md
Normal file
|
@ -0,0 +1,645 @@
|
||||||
|
# RAG Research for Personal Knowledge Base
|
||||||
|
|
||||||
|
## Research Summary (Phase 1: Initial Repository Analysis)
|
||||||
|
|
||||||
|
This phase focuses on analyzing the information retrieved from the `https://github.com/NirDiamant/RAG_Techniques` repository.
|
||||||
|
|
||||||
|
### Search Strategy
|
||||||
|
- **Initial Query (Implicit):** User provided `RAG.md` which has general RAG info.
|
||||||
|
- **Targeted Retrieval:** Focused on `https://github.com/NirDiamant/RAG_Techniques` as a comprehensive source of advanced RAG techniques.
|
||||||
|
- **Focus Areas:** Identifying RAG techniques, frameworks, evaluation methods, and architectural patterns relevant to building a personal, evolving knowledge base for an AI agent.
|
||||||
|
|
||||||
|
### Key Resources
|
||||||
|
1. **`RAG.md` (User-provided):** General overview of RAG, basic steps, common frameworks, databases, and evaluation metrics.
|
||||||
|
2. **`https://github.com/NirDiamant/RAG_Techniques` (Retrieved via Puppeteer):** An extensive collection of advanced RAG techniques, tutorials, and links to further resources.
|
||||||
|
|
||||||
|
### Critical Findings from `NirDiamant/RAG_Techniques` (Relevant to Personal KB)
|
||||||
|
|
||||||
|
The repository lists numerous techniques. The following are initially highlighted for their potential relevance to a personal knowledge base that needs to be accurate, adaptable, and efficiently queried by an AI agent.
|
||||||
|
|
||||||
|
**I. Foundational Techniques:**
|
||||||
|
* **Basic RAG:** Essential starting point.
|
||||||
|
* **Optimizing Chunk Sizes:** Crucial for balancing context and retrieval efficiency. A personal KB might have diverse document lengths.
|
||||||
|
* **Proposition Chunking:** Breaking text into meaningful sentences could be highly beneficial for factual recall from personal notes or scraped articles.
|
||||||
|
|
||||||
|
**II. Query Enhancement:**
|
||||||
|
* **Query Transformations (Rewriting, Step-back, Sub-query):** Useful if the AI agent needs to formulate complex queries or if initial queries are too narrow.
|
||||||
|
* **HyDE (Hypothetical Document Embedding):** Could improve retrieval for nuanced personal queries by generating a hypothetical answer first.
|
||||||
|
* **HyPE (Hypothetical Prompt Embedding):** Precomputing hypothetical questions at indexing could speed up retrieval and improve alignment for a personal KB where query patterns might emerge.
|
||||||
|
|
||||||
|
**III. Context Enrichment:**
|
||||||
|
* **Contextual Chunk Headers:** Adding document/section context to chunks can improve retrieval accuracy, especially for diverse personal documents.
|
||||||
|
* **Relevant Segment Extraction:** Dynamically constructing multi-chunk segments could provide more complete context to the LLM from personal notes or longer articles.
|
||||||
|
* **Semantic Chunking:** More meaningful than fixed-size, good for understanding topics within personal data.
|
||||||
|
* **Contextual Compression:** Useful for fitting more relevant information into the LLM's context window, especially if personal documents are verbose.
|
||||||
|
* **Document Augmentation (Question Generation):** Could enhance retrieval for a personal KB by pre-generating potential questions the user/agent might ask.
|
||||||
|
|
||||||
|
**IV. Advanced Retrieval Methods:**
|
||||||
|
* **Fusion Retrieval (Keyword + Vector):** Could be powerful for a personal KB, allowing both semantic and exact-match searches (e.g., finding a specific term in notes). This is distinct from RAG-Fusion which focuses on multiple generated queries.
|
||||||
|
* **Reranking (LLM-based, Cross-Encoder):** Important for ensuring the most relevant personal information is prioritized.
|
||||||
|
* **Multi-faceted Filtering (Metadata, Similarity, Content, Diversity):** Essential for a personal KB to filter by date, source, type of note, or to avoid redundant information.
|
||||||
|
* **Hierarchical Indices:** Summaries + detailed chunks could be useful for navigating a large personal KB.
|
||||||
|
|
||||||
|
**V. Iterative and Adaptive Techniques:**
|
||||||
|
* **Retrieval with Feedback Loops:** If the agent can provide feedback, this could continuously improve the personal RAG's performance.
|
||||||
|
* **Adaptive Retrieval:** Tailoring strategies based on query types (e.g., "summarize my notes on X" vs. "find specific detail Y").
|
||||||
|
* **Iterative Retrieval:** Useful for complex questions requiring information from multiple personal documents.
|
||||||
|
|
||||||
|
**VI. Evaluation:**
|
||||||
|
* **DeepEval, GroUSE:** While potentially complex for a purely personal setup, understanding these metrics can guide manual evaluation or simpler automated checks. Key aspects for personal KB: Faithfulness (is it true to my notes?), Contextual Relevancy.
|
||||||
|
|
||||||
|
**VII. Advanced Architectures:**
|
||||||
|
* **Graph RAG (General & Microsoft's):** Potentially very powerful for connecting disparate pieces of information in a personal KB (e.g., linking project notes, contacts, and related articles). Might be complex to set up initially.
|
||||||
|
* **RAPTOR (Recursive Abstractive Processing):** Tree-organized retrieval could be excellent for hierarchical personal notes or structured scraped content.
|
||||||
|
* **Self-RAG:** Dynamically deciding to retrieve or not could make the agent more efficient with its personal KB.
|
||||||
|
* **Corrective RAG (CRAG):** Evaluating and correcting retrieval, possibly using web search if personal KB is insufficient, aligns well with an AI agent's needs.
|
||||||
|
* **Integrating Structured Personal Data (Conceptual):** While the primary focus of RAG is often unstructured text, a truly unified personal knowledge base (PUKB) should consider how to incorporate structured personal data like calendar events, to-do lists, and contacts. Conceptual approaches include:
|
||||||
|
1. **Conversion to Natural Language for Vector Search:**
|
||||||
|
* **Approach:** Transform structured items into descriptive sentences or paragraphs. For example, a calendar event "Project Alpha Sync, 2025-05-15 10:00 AM" could become "A meeting for Project Alpha Sync is scheduled on May 15, 2025, at 10:00 AM." This text is then embedded and indexed alongside other unstructured data.
|
||||||
|
* **Pros:** Allows for a single, unified search interface using the existing vector RAG pipeline; enables semantic querying across all data types.
|
||||||
|
* **Cons:** Potential loss of precise structured query capabilities (e.g., exact date range filtering might be less reliable than on structured fields); overhead of accurately converting structured data to natural language; may require careful template design or LLM prompting for conversion.
|
||||||
|
2. **Hybrid Search (Vector DB + Dedicated Structured Store):**
|
||||||
|
* **Approach:** Maintain structured personal data in a simple, local structured datastore (e.g., an SQLite database, or by directly parsing standard PIM files like iCalendar `.ics` or vCard `.vcf`). The AI agent, based on query analysis (e.g., detecting keywords like "calendar," "task," "contact," or specific date/time phrases), would decide to query the vector DB, the structured store, or both. Results would then be synthesized.
|
||||||
|
* **Pros:** Preserves the full fidelity and precise queryability of structured data; potentially more efficient for purely structured queries (e.g., "What are my appointments next Monday?").
|
||||||
|
* **Cons:** Increases agent complexity (query routing, results fusion); requires maintaining and querying two distinct types of data stores; may necessitate Text-to-SQL or Text-to-API-call capabilities for the agent to interact with the structured store based on natural language.
|
||||||
|
3. **Knowledge Graph Augmentation (Advanced):**
|
||||||
|
* **Approach:** Model structured entities (people, events, tasks, projects) and their interrelationships within a local knowledge graph. Link these graph nodes to relevant unstructured text chunks in the vector DB.
|
||||||
|
* **Pros:** Enables complex, multi-hop queries that traverse relationships (e.g., "Show me notes related to people I met with last week concerning Project X"). Provides a richer, interconnected view of personal knowledge.
|
||||||
|
* **Cons:** Significantly more complex to design, implement, and maintain. Steeper learning curve. Best considered as a future enhancement.
|
||||||
|
* **Initial PUKB Strategy:** For initial development, converting structured data to natural language for vector search offers the simplest path to unification. A hybrid approach can be a powerful V2 enhancement for more precise structured queries.
|
||||||
|
|
||||||
|
**VIII. Special Technique:**
|
||||||
|
* **Sophisticated Controllable Agent:** The ultimate goal, using a graph as a "brain" for complex RAG tasks on personal data.
|
||||||
|
|
||||||
|
### Technical Considerations for Personal KB
|
||||||
|
|
||||||
|
* **Data Ingestion:** How easily can new notes, scraped web content (`crawl4ai`), PDFs, etc., be added and indexed?
|
||||||
|
* **Update Mechanisms, Frequency, and Synchronization for Local Files:** Personal knowledge, especially from local files, evolves. The PUKB must efficiently detect changes and update its vector index to maintain relevance and accuracy. Strategies include:
|
||||||
|
* **Change Detection Methods for Local Files:**
|
||||||
|
1. **File System Watchers (Real-time/Near Real-time):**
|
||||||
|
* **Tool:** Python's `watchdog` library is a common cross-platform solution.
|
||||||
|
* **Mechanism:** Monitors specified directories for events like file creation, deletion, modification, and moves.
|
||||||
|
* **Pros:** Provides immediate or near-immediate updates to the PUKB.
|
||||||
|
* **Cons:** Can be resource-intensive if watching a vast number of files/folders; setup might have platform-specific nuances; requires a persistent process.
|
||||||
|
2. **Checksum/Hash Comparison (Periodic Scan):**
|
||||||
|
* **Tool:** Python's `hashlib` (e.g., for SHA256).
|
||||||
|
* **Mechanism:** During initial ingestion, store a hash of each local file's content in its metadata (e.g., `file_hash`). Periodically (e.g., on agent startup, scheduled background task), scan monitored directories. For each file, recalculate its hash and compare it to the stored hash. If different, or if the file is new, trigger re-ingestion. If a file path from the DB is no longer found, its chunks can be marked as stale or deleted.
|
||||||
|
* **Pros:** Robust against content changes; platform-independent logic.
|
||||||
|
* **Cons:** Not real-time; scans can be time-consuming for very large knowledge bases if not optimized (e.g., only hashing files whose `last_modified_date_os` has changed since last scan).
|
||||||
|
3. **Last Modified Timestamp (Periodic Scan):**
|
||||||
|
* **Tool:** Python's `os.path.getmtime()`.
|
||||||
|
* **Mechanism:** Store the `last_modified_date_os` in metadata. Periodically scan and compare the current timestamp with the stored one. If different, trigger re-ingestion.
|
||||||
|
* **Pros:** Simpler and faster to check than full hashing for an initial pass.
|
||||||
|
* **Cons:** Less robust than hashing, as timestamps can sometimes be misleading or not change even if content does (rare, but possible with some tools/operations).
|
||||||
|
4. **Manual Trigger:**
|
||||||
|
* **Mechanism:** The user explicitly commands the AI agent or PUKB system to re-index a specific file or folder.
|
||||||
|
* **Pros:** Simple, user-controlled, good for ad-hoc updates.
|
||||||
|
* **Cons:** Relies on the user to remember and initiate; not suitable for automatic synchronization.
|
||||||
|
* **Recommended Hybrid Approach:** Combine a periodic scan (e.g., daily or on startup) using last modified timestamps as a quick check, followed by hash comparison for files whose timestamps have changed, with an optional file system watcher for designated 'hot' or frequently updated directories if near real-time sync is critical for those.
|
||||||
|
* **Efficient Vector DB Update Strategy (Document-Level Re-indexing):**
|
||||||
|
* When a change in a local file is detected (or a file is deleted):
|
||||||
|
1. Identify the `parent_document_id` associated with the changed/deleted file (e.g., using `original_path` from metadata).
|
||||||
|
2. Delete all existing chunks from the vector database that share this `parent_document_id`.
|
||||||
|
3. If the file was modified (not deleted), re-ingest it: perform full text extraction, preprocessing, chunking, and embedding for the updated content.
|
||||||
|
4. Add the new chunks and their updated metadata (including new `file_hash` and `last_modified_date_os`) to the vector database.
|
||||||
|
* This document-level re-indexing is far more efficient than a full re-index of the entire knowledge base, especially for local file changes.
|
||||||
|
* **Frequency of Updates:**
|
||||||
|
* For watchers: Near real-time.
|
||||||
|
* For periodic scans: Configurable (e.g., on application start, hourly, daily). Balance between freshness and resource usage.
|
||||||
|
* Web-scraped content (`crawl4ai`): Updates depend on re-crawling schedules, which is a separate consideration from local file sync.
|
||||||
|
* **Scalability:** While "personal," the KB could grow significantly. The chosen techniques and database should handle this.
|
||||||
|
* **Privacy & Security in a Unified PUKB:** Ensuring the privacy and security of potentially sensitive personal or mixed (work/personal) data is paramount. A multi-layered approach is recommended:
|
||||||
|
* **Data Segregation within the Vector DB:**
|
||||||
|
* **ChromaDB:** Supports a hierarchy of Tenants -> Databases -> Collections. For PUKB, different personas or data categories (e.g., 'personal_notes', 'work_project_X', 'sensitive_health_info') can be stored in separate **Collections** or even separate **Databases** under a single tenant. Application logic is crucial for ensuring queries are directed to the correct collection(s) based on user context or active persona.
|
||||||
|
* **SQLite-VSS:** Data segregation relies on standard SQLite practices. This typically involves adding a `persona_id` column to vector tables and ensuring all application queries filter by this ID. Alternatively, separate SQLite database files per persona offer stronger isolation.
|
||||||
|
* **Qdrant (If considered for scalability):** Recommends using a single collection with **payload-based partitioning**. A `persona_id` field in each vector's payload would be used for filtering, ensuring each persona only accesses their data.
|
||||||
|
* **Weaviate (If considered for scalability):** Offers robust **built-in multi-tenancy**. Each persona could be a distinct tenant, providing strong data isolation at the database level, with data stored on separate shards.
|
||||||
|
* **Encryption:**
|
||||||
|
* **At Rest:**
|
||||||
|
* **OS-Level Full-Disk Encryption:** A baseline security measure for the machine hosting the PUKB.
|
||||||
|
* **Database-Level Encryption:** If the chosen vector DB supports it (e.g., encrypting database files or directories). For embedded DBs like ChromaDB (persisted) or SQLite, encrypting the parent directory or the files themselves using OS tools or libraries like `cryptography` can be an option.
|
||||||
|
* **Chunk-Level Encryption (Application-Side):** For highly sensitive data, consider encrypting individual chunks *before* storing them in the vector DB using libraries like Python's `cryptography`. Decryption would occur after retrieval and before sending to the LLM. Secure key management (e.g., OS keychain, environment variables for simpler setups, or dedicated secrets manager for advanced use) is critical.
|
||||||
|
* **In Transit:** Primarily relevant if any PUKB component (LLM, embedding model, or the DB itself if run as a separate server) is accessed over a network. Ensure all communications use TLS/HTTPS.
|
||||||
|
* **Role-Based Access Control (RBAC) / Persona-Based Filtering (Application Layer):**
|
||||||
|
* If the PUKB is used by an agent supporting multiple 'personas' (e.g., 'personal', 'work_project_A'), the application (agent) must enforce access control.
|
||||||
|
* The agent maintains an 'active persona' context.
|
||||||
|
* All queries to the PUKB (vector DB, structured stores) MUST be filtered based on this active persona, using the segregation mechanisms of the chosen DB (e.g., querying specific Chroma collections, filtering by `persona_id` in Qdrant/SQLite-VSS, or targeting a specific Weaviate tenant).
|
||||||
|
* **Implications of Local vs. Cloud LLMs/Embedding Models:**
|
||||||
|
* **Local Models (Ollama, Sentence-Transformers, etc.):** Offer the highest privacy as data (prompts, chunks for embedding) does not leave the user's machine. Consider performance and model capability trade-offs.
|
||||||
|
* **Cloud Models (OpenAI, Anthropic, etc.):** Data is sent to third-party servers. Users must review and accept the provider's data usage, retention, and privacy policies. There's an inherent risk of data exposure or use for model training (unless explicitly opted out and guaranteed by the provider).
|
||||||
|
* **Data Minimization:** Only ingest data that is necessary for the PUKB's purpose. For particularly sensitive information, evaluate if a summary, an anonymized version, or just metadata can be ingested instead of the full raw content.
|
||||||
|
* **Secure Deletion:** Implement a reliable process for deleting data. This involves removing chunks from the vector DB (and ensuring the DB reclaims space/updates indexes correctly) and deleting the original source file if requested. Some vector DBs might require specific commands or compaction processes for permanent deletion.
|
||||||
|
* **Input Sanitization:** If user queries are used to construct dynamic filters or other database interactions, ensure proper sanitization to prevent injection-style attacks, even in a local context.
|
||||||
|
* **Regular Backups:** Securely back up the PUKB (vector DB, configuration, local source files if not backed up elsewhere) to prevent data loss. Ensure backups are also encrypted.
|
||||||
|
* **Ease of Implementation/Maintenance:** For a personal system, overly complex solutions might be burdensome.
|
||||||
|
* **Query Types:** The system should handle diverse queries: factual recall, summarization, comparison, open-ended questions.
|
||||||
|
* **Integration with AI Agent & Query Handling for Unified KB:** The RAG output must be easily consumable, and the agent needs sophisticated logic to interact effectively with a unified PUKB containing diverse data types. This involves:
|
||||||
|
* **Query Intent Recognition:** The agent should employ Natural Language Processing (NLP) techniques (e.g., keyword analysis, intent classification models, or LLM-based analysis) to understand the user's query. For example, distinguishing between:
|
||||||
|
* `\"Summarize my notes on Project Titan.\"` (targets local notes, specific project)
|
||||||
|
* `\"What are the latest web articles on quantum computing?\"` (targets web scrapes, specific topic, recency)
|
||||||
|
* `\"Find emails from John Doe about the Q3 budget.\"` (targets emails, specific sender, topic)
|
||||||
|
* `\"What's on my calendar for next Monday?\"` (targets structured calendar data)
|
||||||
|
* **Metadata-Driven Query Routing & Filtering:** Based on the recognized intent, the agent dynamically constructs queries for the vector database, leveraging the Unified Metadata Schema:
|
||||||
|
* **Source Filtering:** If intent clearly points to a source (e.g., \"my notes\"), the agent adds a filter like `metadata.source_type IN ['manual_note', 'local_md']`.
|
||||||
|
* **Topic/Keyword Filtering:** Standard semantic search combined with keyword filters on `document_title` or `chunk_text` (if the vector DB supports hybrid search or if keywords are part of the metadata).
|
||||||
|
* **Date Filtering:** For time-sensitive queries (e.g., \"latest articles,\" \"notes from last week\"), use `ingestion_date`, `creation_date_os`, or `last_modified_date_os` fields.
|
||||||
|
* **Tag Filtering:** `user_tags` become powerful for personalized retrieval (e.g., `\"Find documents tagged 'urgent' and 'project_alpha'\"`).
|
||||||
|
* **Author/Sender Filtering:** For queries like `\"emails from Jane\"` or `\"documents by Smith\"`.
|
||||||
|
* **Handling Ambiguity:** If a query is ambiguous (e.g., `\"Tell me about Project X\"` could be notes, emails, or web data):
|
||||||
|
* The agent might initially query across a broader set of relevant `source_type`s and present categorized results.
|
||||||
|
* Alternatively, it could ask for clarification: `\"Are you looking for your notes, web articles, or emails regarding Project X?\"`
|
||||||
|
* **Structured Data Querying (Conceptual - see 'Integrating Structured Personal Data'):** If the query targets structured data (e.g., `\"Show my tasks due today\"`), the agent would need to route this to the appropriate structured data store/handler (e.g., an SQLite DB, iCalendar parser) instead of, or in addition to, the vector DB.
|
||||||
|
* **Consuming RAG Output:** The agent receives retrieved chunks (with their metadata) and passes them as context to an LLM for answer generation, summarization, or other tasks. The metadata itself can be valuable context for the LLM.
|
||||||
|
* **Unified Metadata Schema:** A flexible and comprehensive metadata schema is crucial for managing diverse data sources within the PUKB, enabling robust filtering, source tracking, and providing context to the LLM.
|
||||||
|
* **Core Metadata Fields (Applicable to all chunks):**
|
||||||
|
* `chunk_id`: (String) Unique identifier for this specific chunk (e.g., UUID).
|
||||||
|
* `parent_document_id`: (String) Unique identifier for the original source document/file/email/note this chunk belongs to (e.g., hash of file path, URL, or a UUID for the document).
|
||||||
|
* `source_type`: (String, Enum-like) Type of the original source. Examples: 'web_crawl4ai', 'local_txt', 'local_md', 'local_pdf_text', 'local_pdf_image_ocr', 'local_docx', 'local_image_ocr', 'local_email_eml', 'local_email_mbox', 'local_email_msg', 'local_code_snippet', 'manual_note', 'structured_calendar_event', 'structured_todo_item', 'structured_contact'.
|
||||||
|
* `document_title`: (String) Title of the source document (e.g., web page `<title>`, filename, email subject, first heading in a note).
|
||||||
|
* `original_path`: (String, Nullable) Absolute file system path for local files.
|
||||||
|
* `url`: (String, Nullable) Original URL for web-scraped content.
|
||||||
|
* `file_hash`: (String, Nullable) Hash (e.g., SHA256) of the original local file content, for change detection.
|
||||||
|
* `creation_date_os`: (ISO8601 Timestamp, Nullable) Original creation date of the file/document from the OS or source metadata.
|
||||||
|
* `last_modified_date_os`: (ISO8601 Timestamp, Nullable) Original last modified date of the file/document from the OS or source metadata.
|
||||||
|
* `ingestion_date`: (ISO8601 Timestamp) Date and time when the content was ingested into the PUKB.
|
||||||
|
* `author`: (String, Nullable) Author(s) of the document, if available/extractable.
|
||||||
|
* `user_tags`: (List of Strings, Nullable) User-defined tags or keywords associated with the document or chunk.
|
||||||
|
* `chunk_sequence_number`: (Integer, Nullable) If the document is split into multiple chunks, this indicates the order of the chunk within the document.
|
||||||
|
* `text_preview`: (String, Nullable) A short preview (e.g., first 200 chars) of the chunk's text content for quick inspection (optional, can increase storage).
|
||||||
|
* **Source-Specific Optional Fields (Examples):** These can be stored as a nested JSON object (e.g., `source_specific_details`) or flattened with prefixes if the vector DB prefers.
|
||||||
|
* For `source_type` starting with `'local_email_'`: `email_sender`, `email_recipients_to`, `email_recipients_cc`, `email_recipients_bcc`, `email_date_sent`, `email_message_id`.
|
||||||
|
* For `source_type: 'web_crawl4ai'`: `web_domain`, `crawl_depth` (if applicable).
|
||||||
|
* For `source_type: 'local_code_snippet'`: `code_language`.
|
||||||
|
* For `source_type: 'structured_calendar_event'`: `event_start_time`, `event_end_time`, `event_location`, `event_attendees`.
|
||||||
|
* **Population:** Metadata is populated during the ingestion pipeline. `chunk_id` and `parent_document_id` are generated. OS/file metadata is gathered for local files. Web metadata comes from crawl results. Email headers provide email-specific fields. `ingestion_date` is timestamped. `user_tags` can be added later.
|
||||||
|
|
||||||
|
### Best Practices Identified (Preliminary)
|
||||||
|
|
||||||
|
* **Modular Design:** Separate components for ingestion, chunking, embedding, storage, retrieval, and generation.
|
||||||
|
* **Experimentation:** Chunking strategy, embedding models, and retrieval methods often require experimentation.
|
||||||
|
* **Evaluation is Key:** Even for a personal system, periodically checking relevance and accuracy is important.
|
||||||
|
* **Start Simple:** Begin with a basic RAG and iteratively add advanced techniques.
|
||||||
|
|
||||||
|
### Common Pitfalls to Avoid for Personal KB
|
||||||
|
|
||||||
|
* **Over-chunking or Under-chunking:** Finding the right balance is critical.
|
||||||
|
* **Stale Index:** Not updating the RAG with new information regularly.
|
||||||
|
* **Ignoring Metadata:** Not using source, date, tags for filtering and context.
|
||||||
|
* **Choosing an Overly Complex System:** Starting with something too difficult to maintain for personal use.
|
||||||
|
* **Vendor Lock-in:** If using cloud services, consider portability.
|
||||||
|
|
||||||
|
### Impact on Approach for Personal KB
|
||||||
|
|
||||||
|
The `NirDiamant/RAG_Techniques` repository provides a rich set of options. For a personal knowledge base, the initial focus should be on:
|
||||||
|
1. **Solid Foundational RAG:** Good chunking (semantic or proposition), reliable embedding model.
|
||||||
|
2. **Effective Retrieval:** Fusion retrieval (keyword + semantic) and reranking seem highly valuable.
|
||||||
|
3. **Contextual Understanding:** Techniques like contextual chunk headers and relevant segment extraction.
|
||||||
|
4. **Manageable Complexity:** Prioritize techniques that can be implemented and maintained without excessive effort for a personal system. GraphRAG and RAPTOR are powerful but might be later-stage enhancements.
|
||||||
|
5. **Data Ingestion:** Needs to be seamless with `crawl4ai` outputs and manual note entry.
|
||||||
|
|
||||||
|
### Frameworks and Tools Mentioned (from NirDiamant/RAG_Techniques & RAG.md)
|
||||||
|
|
||||||
|
* **RAG Frameworks:**
|
||||||
|
* LangChain (has `crawl4ai` loader)
|
||||||
|
* LlamaIndex
|
||||||
|
* Haystack
|
||||||
|
* Semantic Kernel
|
||||||
|
* Dify, Cognita, Verba, Mastra, Letta, Flowise, Swiftide, CocoIndex
|
||||||
|
* **Evaluation Tools:**
|
||||||
|
* DeepEval
|
||||||
|
* GroUSE
|
||||||
|
* LangFuse, Ragas, LangSmith (more for production LLM apps but principles apply)
|
||||||
|
* **Vector Databases (also from `RAG.md`):**
|
||||||
|
* **Open Source / Local-friendly:** ChromaDB, Milvus (can be local), Qdrant, Weaviate (can be local), pgvector (PostgreSQL extension), FAISS (library), LlamaIndex (in-memory default), SQLite-VSS.
|
||||||
|
* **Cloud/Managed:** Pinecone, MongoDB Atlas, Vespa, Elasticsearch, OpenSearch, Oracle AI Vector Search, Azure Cosmos DB, Couchbase.
|
||||||
|
* **Graph-based:** Neo4j (can store vectors).
|
||||||
|
|
||||||
|
This initial analysis of the `NirDiamant/RAG_Techniques` repository provides a strong foundation. The next steps will involve deeper dives into the most promising techniques, vector databases, and frameworks suitable for a personal knowledge base.
|
||||||
|
|
||||||
|
## Phase 2: Unified Ingestion Strategy for Local Files
|
||||||
|
|
||||||
|
This section details the tools and methods for ingesting various local file types into the Personal Unified Knowledge Base (PUKB), prioritizing local-first, open-source solutions.
|
||||||
|
|
||||||
|
### 1. Plain Text (.txt)
|
||||||
|
* **Tool(s):** Python's built-in `open()` function.
|
||||||
|
* **Workflow:**
|
||||||
|
1. Use `with open('filepath.txt', 'r', encoding='utf-8') as f:` to open the file (specify encoding if known, UTF-8 is a good default).
|
||||||
|
2. Read content using `f.read()`.
|
||||||
|
3. Basic cleaning: Strip leading/trailing whitespace, normalize newlines if necessary.
|
||||||
|
|
||||||
|
### 2. Markdown (.md)
|
||||||
|
* **Tool(s):**
|
||||||
|
* Python's built-in `open()` for reading raw Markdown text.
|
||||||
|
* `markdown` library (e.g., `pip install Markdown`) if conversion to HTML is desired as an intermediate step (less common for direct RAG ingestion of Markdown).
|
||||||
|
* **Workflow (Raw Text):**
|
||||||
|
1. Use `with open('filepath.md', 'r', encoding='utf-8') as f:` to open.
|
||||||
|
2. Read content using `f.read()`. This raw Markdown is often suitable for direct chunking.
|
||||||
|
3. Basic cleaning: Similar to .txt files.
|
||||||
|
|
||||||
|
### 3. PDF (Text-Based)
|
||||||
|
* **Tool(s):** `PyMuPDF` (also known as `fitz`, `pip install PyMuPDF`). `pypdf2` (`pip install pypdf2`) is an alternative.
|
||||||
|
* **Workflow (`PyMuPDF`):**
|
||||||
|
1. Import `fitz`.
|
||||||
|
2. Open PDF: `doc = fitz.open('filepath.pdf')`.
|
||||||
|
3. Initialize an empty string for all text: `full_text = ""`.
|
||||||
|
4. Iterate through pages: `for page_num in range(len(doc)): page = doc.load_page(page_num); full_text += page.get_text()`.
|
||||||
|
5. Close document: `doc.close()`.
|
||||||
|
6. Clean extracted text: Remove excessive newlines, ligatures, or broken words if present.
|
||||||
|
* **Feeding to Advanced Chunkers:**
|
||||||
|
* **Semantic Chunking:** The preprocessed `raw_markdown` can be directly fed. Semantic chunkers often leverage Markdown structure (headings, paragraphs, lists) to identify meaningful boundaries for chunks.
|
||||||
|
* **Proposition Chunking:** The preprocessed `raw_markdown` is suitable. The proposition extractor (often LLM-based) will then parse this text to identify and extract atomic factual statements.
|
||||||
|
|
||||||
|
### 4. PDF (Scanned/Image-Based OCR)
|
||||||
|
* **Tool(s):**
|
||||||
|
* Option A: `OCRmyPDF` (`pip install ocrmypdf`, requires Tesseract installed system-wide).
|
||||||
|
* Option B: `PyMuPDF` (to extract images) + `pytesseract` (`pip install pytesseract`, requires Tesseract) + `Pillow` (`pip install Pillow`). `EasyOCR` (`pip install easyocr`) is an alternative to `pytesseract`.
|
||||||
|
* **Workflow (Option A - `OCRmyPDF`):**
|
||||||
|
1. Use command line: `ocrmypdf input.pdf output_ocr.pdf`.
|
||||||
|
2. Process `output_ocr.pdf` as a text-based PDF using `PyMuPDF` (see above) to extract the OCRed text layer.
|
||||||
|
* **Workflow (Option B - `PyMuPDF` + `pytesseract`):**
|
||||||
|
* **Feeding to Advanced Chunkers:**
|
||||||
|
* **Semantic Chunking:** The clean text extracted from `cleaned_html` (after HTML-to-text conversion and further cleaning) is fed to the semantic chunker. The quality of semantic segmentation will depend on how well the original HTML structure (paragraphs, sections) was translated into logical text blocks during the HTML-to-text conversion.
|
||||||
|
* **Proposition Chunking:** The clean text extracted from `cleaned_html` is suitable. The proposition extractor will parse this text for factual statements.
|
||||||
|
1. Open PDF with `PyMuPDF`: `doc = fitz.open('filepath.pdf')`.
|
||||||
|
2. Iterate through pages. For each page:
|
||||||
|
* Extract images: `pix = page.get_pixmap()`. Convert `pix` to a `Pillow` Image object.
|
||||||
|
* Perform OCR: `text_on_page = pytesseract.image_to_string(pil_image)`.
|
||||||
|
* Append `text_on_page` to `full_text`.
|
||||||
|
3. Clean OCRed text: This often requires more significant cleaning for OCR errors, layout artifacts, etc.
|
||||||
|
|
||||||
|
### 5. DOCX (Microsoft Word)
|
||||||
|
* **Tool(s):** `python-docx` (`pip install python-docx`).
|
||||||
|
* **Workflow:**
|
||||||
|
1. Import `Document` from `docx`.
|
||||||
|
2. Open document: `doc = Document('filepath.docx')`.
|
||||||
|
3. Initialize an empty string: `full_text = ""`.
|
||||||
|
4. Iterate through paragraphs: `for para in doc.paragraphs: full_text += para.text + '\n'`.
|
||||||
|
5. (Optional) Extract text from tables, headers, footers if needed, using respective `python-docx` APIs. `python-docx2txt` might simplify this.
|
||||||
|
|
||||||
|
### 6. Common Image Formats (PNG, JPG, etc. for OCR)
|
||||||
|
* **Tool(s):** `Pillow` (`pip install Pillow`) + `pytesseract` (`pip install pytesseract`). `EasyOCR` as an alternative.
|
||||||
|
* **Workflow:**
|
||||||
|
1. Import `Image` from `PIL` and `pytesseract`.
|
||||||
|
2. Open image: `img = Image.open('imagepath.png')`.
|
||||||
|
3. Extract text: `text_content = pytesseract.image_to_string(img)`.
|
||||||
|
4. Clean OCRed text.
|
||||||
|
|
||||||
|
### 7. Email (.eml)
|
||||||
|
* **Tool(s):** Python's built-in `email` module (specifically `email.parser.Parser` or `email.message_from_file`). `eml_parser` for a higher-level API.
|
||||||
|
* **Workflow (built-in `email`):**
|
||||||
|
1. Import `email.parser`.
|
||||||
|
2. Open file: `with open('filepath.eml', 'r') as f: msg = email.parser.Parser().parse(f)`.
|
||||||
|
3. Extract body: Iterate `msg.walk()`. For each part, check `part.get_content_type()`.
|
||||||
|
* If `text/plain`, get payload: `body = part.get_payload(decode=True).decode(part.get_content_charset() or 'utf-8')`.
|
||||||
|
* If `text/html`, get payload and strip HTML tags (e.g., using `BeautifulSoup`).
|
||||||
|
4. Extract other relevant fields: `msg['Subject']`, `msg['From']`, `msg['To']`, `msg['Date']`.
|
||||||
|
|
||||||
|
### 8. Email (.msg - Microsoft Outlook)
|
||||||
|
* **Tool(s):** `extract_msg` (`pip install extract-msg`).
|
||||||
|
* **Workflow:**
|
||||||
|
1. Import `Message` from `extract_msg`.
|
||||||
|
2. Open file: `msg = Message('filepath.msg')`.
|
||||||
|
3. Access properties: `body = msg.body`, `subject = msg.subject`, `sender = msg.sender`, `to = msg.to`, `date = msg.date`.
|
||||||
|
4. The `body` usually contains the main text content.
|
||||||
|
|
||||||
|
### 9. Email (mbox)
|
||||||
|
* **Tool(s):** Python's built-in `mailbox` module.
|
||||||
|
* **Workflow:**
|
||||||
|
1. Import `mailbox`.
|
||||||
|
2. Open mbox file: `mbox_archive = mailbox.mbox('filepath.mbox')`.
|
||||||
|
3. Iterate through messages: `for message in mbox_archive:`.
|
||||||
|
4. Each `message` is an `email.message.Message` instance. Process it similarly to .eml files (see above) to extract text from payloads.
|
||||||
|
5. Extract other relevant fields: `message['Subject']`, etc.
|
||||||
|
|
||||||
|
### 10. Code Snippets (e.g., .py, .js, .java)
|
||||||
|
* **Tool(s):** Python's built-in `open()`. Language-specific parsers like `ast` for Python for deeper analysis (optional).
|
||||||
|
* **Workflow (Raw Text):**
|
||||||
|
1. Treat as plain text: `with open('filepath.py', 'r', encoding='utf-8') as f: code_text = f.read()`.
|
||||||
|
2. This raw code text can be directly used for embedding and RAG.
|
||||||
|
3. Further preprocessing might involve stripping comments or formatting, depending on the desired RAG behavior.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Preprocessing and Advanced Chunking for Diverse Data
|
||||||
|
|
||||||
|
Effective RAG relies on high-quality chunks. This section details preprocessing steps for various data types before applying advanced chunking techniques.
|
||||||
|
|
||||||
|
### Part 1: Preprocessing Specifics by Data Type
|
||||||
|
|
||||||
|
(Assumes initial text extraction as per "Phase 2: Unified Ingestion Strategy for Local Files")
|
||||||
|
|
||||||
|
* **`crawl4ai` - `markdown.raw_markdown` Output:**
|
||||||
|
* **Goal:** Ensure clean, consistent Markdown.
|
||||||
|
* **Steps:**
|
||||||
|
1. Normalize newlines (e.g., ensure all are `\n`).
|
||||||
|
2. Trim leading/trailing whitespace from the entire document and from individual lines.
|
||||||
|
3. Collapse multiple consecutive blank lines into a maximum of one or two.
|
||||||
|
4. Optionally, strip or handle any rare HTML remnants if `crawl4ai`'s conversion wasn't perfect (though `raw_markdown` should be fairly clean).
|
||||||
|
|
||||||
|
* **`crawl4ai` - `cleaned_html` Output:**
|
||||||
|
* **Goal:** Convert to well-structured plain text.
|
||||||
|
* **Steps:**
|
||||||
|
1. Use a robust HTML-to-text converter (e.g., `BeautifulSoup(html_content, 'html.parser').get_text(separator='\n', strip=True)` or `html2text`).
|
||||||
|
2. Ensure paragraph breaks and list structures from HTML are preserved as newlines or appropriate text formatting.
|
||||||
|
3. Remove any residual non-content elements (e.g., rare JavaScript snippets, CSS, boilerplate not fully caught by `crawl4ai`).
|
||||||
|
4. Normalize whitespace (trim, collapse multiple spaces).
|
||||||
|
|
||||||
|
* **Local Plain Text (.txt) & Markdown (.md):**
|
||||||
|
* **Goal:** Clean and standardize plain text.
|
||||||
|
* **Steps:**
|
||||||
|
1. Normalize newlines.
|
||||||
|
2. Trim leading/trailing whitespace.
|
||||||
|
3. Collapse multiple blank lines.
|
||||||
|
4. For Markdown, ensure its syntax is preserved for chunkers that might leverage it.
|
||||||
|
|
||||||
|
* **PDF (Text-Based, output from `PyMuPDF` etc.):**
|
||||||
|
* **Goal:** Clean artifacts from PDF text extraction.
|
||||||
|
* **Steps:**
|
||||||
|
1. Replace common ligatures (e.g., "fi" to "fi", "fl" to "fl").
|
||||||
|
2. Attempt to rejoin hyphenated words broken across lines (can be challenging; may use heuristics or dictionary-based approaches).
|
||||||
|
3. Identify and remove repetitive headers/footers if not handled by the extraction library (e.g., using pattern matching or positional analysis if page layout is consistent).
|
||||||
|
4. Normalize whitespace and remove excessive blank lines.
|
||||||
|
|
||||||
|
* **PDF (Scanned/Image-Based, OCR output from `pytesseract`, `EasyOCR`):**
|
||||||
|
* **Goal:** Correct OCR errors and improve readability.
|
||||||
|
* **Steps:**
|
||||||
|
1. Apply spelling correction (e.g., using `pyspellchecker` or a similar library).
|
||||||
|
2. Filter out OCR noise or gibberish (e.g., sequences of non-alphanumeric characters unlikely to be valid text, very short isolated "words").
|
||||||
|
3. Attempt to reconstruct paragraph/section structure if lost during OCR (e.g., based on analyzing vertical spacing if available from OCR engine, or using NLP techniques to group related sentences).
|
||||||
|
4. Normalize all forms of whitespace.
|
||||||
|
|
||||||
|
* **DOCX (Output from `python-docx`):**
|
||||||
|
* **Goal:** Clean text extracted from Word documents.
|
||||||
|
* **Steps:**
|
||||||
|
1. Normalize whitespace (trim, collapse multiple spaces/newlines).
|
||||||
|
2. Remove artifacts from complex Word formatting if they translate poorly to plain text.
|
||||||
|
3. Handle lists and table text appropriately if extracted.
|
||||||
|
|
||||||
|
* **Image OCR Output (from standalone images):**
|
||||||
|
* **Goal:** Similar to scanned PDF OCR.
|
||||||
|
* **Steps:**
|
||||||
|
1. Spelling correction.
|
||||||
|
2. Noise removal.
|
||||||
|
3. Whitespace normalization.
|
||||||
|
|
||||||
|
* **Emails (Text extracted from .eml, .msg, .mbox):**
|
||||||
|
* **Goal:** Isolate main content and standardize.
|
||||||
|
* **Steps:**
|
||||||
|
1. Remove or clearly demarcate headers (From, To, Subject, Date), footers, and common disclaimers.
|
||||||
|
2. Normalize quoting styles for replies (e.g., convert `>` prefixes consistently, or attempt to strip quoted history if only the latest message is desired).
|
||||||
|
3. If original was HTML, ensure clean conversion to text, preserving paragraph structure.
|
||||||
|
4. Standardize signature blocks or remove them.
|
||||||
|
5. Normalize whitespace.
|
||||||
|
|
||||||
|
* **Code Snippets (Raw text):**
|
||||||
|
* **Goal:** Prepare code for embedding, preserving semantic structure.
|
||||||
|
* **Steps:**
|
||||||
|
1. Normalize newlines.
|
||||||
|
2. Ensure consistent indentation (e.g., convert tabs to spaces, or vice-versa, though often best left as-is if the chunker can use indentation for semantic grouping).
|
||||||
|
3. Decide on handling comments: strip them, or preserve them as they provide valuable context for RAG.
|
||||||
|
4. For some languages, normalizing case for keywords might be considered, but generally not required for modern embedding models.
|
||||||
|
|
||||||
|
### Part 2: Application of Advanced Chunking Techniques to Preprocessed Data
|
||||||
|
|
||||||
|
Once the data from various sources has been preprocessed into clean text, the following advanced chunking strategies can be applied:
|
||||||
|
|
||||||
|
* **Semantic Chunking:**
|
||||||
|
* **Principle:** Divides text into chunks based on semantic similarity or topic coherence, rather than fixed sizes. Often uses embedding models to measure sentence/paragraph similarity.
|
||||||
|
* **Application to Diverse Data:**
|
||||||
|
* **`crawl4ai` (Web Articles, Docs):** Effectively groups paragraphs or sections discussing the same sub-topic. Can leverage existing HTML structure (like `<section>`, `<h2>`) if translated well into the text.
|
||||||
|
* **Local Notes (.txt, .md):** Identifies coherent thoughts or topics within longer notes. Markdown headings can provide strong hints.
|
||||||
|
* **PDF/DOCX (Reports, Chapters):** Groups related paragraphs within sections, even if original formatting cues are subtle in the extracted text.
|
||||||
|
* **Emails:** Can separate distinct topics if an email covers multiple subjects, or keep a single coherent discussion together.
|
||||||
|
* **Code Snippets:** Can group entire functions, classes, or logically related blocks of code, especially if comments and docstrings are included in the preprocessed text.
|
||||||
|
* **Conceptual Example (Python-like pseudo-code):**
|
||||||
|
```python
|
||||||
|
# from semantic_text_splitter import CharacterTextSplitter # Example library
|
||||||
|
# text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
|
||||||
|
# "sentence-transformers/all-MiniLM-L6-v2", # Example embedding model
|
||||||
|
# chunk_size=512, # Target, but will split semantically
|
||||||
|
# chunk_overlap=50
|
||||||
|
# )
|
||||||
|
# chunks = text_splitter.split_text(preprocessed_text_from_any_source)
|
||||||
|
```
|
||||||
|
|
||||||
|
* **Proposition Chunking:**
|
||||||
|
* **Principle:** Breaks down text into atomic, factual statements or propositions. This often involves using an LLM to rephrase or extract these propositions.
|
||||||
|
* **Application to Diverse Data:**
|
||||||
|
* **`crawl4ai` (Factual Articles):** Ideal for extracting key facts, figures, and claims from news or informational web pages.
|
||||||
|
* **Local Notes (.txt, .md):** Converts personal notes or meeting minutes into a list of distinct facts, ideas, or action items.
|
||||||
|
* **PDF/DOCX (Dense Documents):** Extracts core assertions from academic papers, technical manuals, or reports.
|
||||||
|
* **Emails:** Isolates key information, decisions, or requests from email conversations.
|
||||||
|
* **Code Snippets:** Less about the code logic itself, but can extract factual statements from docstrings or high-level comments (e.g., "Function `calculate_sum` returns the total of a list.").
|
||||||
|
* **Conceptual Example (Python-like pseudo-code):**
|
||||||
|
```python
|
||||||
|
# llm_client = YourLLMClient() # Interface to an LLM
|
||||||
|
# prompt = f"Extract all distinct factual propositions from the following text. Each proposition should be a complete sentence and stand alone:\n\n{preprocessed_text_from_any_source}"
|
||||||
|
# response = llm_client.generate(prompt)
|
||||||
|
# propositions = response.splitlines() # Assuming LLM returns one prop per line
|
||||||
|
# # Each proposition in 'propositions' is a chunk
|
||||||
|
```
|
||||||
|
|
||||||
|
* **Contextual Chunk Headers:**
|
||||||
|
* **Principle:** Adds a small piece of contextual metadata (e.g., source, title, section) as a prefix to each chunk's text before embedding. This helps the retrieval system and LLM understand the chunk's origin and context.
|
||||||
|
* **Application to Diverse Data (Header Examples):**
|
||||||
|
* **`crawl4ai` Output:** `"[Source: Web Article | URL: {url} | Title: {page_title} | Section: {nearest_heading_text}]\n{chunk_text}"`
|
||||||
|
* **Local Markdown Note:** `"[Source: Local Note | File: {filename} | Path: {relative_path} | Title: {document_title_from_h1_or_filename}]\n{chunk_text}"`
|
||||||
|
* **PDF Document:** `"[Source: PDF Document | File: {filename} | Title: {pdf_title_metadata} | Page: {page_number}]\n{chunk_text}"`
|
||||||
|
* **Email:** `"[Source: Email | From: {sender} | Subject: {subject} | Date: {date}]\n{chunk_text}"`
|
||||||
|
* **Code Snippet:** `"[Source: Code File | File: {filename} | Language: {language} | Context: {function/class_name}]\n{chunk_text}"`
|
||||||
|
* **Implementation:** This is typically done by constructing the header string from the document's metadata (see Unified Metadata Schema) and prepending it to the chunk content before it's passed to the embedding model.
|
||||||
|
## Phase 3: Deep Dive into Specific RAG Techniques
|
||||||
|
|
||||||
|
### 1. Semantic Chunking
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 2. Fusion Retrieval (Hybrid Search) & Reranking (Initial Summary)
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 3. RAG-Fusion (Query Generation & Reciprocal Rank Fusion)
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 4. Reranking with Cross-Encoders and LLMs
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 5. RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval)
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 6. Corrective RAG (CRAG)
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
---
|
||||||
|
## Phase 4: Deep Dive into Vector Databases (for Local RAG)
|
||||||
|
|
||||||
|
This section will explore various vector database options suitable for a local personal knowledge base, focusing on ease of setup, Python integration, performance for typical personal KB sizes, and maintenance.
|
||||||
|
|
||||||
|
### 1. ChromaDB
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 2. FAISS (Facebook AI Similarity Search)
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 3. Qdrant
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 4. Weaviate
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 5. SQLite-VSS (SQLite Vector Similarity Search)
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
---
|
||||||
|
## Phase 5: Deep Dive into RAG Frameworks
|
||||||
|
|
||||||
|
This section will explore RAG orchestration frameworks, focusing on their suitability for building a custom RAG pipeline for a personal knowledge base.
|
||||||
|
|
||||||
|
### 1. LangChain vs. LlamaIndex
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 2. Haystack (by deepset)
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
### 3. Semantic Kernel (by Microsoft)
|
||||||
|
... (content as before) ...
|
||||||
|
|
||||||
|
---
|
||||||
|
## Phase 6: Synthesis and Recommendations for Personal Knowledge Base RAG
|
||||||
|
|
||||||
|
This section synthesizes the research on RAG techniques, vector databases, and frameworks to provide recommendations for building a personal knowledge base for an AI agent, with `crawl4ai` as a primary data ingestion tool.
|
||||||
|
|
||||||
|
### I. Core Goal Recap & Decision on Unified vs. Separate RAG
|
||||||
|
|
||||||
|
* **Primary Goal:** Create a robust, adaptable, and efficiently queryable personal knowledge base for an AI agent. This KB will store diverse information, including scraped web content, project notes, and potentially code.
|
||||||
|
* **Secondary Goal:** Consider if this personal KB can be unified with a work-related KB or if they should remain separate.
|
||||||
|
* **Recommendation:** For initial development and given the "personal" focus, **start with a separate RAG for the personal knowledge base.**
|
||||||
|
* **Reasoning:**
|
||||||
|
* **Simplicity:** Reduces initial complexity in terms of data siloing, access control, and differing update cadences.
|
||||||
|
* **Privacy:** Keeps personal data distinctly separate, which is crucial.
|
||||||
|
* **Focus:** Allows tailoring the RAG system (chunking, embedding, retrieval strategies) specifically to the nature of personal data and queries.
|
||||||
|
* **Future Unification:** A well-designed personal RAG can potentially be integrated or federated with a work RAG later if designed with modularity in mind. The core challenge would be managing context and preventing data leakage between the two.
|
||||||
|
|
||||||
|
### II. Recommended RAG Framework
|
||||||
|
|
||||||
|
Based on the research, here's a comparison and recommendation:
|
||||||
|
|
||||||
|
| Feature/Aspect | LangChain | LlamaIndex | Haystack | Semantic Kernel |
|
||||||
|
| :---------------------- | :------------------------------------------ | :--------------------------------------------- | :--------------------------------------------- | :------------------------------------------ |
|
||||||
|
| **Primary Focus** | General LLM App Dev, Agents, Chains | Data Framework for LLM Apps, RAG focus | LLM Orchestration, Production RAG, Pipelines | Agentic AI, Planning, Plugins |
|
||||||
|
| **Ease of Basic RAG** | Moderate | High (RAG-centric abstractions) | Moderate (explicit pipeline setup) | Moderate (RAG is a pattern to build) |
|
||||||
|
| **Advanced RAG** | High (many components, flexible) | High (many advanced RAG modules) | High (flexible pipelines, diverse components) | Moderate (via plugins, memory) |
|
||||||
|
| **`crawl4ai` Integration** | Native `Crawl4aiLoader` | Custom loader needed (easy to adapt) | Custom integration needed (easy to adapt) | Custom plugin/memory integration needed |
|
||||||
|
| **Local Deployment** | Excellent (many local model/DB integrations) | Excellent (many local model/DB integrations) | Excellent (many local model/DB integrations) | Good (supports local models, memory stores) |
|
||||||
|
| **Modularity** | High (Chains, LCEL) | High (Engines, Indices, Retrievers) | Very High (Pipelines, Components) | Very High (Kernel, Plugins) |
|
||||||
|
| **Community/Docs** | Very Large | Large, RAG-focused | Good, Growing | Good, Microsoft-backed |
|
||||||
|
| **Personal KB Fit** | Good, flexible | **Excellent**, RAG-first design simplifies setup | Very Good, robust for evolving needs | Good, esp. if agent needs more than RAG |
|
||||||
|
|
||||||
|
**Primary Recommendation: LlamaIndex**
|
||||||
|
|
||||||
|
* **Reasons:**
|
||||||
|
* **RAG-Centric Design:** LlamaIndex is built from the ground up for connecting custom data to LLMs, making RAG its core competency. This often leads to more intuitive and quicker setup for RAG-specific tasks.
|
||||||
|
* **Ease of Use for Core RAG:** High-level abstractions for indexing, retrieval, and query engines simplify building a basic-to-intermediate personal RAG.
|
||||||
|
* **Advanced RAG Features:** Offers a rich set of modules for advanced RAG techniques (e.g., various retrievers, node postprocessors/rerankers, query transformations) that can be incrementally added.
|
||||||
|
* **Strong Local Ecosystem:** Excellent support for local embedding models, LLMs (via Ollama, HuggingFace), and local vector stores.
|
||||||
|
* **Data Ingestion Flexibility:** While no direct `crawl4ai` loader exists *yet*, creating one or simply passing `crawl4ai`'s text output to LlamaIndex's `Document` objects for indexing is straightforward.
|
||||||
|
* **Python Native:** Aligns well with `crawl4ai` and common data science/AI workflows.
|
||||||
|
|
||||||
|
**Runner-Up Recommendation: LangChain**
|
||||||
|
|
||||||
|
* **Reasons:**
|
||||||
|
* **Mature and Flexible:** A very versatile framework with a vast ecosystem of integrations.
|
||||||
|
* **Native `Crawl4aiLoader`:** Simplifies the initial data ingestion step from `crawl4ai`.
|
||||||
|
* **Strong for Complex Chains/Agents:** If the AI agent's capabilities extend significantly beyond RAG, LangChain's strengths in building complex chains and agents become more prominent.
|
||||||
|
* **Large Community:** Extensive documentation, tutorials, and community support.
|
||||||
|
* Can achieve everything LlamaIndex can for RAG, but sometimes with more boilerplate or a less RAG-specific API.
|
||||||
|
|
||||||
|
**Why not Haystack or Semantic Kernel as primary for *this specific* goal?**
|
||||||
|
* **Haystack:** Very powerful and modular, excellent for production and complex pipelines. However, for a *personal* KB, its explicitness in pipeline definition might be slightly more setup than LlamaIndex for common RAG patterns. It's a strong contender if the personal RAG is expected to become very complex or integrate deeply with other Haystack-specific tooling.
|
||||||
|
* **Semantic Kernel:** Its primary strength lies in agentic AI, planning, and function calling. While RAG is achievable via its "Memory" and plugin system, it's more of a capability you build *into* a Semantic Kernel agent rather than the framework's central focus. If the AI agent's core is complex task execution and RAG is just one tool, SK is excellent. If RAG *is* the core, LlamaIndex or LangChain might be more direct.
|
||||||
|
|
||||||
|
### III. Recommended Local Vector Database
|
||||||
|
|
||||||
|
**Primary Recommendation: ChromaDB**
|
||||||
|
|
||||||
|
* **Reasons:**
|
||||||
|
* **Ease of Use & Setup:** `pip install chromadb`, runs in-memory by default or can persist to disk easily. Very developer-friendly.
|
||||||
|
* **Python-Native:** Designed with Python applications in mind.
|
||||||
|
* **Good Integration:** Well-supported by LlamaIndex and LangChain.
|
||||||
|
* **Sufficient for Personal KB:** Scales well enough for typical personal knowledge base sizes.
|
||||||
|
* **Metadata Filtering:** Supports filtering by metadata, which is crucial for a personal KB (e.g., by source, date, tags).
|
||||||
|
|
||||||
|
**Runner-Up Recommendation: SQLite-VSS**
|
||||||
|
|
||||||
|
* **Reasons:**
|
||||||
|
* **Simplicity of SQLite:** If already using SQLite for other application data, adding vector search via an extension is very convenient. No separate database server to manage.
|
||||||
|
* **Good Enough Performance:** For many personal KB use cases, performance will be adequate.
|
||||||
|
* **Growing Ecosystem:** Gaining traction and support in frameworks.
|
||||||
|
|
||||||
|
**Why not others for initial setup?**
|
||||||
|
* **FAISS:** A library, not a full database. Requires more manual setup for persistence and serving, though powerful for raw similarity search. Often used *under the hood* by other vector DBs or frameworks.
|
||||||
|
* **Qdrant/Weaviate:** More feature-rich and scalable, potentially overkill for a basic personal KB's initial setup. They are excellent choices if the KB grows very large or requires more advanced features not easily met by ChromaDB. They can be considered for a "version 2" of the personal RAG.
|
||||||
|
* **Scalability Considerations for a Unified PUKB:** Given that a "Personal Unified Knowledge Base" might grow significantly with diverse local files (text, PDFs, images, emails, code) and web scrapes, the scalability of the chosen vector database becomes more pertinent.
|
||||||
|
* **ChromaDB & SQLite-VSS:** While excellent for starting and for moderately sized KBs due to their ease of setup, their performance might degrade with many millions of diverse vectors or very complex metadata filtering if the PUKB becomes extremely large. SQLite-VSS, being embedded, also shares resources with the main application.
|
||||||
|
* **Qdrant & Weaviate:** These are designed for larger scale and offer more advanced features like optimized filtering, quantization, and potentially better performance under heavy load or with massive datasets. They typically require running as separate services (often via Docker), which adds a small layer of setup complexity compared to embedded solutions.
|
||||||
|
* **Recommendation Adjustment:** For a PUKB envisioned to be *very large and diverse from the outset*, or if initial prototypes with ChromaDB/SQLite-VSS show performance bottlenecks with representative data volumes, considering Qdrant or Weaviate *earlier* in the development lifecycle (perhaps as an alternative for Phase 1 or a direct step into Phase 2 for database setup) would be prudent. The trade-off is initial simplicity versus future-proofing for scale and advanced features. Migration later is possible but involves effort.
|
||||||
|
|
||||||
|
### IV. Recommended Initial RAG Techniques
|
||||||
|
|
||||||
|
Start with a solid foundation and iteratively add complexity:
|
||||||
|
|
||||||
|
1. **Data Ingestion (`crawl4ai` + Manual):**
|
||||||
|
* Use `crawl4ai` to scrape web content.
|
||||||
|
* Develop a simple way to add manual notes (e.g., Markdown files in a directory).
|
||||||
|
2. **Chunking Strategy:**
|
||||||
|
* **Start with:** Recursive Character Text Splitting (available in LlamaIndex/LangChain) with sensible chunk size (e.g., 500-1000 tokens) and overlap (e.g., 50-100 tokens).
|
||||||
|
* **Consider for V2:** Semantic Chunking or Proposition-based chunking for more meaningful segments, especially for diverse personal data.
|
||||||
|
3. **Embedding Model:**
|
||||||
|
* **Start with:** A good open-source sentence-transformer model (e.g., `all-MiniLM-L6-v2` for a balance of speed and quality, or a model from the MTEB leaderboard). LlamaIndex/LangChain make these easy to use.
|
||||||
|
* **Alternative:** If API costs are not a concern, OpenAI's `text-embedding-3-small` is a strong performer.
|
||||||
|
4. **Vector Storage:**
|
||||||
|
* ChromaDB (persisted to disk).
|
||||||
|
5. **Retrieval Strategy:**
|
||||||
|
* **Start with:** Basic semantic similarity search (top-k retrieval).
|
||||||
|
* **Add Early:** Metadata filtering (e.g., filter by source URL, date added, custom tags).
|
||||||
|
6. **Reranking:**
|
||||||
|
* **Consider adding soon after basic retrieval:** A simple reranker like a Cross-Encoder (e.g., `ms-marco-MiniLM-L-6-v2`) to improve the relevance of top-k results before sending to the LLM. LlamaIndex has `SentenceTransformerRerank`.
|
||||||
|
7. **LLM for Generation:**
|
||||||
|
* A local LLM via Ollama (e.g., Llama 3, Mistral) or a smaller, efficient model.
|
||||||
|
* Or an API-based model if preferred (OpenAI, Anthropic).
|
||||||
|
8. **Prompting:**
|
||||||
|
* Standard RAG prompt: "Use the following context to answer the question. Context: {context_str} Question: {query_str} Answer:"
|
||||||
|
|
||||||
|
### V. Phased Implementation Approach
|
||||||
|
|
||||||
|
1. **Phase 1: Core RAG Pipeline Setup**
|
||||||
|
* Choose framework (LlamaIndex recommended).
|
||||||
|
* Set up `crawl4ai` for data ingestion from a few sample websites.
|
||||||
|
* Implement basic chunking, embedding (local model), and storage in ChromaDB.
|
||||||
|
* Build a simple query engine for semantic search.
|
||||||
|
* Integrate a local LLM for generation.
|
||||||
|
* Test with basic queries.
|
||||||
|
2. **Phase 2: Enhancing Retrieval Quality**
|
||||||
|
* Implement metadata storage and filtering.
|
||||||
|
* Add a reranking step (e.g., Cross-Encoder).
|
||||||
|
* Experiment with different chunking strategies and embedding models.
|
||||||
|
* Develop a simple way to add/update manual notes.
|
||||||
|
3. **Phase 3: Advanced Features & Agent Integration**
|
||||||
|
* Explore more advanced retrieval techniques if needed (e.g., HyDE, fusion retrieval if keyword search is important).
|
||||||
|
* Consider query transformations.
|
||||||
|
* Integrate the RAG system as a tool for the AI agent.
|
||||||
|
* Develop a basic UI or CLI for interaction.
|
||||||
|
* Start thinking about evaluation (even if manual).
|
||||||
|
4. **Phase 4: Long-Term Enhancements**
|
||||||
|
* Explore techniques like CRAG or RAPTOR if complexity is justified.
|
||||||
|
* Implement more robust update/synchronization mechanisms for the knowledge base.
|
||||||
|
* Consider GraphRAG if relationships between personal data points become important.
|
||||||
|
|
||||||
|
### VI. Addressing Work vs. Personal Knowledge Base
|
||||||
|
|
||||||
|
* The recommendation to start with a separate personal KB allows focused development.
|
||||||
|
* If a unified system is desired later:
|
||||||
|
* **Data Separation:** Use distinct metadata tags (e.g., `source_type: "work"` vs. `source_type: "personal"`) within a single vector store.
|
||||||
|
* **Query-Time Filtering:** Ensure queries to the "work" aspect only retrieve from work-tagged documents, and vice-versa. This is critical.
|
||||||
|
* **Agent Context:** The AI agent must be aware of which "mode" it's in (work or personal) to apply the correct filters.
|
||||||
|
* This adds complexity but is feasible with careful design in frameworks like LlamaIndex or LangChain.
|
||||||
|
|
||||||
|
This synthesis provides a roadmap. The next practical step would be to start implementing Phase 1.
|
73
router/__init__.py
Normal file
73
router/__init__.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
"""AI Router - Clean abstraction for AI providers."""
|
||||||
|
|
||||||
|
from .base import AIRouter, RouterResponse
|
||||||
|
from .config import RouterConfig, config
|
||||||
|
from .exceptions import (
|
||||||
|
RouterError,
|
||||||
|
AuthenticationError,
|
||||||
|
RateLimitError,
|
||||||
|
ModelNotFoundError,
|
||||||
|
InvalidRequestError,
|
||||||
|
TimeoutError,
|
||||||
|
ContentFilterError,
|
||||||
|
QuotaExceededError,
|
||||||
|
ProviderError,
|
||||||
|
ConfigurationError,
|
||||||
|
)
|
||||||
|
from .gemini import Gemini
|
||||||
|
from .openai_compatible import OpenAI, OpenAICompatible
|
||||||
|
from .cohere import Cohere
|
||||||
|
from .rerank import Rerank, CohereRerank, RerankDocument, RerankResult, RerankResponse
|
||||||
|
from .embed import (
|
||||||
|
Embed,
|
||||||
|
CohereEmbed,
|
||||||
|
GeminiEmbed,
|
||||||
|
OllamaEmbed,
|
||||||
|
CohereEmbedding,
|
||||||
|
GeminiEmbedding,
|
||||||
|
OllamaEmbedding,
|
||||||
|
EmbeddingResponse,
|
||||||
|
create_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Base classes
|
||||||
|
"AIRouter",
|
||||||
|
"RouterResponse",
|
||||||
|
"RouterConfig",
|
||||||
|
"config",
|
||||||
|
# Generation Providers
|
||||||
|
"Gemini",
|
||||||
|
"OpenAI",
|
||||||
|
"OpenAICompatible",
|
||||||
|
"Cohere",
|
||||||
|
# Reranking
|
||||||
|
"Rerank",
|
||||||
|
"CohereRerank",
|
||||||
|
"RerankDocument",
|
||||||
|
"RerankResult",
|
||||||
|
"RerankResponse",
|
||||||
|
# Embeddings
|
||||||
|
"Embed",
|
||||||
|
"CohereEmbed",
|
||||||
|
"GeminiEmbed",
|
||||||
|
"OllamaEmbed",
|
||||||
|
"CohereEmbedding",
|
||||||
|
"GeminiEmbedding",
|
||||||
|
"OllamaEmbedding",
|
||||||
|
"EmbeddingResponse",
|
||||||
|
"create_embedding",
|
||||||
|
# Exceptions
|
||||||
|
"RouterError",
|
||||||
|
"AuthenticationError",
|
||||||
|
"RateLimitError",
|
||||||
|
"ModelNotFoundError",
|
||||||
|
"InvalidRequestError",
|
||||||
|
"TimeoutError",
|
||||||
|
"ContentFilterError",
|
||||||
|
"QuotaExceededError",
|
||||||
|
"ProviderError",
|
||||||
|
"ConfigurationError",
|
||||||
|
]
|
178
router/base.py
Normal file
178
router/base.py
Normal file
|
@ -0,0 +1,178 @@
|
||||||
|
"""Base abstract class for AI router implementations."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RouterResponse:
|
||||||
|
"""Standardized response object for all AI providers."""
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
latency: float = 0.0
|
||||||
|
input_tokens: Optional[int] = None
|
||||||
|
output_tokens: Optional[int] = None
|
||||||
|
total_tokens: Optional[int] = None
|
||||||
|
cost: Optional[float] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
raw_response: Optional[Any] = None
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.content
|
||||||
|
|
||||||
|
|
||||||
|
class AIRouter(ABC):
|
||||||
|
"""Abstract base class for AI provider routers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the router with model and configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model identifier for the provider
|
||||||
|
api_key: API key for authentication (optional if set in environment)
|
||||||
|
**kwargs: Additional provider-specific configuration
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.config = kwargs
|
||||||
|
self.provider = self.__class__.__name__.lower()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
"""Prepare provider-specific request parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of request parameters
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
|
||||||
|
"""Parse provider response into RouterResponse.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_response: Raw response from provider
|
||||||
|
latency: Request latency in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Standardized RouterResponse
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _make_request(self, request_params: Dict[str, Any]) -> Any:
|
||||||
|
"""Make synchronous request to provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_params: Provider-specific request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw provider response
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
|
||||||
|
"""Make asynchronous request to provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_params: Provider-specific request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw provider response
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def call(self, prompt: str, **kwargs: Any) -> RouterResponse:
|
||||||
|
"""Make a synchronous call to the AI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters (model, temperature, max_tokens, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RouterResponse with the result
|
||||||
|
"""
|
||||||
|
# Merge default config with call-specific parameters
|
||||||
|
params = {**self.config, **kwargs}
|
||||||
|
|
||||||
|
# Prepare request
|
||||||
|
request_params = self._prepare_request(prompt, **params)
|
||||||
|
|
||||||
|
# Make request and measure latency
|
||||||
|
start_time = time.time()
|
||||||
|
raw_response = self._make_request(request_params)
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
# Parse and return response
|
||||||
|
return self._parse_response(raw_response, latency)
|
||||||
|
|
||||||
|
async def acall(self, prompt: str, **kwargs: Any) -> RouterResponse:
|
||||||
|
"""Make an asynchronous call to the AI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters (model, temperature, max_tokens, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RouterResponse with the result
|
||||||
|
"""
|
||||||
|
# Merge default config with call-specific parameters
|
||||||
|
params = {**self.config, **kwargs}
|
||||||
|
|
||||||
|
# Prepare request
|
||||||
|
request_params = self._prepare_request(prompt, **params)
|
||||||
|
|
||||||
|
# Make request and measure latency
|
||||||
|
start_time = time.time()
|
||||||
|
raw_response = await self._make_async_request(request_params)
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
# Parse and return response
|
||||||
|
return self._parse_response(raw_response, latency)
|
||||||
|
|
||||||
|
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
|
||||||
|
"""Stream responses from the AI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
RouterResponse chunks
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If streaming is not supported
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{self.provider} does not support streaming yet")
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, prompt: str, **kwargs: Any
|
||||||
|
) -> AsyncGenerator[RouterResponse, None]:
|
||||||
|
"""Asynchronously stream responses from the AI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
RouterResponse chunks
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If async streaming is not supported
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{self.provider} does not support async streaming yet")
|
||||||
|
yield # Required for async generator type hint
|
259
router/cohere.py
Normal file
259
router/cohere.py
Normal file
|
@ -0,0 +1,259 @@
|
||||||
|
"""Cohere provider implementation."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, AsyncGenerator, Generator
|
||||||
|
import time
|
||||||
|
|
||||||
|
from .base import AIRouter, RouterResponse
|
||||||
|
from .config import config
|
||||||
|
from .exceptions import (
|
||||||
|
ConfigurationError,
|
||||||
|
map_provider_error
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Cohere(AIRouter):
|
||||||
|
"""Router for Cohere models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "command-r-plus",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize Cohere router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Cohere model to use (command-r-plus, command-r, etc.)
|
||||||
|
api_key: Cohere API key (optional if set in environment)
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
# Get API key from config if not provided
|
||||||
|
if not api_key:
|
||||||
|
api_key = config.get_api_key("cohere")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
|
||||||
|
provider="cohere"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
# Initialize Cohere client
|
||||||
|
try:
|
||||||
|
import cohere
|
||||||
|
# For v5.15.0, use the standard Client
|
||||||
|
self.client = cohere.Client(api_key)
|
||||||
|
self.async_client = cohere.AsyncClient(api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"cohere package not installed. Install with: pip install cohere",
|
||||||
|
provider="cohere"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
"""Prepare Cohere request parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Request parameters
|
||||||
|
"""
|
||||||
|
# Build request parameters for v5.15.0 API
|
||||||
|
params = {
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"message": prompt, # v5.15.0 uses 'message' not 'messages'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if "temperature" in kwargs:
|
||||||
|
params["temperature"] = kwargs["temperature"]
|
||||||
|
elif hasattr(config, "default_temperature"):
|
||||||
|
params["temperature"] = config.default_temperature
|
||||||
|
|
||||||
|
if "max_tokens" in kwargs:
|
||||||
|
params["max_tokens"] = kwargs["max_tokens"]
|
||||||
|
elif hasattr(config, "default_max_tokens") and config.default_max_tokens:
|
||||||
|
params["max_tokens"] = config.default_max_tokens
|
||||||
|
|
||||||
|
# Cohere uses 'p' instead of 'top_p'
|
||||||
|
if "p" in kwargs:
|
||||||
|
params["p"] = kwargs["p"]
|
||||||
|
elif "top_p" in kwargs:
|
||||||
|
params["p"] = kwargs["top_p"]
|
||||||
|
elif hasattr(config, "default_top_p"):
|
||||||
|
params["p"] = config.default_top_p
|
||||||
|
|
||||||
|
# Cohere uses 'k' instead of 'top_k'
|
||||||
|
if "k" in kwargs:
|
||||||
|
params["k"] = kwargs["k"]
|
||||||
|
elif "top_k" in kwargs:
|
||||||
|
params["k"] = kwargs["top_k"]
|
||||||
|
|
||||||
|
# Other Cohere-specific parameters for v5.15.0
|
||||||
|
for key in ["chat_history", "preamble", "conversation_id", "prompt_truncation", "connectors", "search_queries_only", "documents", "tools", "tool_results", "stop_sequences", "seed"]:
|
||||||
|
if key in kwargs:
|
||||||
|
params[key] = kwargs[key]
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
|
||||||
|
"""Parse Cohere response into RouterResponse.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_response: Raw Cohere response
|
||||||
|
latency: Request latency
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RouterResponse
|
||||||
|
"""
|
||||||
|
# Extract text content - v5.15.0 has 'text' attribute directly
|
||||||
|
content = getattr(raw_response, "text", "")
|
||||||
|
|
||||||
|
# Get token counts from meta/usage
|
||||||
|
input_tokens = None
|
||||||
|
output_tokens = None
|
||||||
|
|
||||||
|
# Try different token count locations based on v5.15.0 structure
|
||||||
|
if hasattr(raw_response, "meta") and hasattr(raw_response.meta, "billed_units"):
|
||||||
|
billed = raw_response.meta.billed_units
|
||||||
|
input_tokens = getattr(billed, "input_tokens", None)
|
||||||
|
output_tokens = getattr(billed, "output_tokens", None)
|
||||||
|
elif hasattr(raw_response, "usage") and hasattr(raw_response.usage, "billed_units"):
|
||||||
|
billed = raw_response.usage.billed_units
|
||||||
|
input_tokens = getattr(billed, "input_tokens", None)
|
||||||
|
output_tokens = getattr(billed, "output_tokens", None)
|
||||||
|
|
||||||
|
# Calculate cost if available
|
||||||
|
cost = None
|
||||||
|
if input_tokens and output_tokens and config.track_costs:
|
||||||
|
cost = config.calculate_cost(self.model, input_tokens, output_tokens)
|
||||||
|
|
||||||
|
# Extract finish reason
|
||||||
|
finish_reason = "stop"
|
||||||
|
if hasattr(raw_response, "finish_reason"):
|
||||||
|
finish_reason = raw_response.finish_reason
|
||||||
|
|
||||||
|
return RouterResponse(
|
||||||
|
content=content,
|
||||||
|
model=self.model,
|
||||||
|
provider="cohere",
|
||||||
|
latency=latency,
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=(input_tokens + output_tokens) if input_tokens and output_tokens else None,
|
||||||
|
cost=cost,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
metadata={
|
||||||
|
"id": getattr(raw_response, "id", None),
|
||||||
|
"generation_id": getattr(raw_response, "generation_id", None),
|
||||||
|
"citations": getattr(raw_response, "citations", None),
|
||||||
|
"documents": getattr(raw_response, "documents", None),
|
||||||
|
"search_results": getattr(raw_response, "search_results", None),
|
||||||
|
"search_queries": getattr(raw_response, "search_queries", None),
|
||||||
|
"tool_calls": getattr(raw_response, "tool_calls", None),
|
||||||
|
},
|
||||||
|
raw_response=raw_response
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_request(self, request_params: Dict[str, Any]) -> Any:
|
||||||
|
"""Make synchronous request to Cohere.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_params: Request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw Cohere response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.client.chat(**request_params)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("cohere", e)
|
||||||
|
|
||||||
|
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
|
||||||
|
"""Make asynchronous request to Cohere.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_params: Request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw Cohere response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await self.async_client.chat(**request_params)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("cohere", e)
|
||||||
|
|
||||||
|
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
|
||||||
|
"""Stream responses from Cohere.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
RouterResponse chunks
|
||||||
|
"""
|
||||||
|
params = {**self.config, **kwargs}
|
||||||
|
request_params = self._prepare_request(prompt, **params)
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
stream = self.client.chat_stream(**request_params)
|
||||||
|
|
||||||
|
for event in stream:
|
||||||
|
# v5.15.0 uses event_type and text directly
|
||||||
|
if hasattr(event, "event_type") and event.event_type == "text-generation":
|
||||||
|
content = getattr(event, "text", "")
|
||||||
|
if content:
|
||||||
|
yield RouterResponse(
|
||||||
|
content=content,
|
||||||
|
model=self.model,
|
||||||
|
provider="cohere",
|
||||||
|
latency=time.time() - start_time,
|
||||||
|
finish_reason=None,
|
||||||
|
metadata={"event_type": event.event_type},
|
||||||
|
raw_response=event
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("cohere", e)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, prompt: str, **kwargs: Any
|
||||||
|
) -> AsyncGenerator[RouterResponse, None]:
|
||||||
|
"""Asynchronously stream responses from Cohere.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
RouterResponse chunks
|
||||||
|
"""
|
||||||
|
params = {**self.config, **kwargs}
|
||||||
|
request_params = self._prepare_request(prompt, **params)
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
stream = self.async_client.chat_stream(**request_params)
|
||||||
|
|
||||||
|
async for event in stream:
|
||||||
|
# v5.15.0 uses event_type and text directly
|
||||||
|
if hasattr(event, "event_type") and event.event_type == "text-generation":
|
||||||
|
content = getattr(event, "text", "")
|
||||||
|
if content:
|
||||||
|
yield RouterResponse(
|
||||||
|
content=content,
|
||||||
|
model=self.model,
|
||||||
|
provider="cohere",
|
||||||
|
latency=time.time() - start_time,
|
||||||
|
finish_reason=None,
|
||||||
|
metadata={"event_type": event.event_type},
|
||||||
|
raw_response=event
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("cohere", e)
|
259
router/config.py
Normal file
259
router/config.py
Normal file
|
@ -0,0 +1,259 @@
|
||||||
|
"""Configuration management for AI router."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RouterConfig:
|
||||||
|
"""Configuration for AI router."""
|
||||||
|
|
||||||
|
# API Keys
|
||||||
|
openai_api_key: Optional[str] = field(default_factory=lambda: os.getenv("OPENAI_API_KEY"))
|
||||||
|
gemini_api_key: Optional[str] = field(default_factory=lambda: os.getenv("GEMINI_API_KEY"))
|
||||||
|
google_api_key: Optional[str] = field(default_factory=lambda: os.getenv("GOOGLE_API_KEY"))
|
||||||
|
cohere_api_key: Optional[str] = field(default_factory=lambda: os.getenv("COHERE_API_KEY"))
|
||||||
|
|
||||||
|
# Ollama configuration
|
||||||
|
ollama_base_url: str = field(default_factory=lambda: os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"))
|
||||||
|
|
||||||
|
# Available Ollama embedding models
|
||||||
|
ollama_embedding_models: List[str] = field(default_factory=lambda: [
|
||||||
|
"mxbai-embed-large:latest",
|
||||||
|
"nomic-embed-text:latest",
|
||||||
|
])
|
||||||
|
|
||||||
|
# Default model settings
|
||||||
|
default_temperature: float = 0.7
|
||||||
|
default_max_tokens: Optional[int] = None
|
||||||
|
default_top_p: float = 1.0
|
||||||
|
default_timeout: int = 160 # seconds
|
||||||
|
|
||||||
|
# Retry configuration
|
||||||
|
max_retries: int = 3
|
||||||
|
retry_delay: float = 1.0
|
||||||
|
retry_backoff: float = 2.0
|
||||||
|
|
||||||
|
# Cost tracking
|
||||||
|
track_costs: bool = False
|
||||||
|
|
||||||
|
# Cost per million tokens (prices in USD)
|
||||||
|
cost_per_1m_input_tokens: Dict[str, float] = field(default_factory=lambda: {
|
||||||
|
# OpenAI Compatible Models
|
||||||
|
"gpt-4o": 2.5,
|
||||||
|
"azure/gpt-4.1-new": 2.0,
|
||||||
|
"o3": 2.0,
|
||||||
|
"o4-mini": 1.1,
|
||||||
|
"claude-opus-4": 15.0,
|
||||||
|
"claude-sonnet-4": 3.0,
|
||||||
|
|
||||||
|
# Gemini Models
|
||||||
|
"gemini-2.5-pro": 1.25,
|
||||||
|
"gemini-2.5-flash": 0.3,
|
||||||
|
"gemini-2.0-flash-001": 0.3, # Alias for current default
|
||||||
|
|
||||||
|
# Cohere Models
|
||||||
|
"command-a": 2.5,
|
||||||
|
"command-r-plus": 2.5,
|
||||||
|
})
|
||||||
|
|
||||||
|
cost_per_1m_output_tokens: Dict[str, float] = field(default_factory=lambda: {
|
||||||
|
# OpenAI Compatible Models
|
||||||
|
"gpt-4o": 10.0,
|
||||||
|
"azure/gpt-4.1-new": 8.0,
|
||||||
|
"o3": 8.0,
|
||||||
|
"o4-mini": 4.4,
|
||||||
|
"claude-opus-4": 75.0,
|
||||||
|
"claude-sonnet-4": 15.0,
|
||||||
|
|
||||||
|
# Gemini Models
|
||||||
|
"gemini-2.5-pro": 10.0,
|
||||||
|
"gemini-2.5-flash": 2.5,
|
||||||
|
"gemini-2.0-flash-001": 2.5, # Alias for current default
|
||||||
|
|
||||||
|
# Cohere Models
|
||||||
|
"command-a": 10.0,
|
||||||
|
"command-r-plus": 10.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Cached input pricing for specific models (per million tokens)
|
||||||
|
cost_per_1m_cached_input_tokens: Dict[str, float] = field(default_factory=lambda: {
|
||||||
|
"azure/gpt-4.1-new": 0.5,
|
||||||
|
"o3": 0.5,
|
||||||
|
"o4-mini": 0.28,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Reranking costs (per 1k searches)
|
||||||
|
cost_per_1k_rerank_searches: Dict[str, float] = field(default_factory=lambda: {
|
||||||
|
"rerank-3.5": 2.0,
|
||||||
|
"rerank-english-v3.0": 2.0,
|
||||||
|
"rerank-multilingual-v3.0": 2.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Embedding costs (per million tokens)
|
||||||
|
cost_per_1m_embed_tokens: Dict[str, float] = field(default_factory=lambda: {
|
||||||
|
"embed-english-v3.0": 0.12,
|
||||||
|
"embed-multilingual-v3.0": 0.12,
|
||||||
|
"embed-english-light-v3.0": 0.12,
|
||||||
|
"text-embedding-004": 0.12, # Google's embedding model
|
||||||
|
# Ollama models are free (local)
|
||||||
|
"mxbai-embed-large:latest": 0.0,
|
||||||
|
"nomic-embed-text:latest": 0.0,
|
||||||
|
"nomic-embed-text:137m-v1.5-fp16": 0.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Image embedding costs (per million image tokens)
|
||||||
|
cost_per_1m_embed_image_tokens: Dict[str, float] = field(default_factory=lambda: {
|
||||||
|
"embed-english-v3.0": 0.47,
|
||||||
|
"embed-multilingual-v3.0": 0.47,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_requests: bool = False
|
||||||
|
log_responses: bool = False
|
||||||
|
log_errors: bool = True
|
||||||
|
|
||||||
|
def get_api_key(self, provider: str) -> Optional[str]:
|
||||||
|
"""Get API key for a specific provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name (openai, gemini, cohere, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API key if available
|
||||||
|
"""
|
||||||
|
provider = provider.lower()
|
||||||
|
if provider == "openai":
|
||||||
|
return self.openai_api_key
|
||||||
|
elif provider in ["gemini", "google"]:
|
||||||
|
return self.gemini_api_key or self.google_api_key
|
||||||
|
elif provider == "cohere":
|
||||||
|
return self.cohere_api_key
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_cost(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
cached_input: bool = False
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""Calculate cost for a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model identifier
|
||||||
|
input_tokens: Number of input tokens
|
||||||
|
output_tokens: Number of output tokens
|
||||||
|
cached_input: Whether input tokens are cached (for compatible models)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total cost in USD, or None if cost data not available
|
||||||
|
"""
|
||||||
|
if not self.track_costs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if using cached pricing
|
||||||
|
if cached_input and model in self.cost_per_1m_cached_input_tokens:
|
||||||
|
input_cost_per_1m = self.cost_per_1m_cached_input_tokens.get(model)
|
||||||
|
else:
|
||||||
|
input_cost_per_1m = self.cost_per_1m_input_tokens.get(model)
|
||||||
|
|
||||||
|
output_cost_per_1m = self.cost_per_1m_output_tokens.get(model)
|
||||||
|
|
||||||
|
if input_cost_per_1m is None or output_cost_per_1m is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
input_cost = (input_tokens / 1_000_000) * input_cost_per_1m
|
||||||
|
output_cost = (output_tokens / 1_000_000) * output_cost_per_1m
|
||||||
|
|
||||||
|
return round(input_cost + output_cost, 6)
|
||||||
|
|
||||||
|
def calculate_rerank_cost(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
num_searches: int
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""Calculate cost for reranking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Rerank model identifier
|
||||||
|
num_searches: Number of searches performed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total cost in USD, or None if cost data not available
|
||||||
|
"""
|
||||||
|
if not self.track_costs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cost_per_1k = self.cost_per_1k_rerank_searches.get(model)
|
||||||
|
if cost_per_1k is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return round((num_searches / 1000) * cost_per_1k, 6)
|
||||||
|
|
||||||
|
def calculate_embed_cost(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
num_tokens: int,
|
||||||
|
is_image: bool = False
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""Calculate cost for embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Embedding model identifier
|
||||||
|
num_tokens: Number of tokens to embed
|
||||||
|
is_image: Whether these are image tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total cost in USD, or None if cost data not available
|
||||||
|
"""
|
||||||
|
if not self.track_costs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if is_image:
|
||||||
|
cost_per_1m = self.cost_per_1m_embed_image_tokens.get(model)
|
||||||
|
else:
|
||||||
|
cost_per_1m = self.cost_per_1m_embed_tokens.get(model)
|
||||||
|
|
||||||
|
if cost_per_1m is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return round((num_tokens / 1_000_000) * cost_per_1m, 6)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "RouterConfig":
|
||||||
|
"""Create config from environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RouterConfig instance
|
||||||
|
"""
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert config to dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary representation
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"openai_api_key": "***" if self.openai_api_key else None,
|
||||||
|
"gemini_api_key": "***" if self.gemini_api_key else None,
|
||||||
|
"google_api_key": "***" if self.google_api_key else None,
|
||||||
|
"cohere_api_key": "***" if self.cohere_api_key else None,
|
||||||
|
"default_temperature": self.default_temperature,
|
||||||
|
"default_max_tokens": self.default_max_tokens,
|
||||||
|
"default_top_p": self.default_top_p,
|
||||||
|
"default_timeout": self.default_timeout,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
|
"retry_delay": self.retry_delay,
|
||||||
|
"retry_backoff": self.retry_backoff,
|
||||||
|
"track_costs": self.track_costs,
|
||||||
|
"log_requests": self.log_requests,
|
||||||
|
"log_responses": self.log_responses,
|
||||||
|
"log_errors": self.log_errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global config instance
|
||||||
|
config = RouterConfig.from_env()
|
742
router/embed.py
Normal file
742
router/embed.py
Normal file
|
@ -0,0 +1,742 @@
|
||||||
|
"""Embedding router implementation for multiple providers."""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Optional, Union, Literal
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from .config import config
|
||||||
|
from .exceptions import (
|
||||||
|
ConfigurationError,
|
||||||
|
map_provider_error
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingResponse:
|
||||||
|
"""Response from embedding operation."""
|
||||||
|
embeddings: List[List[float]] # List of embedding vectors
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
latency: float
|
||||||
|
dimension: int # Dimension of embeddings
|
||||||
|
num_inputs: int # Number of inputs embedded
|
||||||
|
total_tokens: Optional[int] = None
|
||||||
|
cost: Optional[float] = None
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
raw_response: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEmbedding(ABC):
|
||||||
|
"""Base class for embedding routers."""
|
||||||
|
|
||||||
|
def __init__(self, model: str, api_key: str, **kwargs: Any) -> None:
|
||||||
|
"""Initialize embedding router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model identifier
|
||||||
|
api_key: API key
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.config = kwargs
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
texts: Union[str, List[str]],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Generate embeddings for texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or list of texts to embed
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResponse
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def aembed(
|
||||||
|
self,
|
||||||
|
texts: Union[str, List[str]],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Asynchronously generate embeddings for texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or list of texts to embed
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResponse
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CohereEmbedding(BaseEmbedding):
|
||||||
|
"""Router for Cohere embedding models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "embed-english-v3.0",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize Cohere embedding router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Embedding model (embed-english-v3.0, embed-multilingual-v3.0, etc.)
|
||||||
|
api_key: Cohere API key (optional if set in environment)
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
# Get API key from config if not provided
|
||||||
|
if not api_key:
|
||||||
|
api_key = config.get_api_key("cohere")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
|
||||||
|
provider="cohere"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(model, api_key, **kwargs)
|
||||||
|
|
||||||
|
# Initialize Cohere client
|
||||||
|
try:
|
||||||
|
import cohere
|
||||||
|
# For v5.15.0, use the standard Client
|
||||||
|
self.client = cohere.Client(api_key)
|
||||||
|
self.async_client = cohere.AsyncClient(api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"cohere package not installed. Install with: pip install cohere",
|
||||||
|
provider="cohere"
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
texts: Union[str, List[str]],
|
||||||
|
input_type: Optional[Literal["search_document", "search_query", "classification", "clustering"]] = None,
|
||||||
|
truncate: Optional[Literal["NONE", "START", "END"]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Generate embeddings for texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or list of texts to embed
|
||||||
|
input_type: Purpose of embeddings (affects vector space)
|
||||||
|
truncate: How to handle inputs longer than max tokens
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResponse
|
||||||
|
"""
|
||||||
|
# Ensure texts is a list
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
params = self._prepare_request(texts, input_type, truncate, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
response = self.client.embed(**params)
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
return self._parse_response(response, latency, len(texts))
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("cohere", e)
|
||||||
|
|
||||||
|
async def aembed(
|
||||||
|
self,
|
||||||
|
texts: Union[str, List[str]],
|
||||||
|
input_type: Optional[Literal["search_document", "search_query", "classification", "clustering"]] = None,
|
||||||
|
truncate: Optional[Literal["NONE", "START", "END"]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Asynchronously generate embeddings for texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or list of texts to embed
|
||||||
|
input_type: Purpose of embeddings
|
||||||
|
truncate: How to handle long inputs
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResponse
|
||||||
|
"""
|
||||||
|
# Ensure texts is a list
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
params = self._prepare_request(texts, input_type, truncate, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
response = await self.async_client.embed(**params)
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
return self._parse_response(response, latency, len(texts))
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("cohere", e)
|
||||||
|
|
||||||
|
def _prepare_request(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
input_type: Optional[str],
|
||||||
|
truncate: Optional[str],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Prepare embed request parameters."""
|
||||||
|
params = {
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"texts": texts,
|
||||||
|
}
|
||||||
|
|
||||||
|
if input_type:
|
||||||
|
params["input_type"] = input_type
|
||||||
|
|
||||||
|
if truncate:
|
||||||
|
params["truncate"] = truncate
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _parse_response(
|
||||||
|
self,
|
||||||
|
raw_response: Any,
|
||||||
|
latency: float,
|
||||||
|
num_inputs: int
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Parse Cohere embed response."""
|
||||||
|
# Extract embeddings
|
||||||
|
embeddings = []
|
||||||
|
if hasattr(raw_response, "embeddings"):
|
||||||
|
embeddings = raw_response.embeddings
|
||||||
|
|
||||||
|
# Get dimension from first embedding
|
||||||
|
dimension = len(embeddings[0]) if embeddings else 0
|
||||||
|
|
||||||
|
# Get token count from response metadata
|
||||||
|
total_tokens = None
|
||||||
|
if hasattr(raw_response, "meta") and hasattr(raw_response.meta, "billed_units"):
|
||||||
|
total_tokens = getattr(raw_response.meta.billed_units, "input_tokens", None)
|
||||||
|
|
||||||
|
# Calculate cost
|
||||||
|
cost = None
|
||||||
|
if config.track_costs and total_tokens:
|
||||||
|
cost = config.calculate_embed_cost(self.model, total_tokens)
|
||||||
|
|
||||||
|
return EmbeddingResponse(
|
||||||
|
embeddings=embeddings,
|
||||||
|
model=self.model,
|
||||||
|
provider="cohere",
|
||||||
|
latency=latency,
|
||||||
|
dimension=dimension,
|
||||||
|
num_inputs=num_inputs,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
cost=cost,
|
||||||
|
metadata={
|
||||||
|
"id": getattr(raw_response, "id", None),
|
||||||
|
"response_type": getattr(raw_response, "response_type", None),
|
||||||
|
},
|
||||||
|
raw_response=raw_response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiEmbedding(BaseEmbedding):
|
||||||
|
"""Router for Google Gemini embedding models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "text-embedding-004",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize Gemini embedding router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Embedding model (text-embedding-004, etc.)
|
||||||
|
api_key: Google API key (optional if set in environment)
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
# Get API key from config if not provided
|
||||||
|
if not api_key:
|
||||||
|
api_key = config.get_api_key("gemini")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.",
|
||||||
|
provider="gemini"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(model, api_key, **kwargs)
|
||||||
|
|
||||||
|
# Initialize Gemini client using new google.genai library
|
||||||
|
try:
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
self.genai = genai
|
||||||
|
self.types = types
|
||||||
|
self.client = genai.Client(api_key=api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"google-genai package not installed. Install with: pip install google-genai",
|
||||||
|
provider="gemini"
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
texts: Union[str, List[str]],
|
||||||
|
task_type: Optional[str] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Generate embeddings for texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or list of texts to embed
|
||||||
|
task_type: Task type for embeddings
|
||||||
|
title: Optional title for context
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResponse
|
||||||
|
"""
|
||||||
|
# Ensure texts is a list
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# The Google genai SDK supports batch embedding
|
||||||
|
# Pass all texts at once for better performance
|
||||||
|
params = {
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"contents": texts, # Can pass multiple texts
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional config
|
||||||
|
config_params = {}
|
||||||
|
if task_type:
|
||||||
|
config_params["task_type"] = task_type
|
||||||
|
if title:
|
||||||
|
config_params["title"] = title
|
||||||
|
|
||||||
|
if config_params:
|
||||||
|
params["config"] = self.types.EmbedContentConfig(**config_params)
|
||||||
|
|
||||||
|
response = self.client.models.embed_content(**params)
|
||||||
|
|
||||||
|
# Extract embeddings from response
|
||||||
|
# The response always has an 'embeddings' attribute (even for single text)
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
# Check if embeddings exist
|
||||||
|
if response.embeddings is None:
|
||||||
|
raise ValueError("No embeddings returned in response")
|
||||||
|
|
||||||
|
# response.embeddings is always present
|
||||||
|
for emb in response.embeddings:
|
||||||
|
# ContentEmbedding objects have a 'values' attribute containing the float list
|
||||||
|
if hasattr(emb, "values") and emb.values is not None:
|
||||||
|
embeddings.append(list(emb.values))
|
||||||
|
elif hasattr(emb, "__iter__"):
|
||||||
|
# If the embedding is directly iterable (list-like)
|
||||||
|
try:
|
||||||
|
embeddings.append(list(emb))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not extract embedding values from {type(emb)}: {e}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: Unknown embedding format: {type(emb)}, attributes: {dir(emb)}")
|
||||||
|
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
# Token counting is not directly available in the response
|
||||||
|
# Set to None for now
|
||||||
|
total_tokens = None
|
||||||
|
|
||||||
|
return self._create_response(
|
||||||
|
embeddings=embeddings,
|
||||||
|
latency=latency,
|
||||||
|
num_inputs=len(texts),
|
||||||
|
total_tokens=total_tokens
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("gemini", e)
|
||||||
|
|
||||||
|
async def aembed(
|
||||||
|
self,
|
||||||
|
texts: Union[str, List[str]],
|
||||||
|
task_type: Optional[str] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Asynchronously generate embeddings for texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or list of texts to embed
|
||||||
|
task_type: Task type for embeddings
|
||||||
|
title: Optional title for context
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResponse
|
||||||
|
"""
|
||||||
|
# Ensure texts is a list
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# The Google genai SDK supports batch embedding
|
||||||
|
# Pass all texts at once for better performance
|
||||||
|
params = {
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"contents": texts, # Can pass multiple texts
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional config
|
||||||
|
config_params = {}
|
||||||
|
if task_type:
|
||||||
|
config_params["task_type"] = task_type
|
||||||
|
if title:
|
||||||
|
config_params["title"] = title
|
||||||
|
|
||||||
|
if config_params:
|
||||||
|
params["config"] = self.types.EmbedContentConfig(**config_params)
|
||||||
|
|
||||||
|
response = await self.client.aio.models.embed_content(**params)
|
||||||
|
|
||||||
|
# Extract embeddings from response
|
||||||
|
# The response always has an 'embeddings' attribute (even for single text)
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
# Check if embeddings exist
|
||||||
|
if response.embeddings is None:
|
||||||
|
raise ValueError("No embeddings returned in response")
|
||||||
|
|
||||||
|
# response.embeddings is always present
|
||||||
|
for emb in response.embeddings:
|
||||||
|
# ContentEmbedding objects have a 'values' attribute containing the float list
|
||||||
|
if hasattr(emb, "values") and emb.values is not None:
|
||||||
|
embeddings.append(list(emb.values))
|
||||||
|
elif hasattr(emb, "__iter__"):
|
||||||
|
# If the embedding is directly iterable (list-like)
|
||||||
|
try:
|
||||||
|
embeddings.append(list(emb))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not extract embedding values from {type(emb)}: {e}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: Unknown embedding format: {type(emb)}, attributes: {dir(emb)}")
|
||||||
|
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
# Token counting is not directly available in the response
|
||||||
|
# Set to None for now
|
||||||
|
total_tokens = None
|
||||||
|
|
||||||
|
return self._create_response(
|
||||||
|
embeddings=embeddings,
|
||||||
|
latency=latency,
|
||||||
|
num_inputs=len(texts),
|
||||||
|
total_tokens=total_tokens
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("gemini", e)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_response(
|
||||||
|
self,
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
latency: float,
|
||||||
|
num_inputs: int,
|
||||||
|
total_tokens: Optional[int] = None
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Create embedding response."""
|
||||||
|
# Get dimension from first embedding
|
||||||
|
dimension = len(embeddings[0]) if embeddings else 0
|
||||||
|
|
||||||
|
# Calculate cost
|
||||||
|
cost = None
|
||||||
|
if config.track_costs and total_tokens:
|
||||||
|
cost = config.calculate_embed_cost(self.model, total_tokens)
|
||||||
|
|
||||||
|
return EmbeddingResponse(
|
||||||
|
embeddings=embeddings,
|
||||||
|
model=self.model,
|
||||||
|
provider="gemini",
|
||||||
|
latency=latency,
|
||||||
|
dimension=dimension,
|
||||||
|
num_inputs=num_inputs,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
cost=cost,
|
||||||
|
metadata={},
|
||||||
|
raw_response=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbedding(BaseEmbedding):
|
||||||
|
"""Router for Ollama local embedding models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "nomic-embed-text:latest",
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize Ollama embedding router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Ollama embedding model (mxbai-embed-large, nomic-embed-text, etc.)
|
||||||
|
base_url: Ollama API base URL (default: http://localhost:11434)
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
# No API key needed for local Ollama
|
||||||
|
super().__init__(model, api_key="local", **kwargs)
|
||||||
|
|
||||||
|
# Get base URL from config or parameter
|
||||||
|
self.base_url = base_url or config.ollama_base_url
|
||||||
|
self.embeddings_url = f"{self.base_url}/api/embeddings"
|
||||||
|
|
||||||
|
# Check if Ollama is available
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
self.requests = requests
|
||||||
|
except ImportError:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"requests package not installed. Install with: pip install requests",
|
||||||
|
provider="ollama"
|
||||||
|
)
|
||||||
|
|
||||||
|
# For async support
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
self.httpx = httpx
|
||||||
|
except ImportError:
|
||||||
|
self.httpx = None # Async support is optional
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
texts: Union[str, List[str]],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Generate embeddings for texts using Ollama.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or list of texts to embed
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResponse
|
||||||
|
"""
|
||||||
|
# Ensure texts is a list
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
# Ollama API currently handles one text at a time
|
||||||
|
for text in texts:
|
||||||
|
payload = {
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"prompt": text
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.requests.post(
|
||||||
|
self.embeddings_url,
|
||||||
|
json=payload,
|
||||||
|
timeout=kwargs.get("timeout", 30)
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
if "embedding" in data:
|
||||||
|
embeddings.append(data["embedding"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No embedding in response: {data}")
|
||||||
|
|
||||||
|
except self.requests.exceptions.ConnectionError:
|
||||||
|
raise ConfigurationError(
|
||||||
|
f"Cannot connect to Ollama at {self.base_url}. "
|
||||||
|
"Make sure Ollama is running (ollama serve).",
|
||||||
|
provider="ollama"
|
||||||
|
)
|
||||||
|
except self.requests.exceptions.HTTPError as e:
|
||||||
|
if e.response.status_code == 404:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{self.model}' not found. "
|
||||||
|
f"Pull it first with: ollama pull {self.model}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
return self._create_response(
|
||||||
|
embeddings=embeddings,
|
||||||
|
latency=latency,
|
||||||
|
num_inputs=len(texts)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("ollama", e)
|
||||||
|
|
||||||
|
async def aembed(
|
||||||
|
self,
|
||||||
|
texts: Union[str, List[str]],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Asynchronously generate embeddings for texts using Ollama.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or list of texts to embed
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResponse
|
||||||
|
"""
|
||||||
|
if self.httpx is None:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"httpx package not installed for async support. Install with: pip install httpx",
|
||||||
|
provider="ollama"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure texts is a list
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
async with self.httpx.AsyncClient() as client:
|
||||||
|
# Process texts one at a time (Ollama limitation)
|
||||||
|
for text in texts:
|
||||||
|
payload = {
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"prompt": text
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
self.embeddings_url,
|
||||||
|
json=payload,
|
||||||
|
timeout=kwargs.get("timeout", 30)
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
if "embedding" in data:
|
||||||
|
embeddings.append(data["embedding"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No embedding in response: {data}")
|
||||||
|
|
||||||
|
except self.httpx.ConnectError:
|
||||||
|
raise ConfigurationError(
|
||||||
|
f"Cannot connect to Ollama at {self.base_url}. "
|
||||||
|
"Make sure Ollama is running (ollama serve).",
|
||||||
|
provider="ollama"
|
||||||
|
)
|
||||||
|
except self.httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code == 404:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{self.model}' not found. "
|
||||||
|
f"Pull it first with: ollama pull {self.model}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
return self._create_response(
|
||||||
|
embeddings=embeddings,
|
||||||
|
latency=latency,
|
||||||
|
num_inputs=len(texts)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("ollama", e)
|
||||||
|
|
||||||
|
def _create_response(
|
||||||
|
self,
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
latency: float,
|
||||||
|
num_inputs: int
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Create embedding response."""
|
||||||
|
# Get dimension from first embedding
|
||||||
|
dimension = len(embeddings[0]) if embeddings else 0
|
||||||
|
|
||||||
|
# No cost for local models
|
||||||
|
cost = 0.0 if config.track_costs else None
|
||||||
|
|
||||||
|
return EmbeddingResponse(
|
||||||
|
embeddings=embeddings,
|
||||||
|
model=self.model,
|
||||||
|
provider="ollama",
|
||||||
|
latency=latency,
|
||||||
|
dimension=dimension,
|
||||||
|
num_inputs=num_inputs,
|
||||||
|
total_tokens=None, # Ollama doesn't provide token counts
|
||||||
|
cost=cost,
|
||||||
|
metadata={
|
||||||
|
"base_url": self.base_url,
|
||||||
|
"local": True
|
||||||
|
},
|
||||||
|
raw_response=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Factory function to create embedding instances
|
||||||
|
def create_embedding(
|
||||||
|
provider: str = "cohere",
|
||||||
|
model: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> BaseEmbedding:
|
||||||
|
"""Create an embedding instance based on provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Embedding provider (cohere, gemini)
|
||||||
|
model: Model to use (provider-specific)
|
||||||
|
api_key: API key (optional if set in environment)
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseEmbedding instance
|
||||||
|
"""
|
||||||
|
provider = provider.lower()
|
||||||
|
|
||||||
|
if provider == "cohere":
|
||||||
|
return CohereEmbedding(
|
||||||
|
model=model or "embed-english-v3.0",
|
||||||
|
api_key=api_key,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
elif provider in ["gemini", "google"]:
|
||||||
|
return GeminiEmbedding(
|
||||||
|
model=model or "text-embedding-004",
|
||||||
|
api_key=api_key,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
elif provider == "ollama":
|
||||||
|
return OllamaEmbedding(
|
||||||
|
model=model or "nomic-embed-text:latest",
|
||||||
|
base_url=kwargs.pop("base_url", None), # Extract base_url from kwargs
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown embedding provider: {provider}")
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience aliases
|
||||||
|
Embed = create_embedding
|
||||||
|
CohereEmbed = CohereEmbedding
|
||||||
|
GeminiEmbed = GeminiEmbedding
|
||||||
|
OllamaEmbed = OllamaEmbedding
|
172
router/exceptions.py
Normal file
172
router/exceptions.py
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
"""Custom exceptions for AI router."""
|
||||||
|
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
|
||||||
|
class RouterError(Exception):
|
||||||
|
"""Base exception for all router errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
original_error: Optional[Exception] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize router error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
provider: Provider that raised the error
|
||||||
|
original_error: Original exception from provider
|
||||||
|
**kwargs: Additional error details
|
||||||
|
"""
|
||||||
|
super().__init__(message)
|
||||||
|
self.provider = provider
|
||||||
|
self.original_error = original_error
|
||||||
|
self.details = kwargs
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation of error."""
|
||||||
|
base_msg = super().__str__()
|
||||||
|
if self.provider:
|
||||||
|
base_msg = f"[{self.provider}] {base_msg}"
|
||||||
|
if self.original_error:
|
||||||
|
base_msg += f" (Original: {self.original_error})"
|
||||||
|
return base_msg
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(RouterError):
|
||||||
|
"""Raised when authentication fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitError(RouterError):
|
||||||
|
"""Raised when rate limit is exceeded."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
retry_after: Optional[float] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize rate limit error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
retry_after: Seconds to wait before retry
|
||||||
|
**kwargs: Additional error details
|
||||||
|
"""
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.retry_after = retry_after
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotFoundError(RouterError):
|
||||||
|
"""Raised when requested model is not available."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
model: str,
|
||||||
|
available_models: Optional[list] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize model not found error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
model: Requested model
|
||||||
|
available_models: List of available models
|
||||||
|
**kwargs: Additional error details
|
||||||
|
"""
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.model = model
|
||||||
|
self.available_models = available_models or []
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRequestError(RouterError):
|
||||||
|
"""Raised when request parameters are invalid."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TimeoutError(RouterError):
|
||||||
|
"""Raised when request times out."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ContentFilterError(RouterError):
|
||||||
|
"""Raised when content is blocked by safety filters."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaExceededError(RouterError):
|
||||||
|
"""Raised when API quota is exceeded."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderError(RouterError):
|
||||||
|
"""Raised for provider-specific errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigurationError(RouterError):
|
||||||
|
"""Raised for configuration errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def map_provider_error(provider: str, error: Exception) -> RouterError:
|
||||||
|
"""Map provider-specific errors to router errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name
|
||||||
|
error: Original exception
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Appropriate RouterError subclass
|
||||||
|
"""
|
||||||
|
error_message = str(error)
|
||||||
|
error_type = type(error).__name__
|
||||||
|
|
||||||
|
# Common patterns across providers
|
||||||
|
if any(x in error_message.lower() for x in ["unauthorized", "invalid api key", "authentication"]):
|
||||||
|
return AuthenticationError(
|
||||||
|
f"Authentication failed: {error_message}",
|
||||||
|
provider=provider,
|
||||||
|
original_error=error
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(x in error_message.lower() for x in ["rate limit", "too many requests", "quota exceeded"]):
|
||||||
|
return RateLimitError(
|
||||||
|
f"Rate limit exceeded: {error_message}",
|
||||||
|
provider=provider,
|
||||||
|
original_error=error
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(x in error_message.lower() for x in ["model not found", "invalid model", "unknown model"]):
|
||||||
|
return ModelNotFoundError(
|
||||||
|
f"Model not found: {error_message}",
|
||||||
|
model="", # Extract from error if possible
|
||||||
|
provider=provider,
|
||||||
|
original_error=error
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(x in error_message.lower() for x in ["timeout", "timed out"]):
|
||||||
|
return TimeoutError(
|
||||||
|
f"Request timed out: {error_message}",
|
||||||
|
provider=provider,
|
||||||
|
original_error=error
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(x in error_message.lower() for x in ["content filter", "safety", "blocked"]):
|
||||||
|
return ContentFilterError(
|
||||||
|
f"Content blocked by safety filters: {error_message}",
|
||||||
|
provider=provider,
|
||||||
|
original_error=error
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default to provider error
|
||||||
|
return ProviderError(
|
||||||
|
f"{error_type}: {error_message}",
|
||||||
|
provider=provider,
|
||||||
|
original_error=error
|
||||||
|
)
|
242
router/gemini.py
Normal file
242
router/gemini.py
Normal file
|
@ -0,0 +1,242 @@
|
||||||
|
"""Google Gemini provider implementation."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, AsyncGenerator, Generator
|
||||||
|
import time
|
||||||
|
|
||||||
|
from .base import AIRouter, RouterResponse
|
||||||
|
from .config import config
|
||||||
|
from .exceptions import (
|
||||||
|
AuthenticationError,
|
||||||
|
ConfigurationError,
|
||||||
|
map_provider_error
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemini(AIRouter):
|
||||||
|
"""Router for Google Gemini models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "gemini-2.0-flash-001",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize Gemini router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Gemini model to use (gemini-2.0-flash-001, gemini-1.5-pro, etc.)
|
||||||
|
api_key: Google API key (optional if set in environment)
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
# Get API key from config if not provided
|
||||||
|
if not api_key:
|
||||||
|
api_key = config.get_api_key("gemini")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"Gemini API key not found. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.",
|
||||||
|
provider="gemini"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
# Initialize Gemini client using new google.genai library
|
||||||
|
try:
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
self.genai = genai
|
||||||
|
self.types = types
|
||||||
|
self.client = genai.Client(api_key=api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"google-genai package not installed. Install with: pip install google-genai",
|
||||||
|
provider="gemini"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
"""Prepare Gemini request parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Request parameters
|
||||||
|
"""
|
||||||
|
# Build config using new API structure
|
||||||
|
config_params = {}
|
||||||
|
|
||||||
|
if "temperature" in kwargs:
|
||||||
|
config_params["temperature"] = kwargs["temperature"]
|
||||||
|
elif hasattr(config, "default_temperature"):
|
||||||
|
config_params["temperature"] = config.default_temperature
|
||||||
|
|
||||||
|
if "max_tokens" in kwargs:
|
||||||
|
config_params["max_output_tokens"] = kwargs["max_tokens"]
|
||||||
|
elif "max_output_tokens" in kwargs:
|
||||||
|
config_params["max_output_tokens"] = kwargs["max_output_tokens"]
|
||||||
|
|
||||||
|
if "top_p" in kwargs:
|
||||||
|
config_params["top_p"] = kwargs["top_p"]
|
||||||
|
elif hasattr(config, "default_top_p"):
|
||||||
|
config_params["top_p"] = config.default_top_p
|
||||||
|
|
||||||
|
if "top_k" in kwargs:
|
||||||
|
config_params["top_k"] = kwargs["top_k"]
|
||||||
|
|
||||||
|
# Add safety settings if provided
|
||||||
|
if "safety_settings" in kwargs:
|
||||||
|
config_params["safety_settings"] = kwargs["safety_settings"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"contents": prompt,
|
||||||
|
"config": self.types.GenerateContentConfig(**config_params) if config_params else None
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
|
||||||
|
"""Parse Gemini response into RouterResponse.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_response: Raw Gemini response
|
||||||
|
latency: Request latency
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RouterResponse
|
||||||
|
"""
|
||||||
|
# Extract text content
|
||||||
|
content = raw_response.text if hasattr(raw_response, "text") else ""
|
||||||
|
|
||||||
|
# Try to get token counts from usage_metadata
|
||||||
|
input_tokens = None
|
||||||
|
output_tokens = None
|
||||||
|
total_tokens = None
|
||||||
|
|
||||||
|
if hasattr(raw_response, "usage_metadata"):
|
||||||
|
usage = raw_response.usage_metadata
|
||||||
|
# Handle both old and new attribute names
|
||||||
|
input_tokens = getattr(usage, "prompt_token_count", None) or getattr(usage, "cached_content_token_count", None)
|
||||||
|
output_tokens = getattr(usage, "candidates_token_count", None)
|
||||||
|
total_tokens = getattr(usage, "total_token_count", None)
|
||||||
|
|
||||||
|
# Calculate cost if available
|
||||||
|
cost = None
|
||||||
|
if input_tokens and output_tokens and config.track_costs:
|
||||||
|
cost = config.calculate_cost(self.model, input_tokens, output_tokens)
|
||||||
|
|
||||||
|
# Extract finish reason
|
||||||
|
finish_reason = "stop"
|
||||||
|
if hasattr(raw_response, "candidates") and raw_response.candidates:
|
||||||
|
candidate = raw_response.candidates[0]
|
||||||
|
finish_reason = getattr(candidate, "finish_reason", "stop")
|
||||||
|
|
||||||
|
return RouterResponse(
|
||||||
|
content=content,
|
||||||
|
model=raw_response.model_version if hasattr(raw_response, "model_version") else self.model,
|
||||||
|
provider="gemini",
|
||||||
|
latency=latency,
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=total_tokens or ((input_tokens + output_tokens) if input_tokens and output_tokens else None),
|
||||||
|
cost=cost,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
metadata={
|
||||||
|
"prompt_feedback": getattr(raw_response, "prompt_feedback", None),
|
||||||
|
"safety_ratings": getattr(raw_response.candidates[0], "safety_ratings", None) if hasattr(raw_response, "candidates") and raw_response.candidates else None
|
||||||
|
},
|
||||||
|
raw_response=raw_response
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_request(self, request_params: Dict[str, Any]) -> Any:
|
||||||
|
"""Make synchronous request to Gemini.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_params: Request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw Gemini response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.client.models.generate_content(**request_params)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("gemini", e)
|
||||||
|
|
||||||
|
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
|
||||||
|
"""Make asynchronous request to Gemini.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_params: Request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw Gemini response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await self.client.aio.models.generate_content(**request_params)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("gemini", e)
|
||||||
|
|
||||||
|
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
|
||||||
|
"""Stream responses from Gemini.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
RouterResponse chunks
|
||||||
|
"""
|
||||||
|
params = {**self.config, **kwargs}
|
||||||
|
request_params = self._prepare_request(prompt, **params)
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
stream = self.client.models.generate_content_stream(**request_params)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.text:
|
||||||
|
yield RouterResponse(
|
||||||
|
content=chunk.text,
|
||||||
|
model=self.model,
|
||||||
|
provider="gemini",
|
||||||
|
latency=time.time() - start_time,
|
||||||
|
finish_reason=None,
|
||||||
|
metadata={},
|
||||||
|
raw_response=chunk
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("gemini", e)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, prompt: str, **kwargs: Any
|
||||||
|
) -> AsyncGenerator[RouterResponse, None]:
|
||||||
|
"""Asynchronously stream responses from Gemini.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
RouterResponse chunks
|
||||||
|
"""
|
||||||
|
params = {**self.config, **kwargs}
|
||||||
|
request_params = self._prepare_request(prompt, **params)
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
stream = await self.client.aio.models.generate_content_stream(**request_params)
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.text:
|
||||||
|
yield RouterResponse(
|
||||||
|
content=chunk.text,
|
||||||
|
model=self.model,
|
||||||
|
provider="gemini",
|
||||||
|
latency=time.time() - start_time,
|
||||||
|
finish_reason=None,
|
||||||
|
metadata={},
|
||||||
|
raw_response=chunk
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("gemini", e)
|
280
router/openai_compatible.py
Normal file
280
router/openai_compatible.py
Normal file
|
@ -0,0 +1,280 @@
|
||||||
|
"""OpenAI-compatible provider implementation."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, AsyncGenerator, Generator
|
||||||
|
import time
|
||||||
|
|
||||||
|
from .base import AIRouter, RouterResponse
|
||||||
|
from .config import config
|
||||||
|
from .exceptions import (
|
||||||
|
ConfigurationError,
|
||||||
|
map_provider_error
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatible(AIRouter):
|
||||||
|
"""Router for OpenAI and compatible APIs (OpenAI, Azure OpenAI, etc.)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "gpt-3.5-turbo",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize OpenAI-compatible router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model to use (gpt-4, gpt-3.5-turbo, etc.)
|
||||||
|
api_key: API key (optional if set in environment)
|
||||||
|
base_url: Base URL for API (for Azure or custom endpoints)
|
||||||
|
organization: Organization ID for OpenAI
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
# Get API key from config if not provided
|
||||||
|
if not api_key:
|
||||||
|
api_key = config.get_api_key("openai")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"OpenAI API key not found. Set OPENAI_API_KEY environment variable.",
|
||||||
|
provider="openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
# Store additional config
|
||||||
|
self.base_url = base_url
|
||||||
|
self.organization = organization
|
||||||
|
|
||||||
|
# Initialize OpenAI client
|
||||||
|
try:
|
||||||
|
from openai import OpenAI, AsyncOpenAI
|
||||||
|
|
||||||
|
# Build client kwargs with proper types
|
||||||
|
client_kwargs: Dict[str, Any] = {
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
if base_url:
|
||||||
|
client_kwargs["base_url"] = base_url
|
||||||
|
if organization:
|
||||||
|
client_kwargs["organization"] = organization
|
||||||
|
|
||||||
|
# Add any additional client configuration from kwargs
|
||||||
|
# Note: Only pass through valid OpenAI client parameters
|
||||||
|
valid_client_params = ["timeout", "max_retries", "default_headers", "default_query", "http_client"]
|
||||||
|
for param in valid_client_params:
|
||||||
|
if param in kwargs:
|
||||||
|
client_kwargs[param] = kwargs.pop(param)
|
||||||
|
|
||||||
|
self.client = OpenAI(**client_kwargs)
|
||||||
|
self.async_client = AsyncOpenAI(**client_kwargs)
|
||||||
|
except ImportError:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"openai package not installed. Install with: pip install openai",
|
||||||
|
provider="openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_request(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
"""Prepare OpenAI request parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Request parameters
|
||||||
|
"""
|
||||||
|
# Build messages
|
||||||
|
messages = kwargs.get("messages", [
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
])
|
||||||
|
|
||||||
|
# Build request parameters
|
||||||
|
params = {
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if "temperature" in kwargs:
|
||||||
|
params["temperature"] = kwargs["temperature"]
|
||||||
|
elif hasattr(config, "default_temperature"):
|
||||||
|
params["temperature"] = config.default_temperature
|
||||||
|
|
||||||
|
if "max_tokens" in kwargs:
|
||||||
|
params["max_tokens"] = kwargs["max_tokens"]
|
||||||
|
elif hasattr(config, "default_max_tokens") and config.default_max_tokens:
|
||||||
|
params["max_tokens"] = config.default_max_tokens
|
||||||
|
|
||||||
|
if "top_p" in kwargs:
|
||||||
|
params["top_p"] = kwargs["top_p"]
|
||||||
|
elif hasattr(config, "default_top_p"):
|
||||||
|
params["top_p"] = config.default_top_p
|
||||||
|
|
||||||
|
# Other OpenAI-specific parameters
|
||||||
|
for key in ["n", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "user", "seed", "tools", "tool_choice", "response_format", "logprobs", "top_logprobs", "parallel_tool_calls"]:
|
||||||
|
if key in kwargs:
|
||||||
|
params[key] = kwargs[key]
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _parse_response(self, raw_response: Any, latency: float) -> RouterResponse:
|
||||||
|
"""Parse OpenAI response into RouterResponse.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_response: Raw OpenAI response
|
||||||
|
latency: Request latency
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RouterResponse
|
||||||
|
"""
|
||||||
|
# Extract first choice
|
||||||
|
choice = raw_response.choices[0]
|
||||||
|
content = choice.message.content or ""
|
||||||
|
|
||||||
|
# Get token counts
|
||||||
|
usage = raw_response.usage
|
||||||
|
input_tokens = usage.prompt_tokens if usage else None
|
||||||
|
output_tokens = usage.completion_tokens if usage else None
|
||||||
|
total_tokens = usage.total_tokens if usage else None
|
||||||
|
|
||||||
|
# Calculate cost if available
|
||||||
|
cost = None
|
||||||
|
if input_tokens and output_tokens and config.track_costs:
|
||||||
|
cost = config.calculate_cost(raw_response.model, input_tokens, output_tokens)
|
||||||
|
|
||||||
|
return RouterResponse(
|
||||||
|
content=content,
|
||||||
|
model=raw_response.model,
|
||||||
|
provider="openai",
|
||||||
|
latency=latency,
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
cost=cost,
|
||||||
|
finish_reason=choice.finish_reason,
|
||||||
|
metadata={
|
||||||
|
"id": raw_response.id,
|
||||||
|
"created": raw_response.created,
|
||||||
|
"system_fingerprint": getattr(raw_response, "system_fingerprint", None),
|
||||||
|
"tool_calls": getattr(choice.message, "tool_calls", None),
|
||||||
|
"function_call": getattr(choice.message, "function_call", None),
|
||||||
|
"logprobs": getattr(choice, "logprobs", None),
|
||||||
|
},
|
||||||
|
raw_response=raw_response
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_request(self, request_params: Dict[str, Any]) -> Any:
|
||||||
|
"""Make synchronous request to OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_params: Request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw OpenAI response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(**request_params)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("openai", e)
|
||||||
|
|
||||||
|
async def _make_async_request(self, request_params: Dict[str, Any]) -> Any:
|
||||||
|
"""Make asynchronous request to OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_params: Request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw OpenAI response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await self.async_client.chat.completions.create(**request_params)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("openai", e)
|
||||||
|
|
||||||
|
def stream(self, prompt: str, **kwargs: Any) -> Generator[RouterResponse, None, None]:
|
||||||
|
"""Stream responses from OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
RouterResponse chunks
|
||||||
|
"""
|
||||||
|
params = {**self.config, **kwargs}
|
||||||
|
request_params = self._prepare_request(prompt, **params)
|
||||||
|
request_params["stream"] = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
stream = self.client.chat.completions.create(**request_params)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.choices and len(chunk.choices) > 0:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
content = getattr(choice.delta, "content", None)
|
||||||
|
if content:
|
||||||
|
yield RouterResponse(
|
||||||
|
content=content,
|
||||||
|
model=chunk.model,
|
||||||
|
provider="openai",
|
||||||
|
latency=time.time() - start_time,
|
||||||
|
finish_reason=getattr(choice, "finish_reason", None),
|
||||||
|
metadata={
|
||||||
|
"chunk_id": chunk.id,
|
||||||
|
"tool_calls": getattr(choice.delta, "tool_calls", None),
|
||||||
|
"function_call": getattr(choice.delta, "function_call", None),
|
||||||
|
},
|
||||||
|
raw_response=chunk
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("openai", e)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, prompt: str, **kwargs: Any
|
||||||
|
) -> AsyncGenerator[RouterResponse, None]:
|
||||||
|
"""Asynchronously stream responses from OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
RouterResponse chunks
|
||||||
|
"""
|
||||||
|
params = {**self.config, **kwargs}
|
||||||
|
request_params = self._prepare_request(prompt, **params)
|
||||||
|
request_params["stream"] = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
stream = await self.async_client.chat.completions.create(**request_params)
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.choices and len(chunk.choices) > 0:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
content = getattr(choice.delta, "content", None)
|
||||||
|
if content:
|
||||||
|
yield RouterResponse(
|
||||||
|
content=content,
|
||||||
|
model=chunk.model,
|
||||||
|
provider="openai",
|
||||||
|
latency=time.time() - start_time,
|
||||||
|
finish_reason=getattr(choice, "finish_reason", None),
|
||||||
|
metadata={
|
||||||
|
"chunk_id": chunk.id,
|
||||||
|
"tool_calls": getattr(choice.delta, "tool_calls", None),
|
||||||
|
"function_call": getattr(choice.delta, "function_call", None),
|
||||||
|
},
|
||||||
|
raw_response=chunk
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("openai", e)
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience alias
|
||||||
|
OpenAI = OpenAICompatible
|
245
router/rerank.py
Normal file
245
router/rerank.py
Normal file
|
@ -0,0 +1,245 @@
|
||||||
|
"""Reranking router implementation for Cohere Rerank models."""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Optional, Union
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import time
|
||||||
|
|
||||||
|
from .config import config
|
||||||
|
from .exceptions import (
|
||||||
|
ConfigurationError,
|
||||||
|
map_provider_error
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RerankDocument:
|
||||||
|
"""Document to be reranked."""
|
||||||
|
text: str
|
||||||
|
id: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RerankResult:
|
||||||
|
"""Single reranked result."""
|
||||||
|
index: int # Original index in the documents list
|
||||||
|
relevance_score: float
|
||||||
|
document: RerankDocument
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RerankResponse:
|
||||||
|
"""Response from reranking operation."""
|
||||||
|
results: List[RerankResult]
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
latency: float
|
||||||
|
num_documents: int
|
||||||
|
cost: Optional[float] = None
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
raw_response: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CohereRerank:
|
||||||
|
"""Router for Cohere Rerank models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "rerank-english-v3.0",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize Cohere Rerank router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Rerank model to use (rerank-3.5, rerank-english-v3.0, rerank-multilingual-v3.0)
|
||||||
|
api_key: Cohere API key (optional if set in environment)
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
# Get API key from config if not provided
|
||||||
|
if not api_key:
|
||||||
|
api_key = config.get_api_key("cohere")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"Cohere API key not found. Set COHERE_API_KEY environment variable.",
|
||||||
|
provider="cohere"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.config = kwargs
|
||||||
|
|
||||||
|
# Initialize Cohere client
|
||||||
|
try:
|
||||||
|
import cohere
|
||||||
|
# For v5.15.0, use the standard Client
|
||||||
|
self.client = cohere.Client(api_key)
|
||||||
|
self.async_client = cohere.AsyncClient(api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"cohere package not installed. Install with: pip install cohere",
|
||||||
|
provider="cohere"
|
||||||
|
)
|
||||||
|
|
||||||
|
def rerank(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""Rerank documents based on relevance to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query
|
||||||
|
documents: List of documents to rerank (strings, dicts, or RerankDocument objects)
|
||||||
|
top_n: Number of top results to return (None returns all)
|
||||||
|
**kwargs: Additional parameters (max_chunks_per_doc, return_documents, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RerankResponse with reranked results
|
||||||
|
"""
|
||||||
|
params = self._prepare_request(query, documents, top_n, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
response = self.client.rerank(**params)
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
return self._parse_response(response, latency, len(params["documents"]))
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("cohere", e)
|
||||||
|
|
||||||
|
async def arerank(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""Asynchronously rerank documents based on relevance to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query
|
||||||
|
documents: List of documents to rerank
|
||||||
|
top_n: Number of top results to return (None returns all)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RerankResponse with reranked results
|
||||||
|
"""
|
||||||
|
params = self._prepare_request(query, documents, top_n, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
response = await self.async_client.rerank(**params)
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
return self._parse_response(response, latency, len(params["documents"]))
|
||||||
|
except Exception as e:
|
||||||
|
raise map_provider_error("cohere", e)
|
||||||
|
|
||||||
|
def _prepare_request(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: Union[List[str], List[Dict[str, Any]], List[RerankDocument]],
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Prepare rerank request parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query
|
||||||
|
documents: Documents to rerank
|
||||||
|
top_n: Number of results to return
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Request parameters
|
||||||
|
"""
|
||||||
|
# Convert documents to the format expected by Cohere
|
||||||
|
formatted_docs = []
|
||||||
|
for i, doc in enumerate(documents):
|
||||||
|
if isinstance(doc, str):
|
||||||
|
formatted_docs.append({"text": doc})
|
||||||
|
elif isinstance(doc, RerankDocument):
|
||||||
|
formatted_docs.append({"text": doc.text})
|
||||||
|
elif isinstance(doc, dict):
|
||||||
|
# Assume dict has at least 'text' field
|
||||||
|
formatted_docs.append(doc)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid document type at index {i}: {type(doc)}")
|
||||||
|
|
||||||
|
# Build request parameters
|
||||||
|
params = {
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"query": query,
|
||||||
|
"documents": formatted_docs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if top_n is not None:
|
||||||
|
params["top_n"] = top_n
|
||||||
|
|
||||||
|
# Add optional parameters for v5.15.0
|
||||||
|
for key in ["max_chunks_per_doc", "return_documents", "rank_fields"]:
|
||||||
|
if key in kwargs:
|
||||||
|
params[key] = kwargs[key]
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _parse_response(
|
||||||
|
self,
|
||||||
|
raw_response: Any,
|
||||||
|
latency: float,
|
||||||
|
num_documents: int
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""Parse Cohere rerank response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_response: Raw response from Cohere
|
||||||
|
latency: Request latency
|
||||||
|
num_documents: Total number of documents submitted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RerankResponse
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Parse results from v5.15.0 response format
|
||||||
|
if hasattr(raw_response, "results"):
|
||||||
|
for result in raw_response.results:
|
||||||
|
# Extract document info
|
||||||
|
doc_text = ""
|
||||||
|
if hasattr(result, "document") and hasattr(result.document, "text"):
|
||||||
|
doc_text = result.document.text
|
||||||
|
|
||||||
|
results.append(RerankResult(
|
||||||
|
index=result.index,
|
||||||
|
relevance_score=result.relevance_score,
|
||||||
|
document=RerankDocument(text=doc_text)
|
||||||
|
))
|
||||||
|
|
||||||
|
# Calculate cost
|
||||||
|
cost = None
|
||||||
|
if config.track_costs:
|
||||||
|
# Reranking is charged per search (1 search = 1 query across N documents)
|
||||||
|
cost = config.calculate_rerank_cost(self.model, 1)
|
||||||
|
|
||||||
|
return RerankResponse(
|
||||||
|
results=results,
|
||||||
|
model=self.model,
|
||||||
|
provider="cohere",
|
||||||
|
latency=latency,
|
||||||
|
num_documents=num_documents,
|
||||||
|
cost=cost,
|
||||||
|
metadata={
|
||||||
|
"id": getattr(raw_response, "id", None),
|
||||||
|
"meta": getattr(raw_response, "meta", None),
|
||||||
|
},
|
||||||
|
raw_response=raw_response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience alias
|
||||||
|
Rerank = CohereRerank
|
Loading…
Reference in a new issue