diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index a2ecc6b7c..3efb5c34f 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -236,7 +236,7 @@ class CohereChatGenerator: from haystack.utils import Secret from haystack_integrations.components.generators.cohere import CohereChatGenerator - client = CohereChatGenerator(model="command-r", api_key=Secret.from_env_var("COHERE_API_KEY")) + client = CohereChatGenerator(model="command-r-08-2024", api_key=Secret.from_env_var("COHERE_API_KEY")) messages = [ChatMessage.from_user("What's Natural Language Processing?")] client.run(messages) @@ -278,7 +278,7 @@ def weather(city: str) -> str: # Create and set up the pipeline pipeline = Pipeline() - pipeline.add_component("generator", CohereChatGenerator(model="command-r", tools=[weather_tool])) + pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool])) pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool])) pipeline.connect("generator", "tool_invoker") @@ -296,7 +296,7 @@ def weather(city: str) -> str: def __init__( self, api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), - model: str = "command-r", + model: str = "command-r-08-2024", streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, api_base_url: Optional[str] = None, generation_kwargs: Optional[Dict[str, Any]] = None, diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 0dc3d5f14..e9d8131de 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -46,7 +46,7 @@ def test_init_default(self, monkeypatch): component = CohereChatGenerator() assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) - assert component.model == "command-r" + assert component.model == "command-r-08-2024" assert component.streaming_callback is None assert component.api_base_url == "https://api.cohere.com" assert not component.generation_kwargs @@ -78,7 +78,7 @@ def test_to_dict_default(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model": "command-r", + "model": "command-r-08-2024", "streaming_callback": None, "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "api_base_url": "https://api.cohere.com", @@ -116,7 +116,7 @@ def test_from_dict(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model": "command-r", + "model": "command-r-08-2024", "api_base_url": "test-base-url", "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", @@ -124,7 +124,7 @@ def test_from_dict(self, monkeypatch): }, } component = CohereChatGenerator.from_dict(data) - assert component.model == "command-r" + assert component.model == "command-r-08-2024" assert component.streaming_callback is print_streaming_chunk assert component.api_base_url == "test-base-url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} @@ -135,7 +135,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model": "command-r", + "model": "command-r-08-2024", "api_base_url": "test-base-url", "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", @@ -226,7 +226,7 @@ def test_tools_use_old_way(self): }, } ] - client = CohereChatGenerator(model="command-r") + client = CohereChatGenerator(model="command-r-08-2024") response = client.run( messages=[ChatMessage.from_user("What is the current price of AAPL?")], generation_kwargs={"tools": tools_schema}, @@ -267,7 +267,7 @@ def test_tools_use_with_tools(self): function=stock_price, ) initial_messages = [ChatMessage.from_user("What is the current price of AAPL?")] - client = CohereChatGenerator(model="command-r") + client = CohereChatGenerator(model="command-r-08-2024") response = client.run( messages=initial_messages, tools=[stock_price_tool], @@ -327,7 +327,7 @@ def test_live_run_with_tools_streaming(self): initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] component = CohereChatGenerator( - model="command-r", # Cohere's model that supports tools + model="command-r-08-2024", # Cohere's model that supports tools tools=[weather_tool], streaming_callback=print_streaming_chunk, ) @@ -384,7 +384,7 @@ def test_pipeline_with_cohere_chat_generator(self): ) pipeline = Pipeline() - pipeline.add_component("generator", CohereChatGenerator(model="command-r", tools=[weather_tool])) + pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool])) pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool])) pipeline.connect("generator", "tool_invoker") @@ -416,7 +416,7 @@ def test_serde_in_pipeline(self, monkeypatch): # Create generator with specific configuration generator = CohereChatGenerator( - model="command-r", + model="command-r-08-2024", generation_kwargs={"temperature": 0.7}, streaming_callback=print_streaming_chunk, tools=[tool], @@ -437,7 +437,7 @@ def test_serde_in_pipeline(self, monkeypatch): "generator": { "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", # noqa: E501 "init_parameters": { - "model": "command-r", + "model": "command-r-08-2024", "api_key": {"type": "env_var", "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "api_base_url": "https://api.cohere.com",