Skip to content

Commit 4a2ac55

Browse files
committed
refactor
1 parent 28374ac commit 4a2ac55

File tree

7 files changed

+34
-23
lines changed

7 files changed

+34
-23
lines changed

cascade.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def load_config(config_file: str) -> tuple[Config, str]:
4848
)
4949

5050
llm_wrappers = {
51-
"llm1": LLMFactory.create(config.llm1.provider, config.llm1.model),
52-
"llm2": LLMFactory.create(config.llm2.provider, config.llm2.model)
51+
"llm1": LLMFactory.create(config.llm1),
52+
"llm2": LLMFactory.create(config.llm2)
5353
}
5454

5555
state_manager = StateManager(config)

cascade/llm/anthropic.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,15 @@
1212
class AnthropicWrapper(BaseLLMWrapper):
1313
"""Wrapper for the Anthropic API."""
1414

15-
def __init__(self, model: str):
15+
def __init__(self, model: str, api_key: str):
1616
"""Initialize the Anthropic wrapper.
1717
1818
Args:
1919
model: The model identifier to use
20+
api_key: The Anthropic API key
2021
"""
2122
super().__init__(model)
22-
self.name = "anthropic"
23-
if os.environ.get("ANTHROPIC_API_KEY"):
24-
self.client = anthropic.Anthropic(
25-
api_key=os.environ.get("ANTHROPIC_API_KEY")
26-
)
27-
else:
28-
raise ValueError("Anthropic API key not found")
23+
self.client = anthropic.Anthropic(api_key=api_key)
2924

3025
def generate_stream(
3126
self, messages, system_prompt=None

cascade/llm/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, model: str):
1717
model: The model identifier to use
1818
"""
1919
self.model = model
20-
self.name = None
20+
self.name = self.__class__.__name__.replace("Wrapper", "").lower()
2121

2222
@abstractmethod
2323
def generate_stream(

cascade/llm/factory.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Create LLM wrappers."""
22

33
from typing import Dict, Type
4+
import os
45
from cascade.models import Provider
56
from cascade.llm.base import BaseLLMWrapper
67
from cascade.llm.anthropic import AnthropicWrapper
@@ -18,13 +19,19 @@ class LLMFactory:
1819
}
1920

2021
@classmethod
21-
def create(cls, provider: Provider, model: str) -> BaseLLMWrapper:
22-
"""Create an LLM wrapper instance."""
22+
def create(cls, config) -> BaseLLMWrapper:
23+
"""Create an LLM wrapper instance from an LLMConfig."""
24+
provider = config.provider
2325
if provider not in cls._registry:
2426
raise ValueError(f"Unknown provider: {provider}")
2527

2628
wrapper_class = cls._registry[provider]
27-
return wrapper_class(model=model)
29+
if provider in [Provider.OPENAI, Provider.ANTHROPIC]:
30+
config.require_api_key()
31+
api_key = os.environ[provider.upper() + "_API_KEY"]
32+
return wrapper_class(model=config.model, api_key=api_key)
33+
else:
34+
return wrapper_class(model=config.model)
2835

2936
@classmethod
3037
def register(cls, provider: Provider, wrapper_class: Type[BaseLLMWrapper]) -> None:

cascade/llm/ollama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def __init__(self, model: str):
1818
model: The model identifier to use
1919
"""
2020
super().__init__(model)
21-
self.name = "ollama"
2221

2322
def generate_stream(
2423
self, messages, system_prompt=None

cascade/llm/openai.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,15 @@
1818
class OpenAIWrapper(BaseLLMWrapper):
1919
"""Wrapper for OpenAI API"""
2020

21-
def __init__(self, model: str):
21+
def __init__(self, model: str, api_key: str):
2222
"""Initialize the OpenAI wrapper.
2323
2424
Args:
2525
model: The model identifier to use
26+
api_key: The OpenAI API key
2627
"""
2728
super().__init__(model)
28-
self.name = "openai"
29-
if os.environ.get("OPENAI_API_KEY"):
30-
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
31-
else:
32-
raise ValueError("OpenAI API key not found in environment variables")
29+
self.client = OpenAI(api_key=api_key)
3330

3431
def generate_stream(
3532
self, messages, system_prompt=None

cascade/models.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List, Dict, Optional
44
from enum import Enum
55
from pydantic import BaseModel, model_validator, field_validator
6+
import os
67

78

89
class Provider(str, Enum):
@@ -40,8 +41,10 @@ def validate_connection(cls, v: str) -> str:
4041
provider, model = v.split(":", 1)
4142
if not provider or not model:
4243
raise ValueError("Both provider and model must be specified")
43-
if provider not in ["anthropic", "openai", "ollama"]:
44-
raise ValueError("Provider must be one of: anthropic, openai, ollama")
44+
try:
45+
Provider(provider)
46+
except ValueError:
47+
raise ValueError(f"Provider must be one of: {[p.value for p in Provider]}")
4548
return v
4649

4750
@property
@@ -54,6 +57,16 @@ def model(self) -> str:
5457
"""Get the model from the connection string."""
5558
return self.connection.split(":", 1)[1]
5659

60+
def require_api_key(self):
61+
"""Ensure the required API key is present in the environment for providers that need it."""
62+
env_vars = {
63+
"openai": "OPENAI_API_KEY",
64+
"anthropic": "ANTHROPIC_API_KEY",
65+
}
66+
env_var = env_vars.get(self.provider)
67+
if env_var and not os.environ.get(env_var):
68+
raise ValueError(f"{env_var} must be set in the environment for provider '{self.provider}'")
69+
5770

5871
class Config(BaseModel):
5972
"""YAML Config."""

0 commit comments

Comments
 (0)