template-with-ai/router/factory.py
2025-07-01 17:07:02 +05:30

57 lines
No EOL
1.6 KiB
Python

"""Factory function for creating AI router instances."""
from typing import Any, Optional
from .base import AIRouter
from .gemini import Gemini
from .openai_compatible import OpenAI
from .cohere import Cohere
def create_router(
provider: str,
model: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs: Any
) -> AIRouter:
"""Create an AI router instance based on provider.
Args:
provider: Provider name (openai, gemini, cohere)
model: Model to use (provider-specific default if not provided)
api_key: API key (optional if set in environment)
**kwargs: Additional provider-specific configuration
Returns:
AIRouter instance
Raises:
ValueError: If provider is unknown
"""
provider = provider.lower()
if provider == "openai":
return OpenAI(
model=model or "gpt-4o",
api_key=api_key,
**kwargs
)
elif provider in ["gemini", "google"]:
return Gemini(
model=model or "gemini-2.0-flash-001",
api_key=api_key,
**kwargs
)
elif provider == "cohere":
return Cohere(
model=model or "command-r-plus",
api_key=api_key,
**kwargs
)
else:
raise ValueError(f"Unknown provider: {provider}")
# Convenience alias that matches the test expectations
def AIRouter(provider: str, **kwargs: Any) -> AIRouter:
"""Factory function alias for backward compatibility."""
return create_router(provider, **kwargs)