57 lines
No EOL
1.6 KiB
Python
57 lines
No EOL
1.6 KiB
Python
"""Factory function for creating AI router instances."""
|
|
|
|
from typing import Any, Optional
|
|
from .base import AIRouter
|
|
from .gemini import Gemini
|
|
from .openai_compatible import OpenAI
|
|
from .cohere import Cohere
|
|
|
|
|
|
def create_router(
|
|
provider: str,
|
|
model: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
**kwargs: Any
|
|
) -> AIRouter:
|
|
"""Create an AI router instance based on provider.
|
|
|
|
Args:
|
|
provider: Provider name (openai, gemini, cohere)
|
|
model: Model to use (provider-specific default if not provided)
|
|
api_key: API key (optional if set in environment)
|
|
**kwargs: Additional provider-specific configuration
|
|
|
|
Returns:
|
|
AIRouter instance
|
|
|
|
Raises:
|
|
ValueError: If provider is unknown
|
|
"""
|
|
provider = provider.lower()
|
|
|
|
if provider == "openai":
|
|
return OpenAI(
|
|
model=model or "gpt-4o",
|
|
api_key=api_key,
|
|
**kwargs
|
|
)
|
|
elif provider in ["gemini", "google"]:
|
|
return Gemini(
|
|
model=model or "gemini-2.0-flash-001",
|
|
api_key=api_key,
|
|
**kwargs
|
|
)
|
|
elif provider == "cohere":
|
|
return Cohere(
|
|
model=model or "command-r-plus",
|
|
api_key=api_key,
|
|
**kwargs
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown provider: {provider}")
|
|
|
|
|
|
# Convenience alias that matches the test expectations
|
|
def AIRouter(provider: str, **kwargs: Any) -> AIRouter:
|
|
"""Factory function alias for backward compatibility."""
|
|
return create_router(provider, **kwargs) |