From ab242fa38f9fa9840e1c03bea968e9dba9ac8a33 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Tue, 18 Mar 2025 23:22:53 +0100 Subject: [PATCH 1/4] feat: Add NvidiaChatGenerator --- .../components/generators/nvidia/__init__.py | 3 +- .../generators/nvidia/chat/chat_generator.py | 247 ++++++++++++++ .../utils/nvidia/models.py | 68 +--- .../utils/nvidia/nim_backend.py | 55 +++- .../utils/nvidia/utils.py | 11 +- .../nvidia/tests/test_chat_generator.py | 303 ++++++++++++++++++ integrations/nvidia/tests/test_utils.py | 4 +- 7 files changed, 622 insertions(+), 69 deletions(-) create mode 100644 integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py create mode 100644 integrations/nvidia/tests/test_chat_generator.py diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py index b809d83b9..c53880870 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from .chat.chat_generator import NvidiaChatGenerator from .generator import NvidiaGenerator -__all__ = ["NvidiaGenerator"] +__all__ = ["NvidiaChatGenerator", "NvidiaGenerator"] diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py new file mode 100644 index 000000000..cf1acdafb --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import warnings +from typing import Any, Dict, List, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import ChatMessage, ChatRole +from haystack.utils.auth import Secret, deserialize_secrets_inplace + +from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Model, NimBackend, is_hosted, url_validation + + +@component +class NvidiaChatGenerator: + """ + Generates responses using generative chat models hosted with + [NVIDIA NIM](https://ai.nvidia.com) on the [NVIDIA API Catalog](https://build.nvidia.com/explore/discover). + + This component uses the ChatMessage format to communicate with NVIDIA NIM models that support chat completion. + + ### Usage example + + ```python + from haystack_integrations.components.generators.nvidia import NvidiaChatGenerator + from haystack.dataclasses import ChatMessage + + generator = NvidiaChatGenerator( + model="meta/llama3-70b-instruct", + model_arguments={ + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + }, + ) + generator.warm_up() + + messages = [ + ChatMessage.from_system("You are a helpful assistant."), + ChatMessage.from_user("What is the answer to life, the universe, and everything?") + ] + result = generator.run(messages=messages) + print(result["replies"]) + ``` + + You need an NVIDIA API key for this component to work. + """ + + def __init__( + self, + model: Optional[str] = None, + api_url: str = os.getenv("NVIDIA_API_URL", DEFAULT_API_URL), + api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), + model_arguments: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, + ): + """ + Create a NvidiaChatGenerator component. + + :param model: + Name of the model to use for chat generation. + See the [NVIDIA NIMs](https://ai.nvidia.com) + for more information on the supported models. + `Note`: If no specific model along with locally hosted API URL is provided, + the system defaults to the available model found using /models API. + Check supported models at [NVIDIA NIM](https://ai.nvidia.com). + :param api_key: + API key for the NVIDIA NIM. Set it as the `NVIDIA_API_KEY` environment + variable or pass it here. + :param api_url: + Custom API URL for the NVIDIA NIM. + :param model_arguments: + Additional arguments to pass to the model provider. These arguments are + specific to a model. + Search your model in the [NVIDIA NIM](https://ai.nvidia.com) + to find the arguments it accepts. + :param timeout: + Timeout for request calls, if not set it is inferred from the `NVIDIA_TIMEOUT` environment variable + or set to 60 by default. + """ + self._model = model + self.api_url = url_validation(api_url) + self._api_key = api_key + self._model_arguments = model_arguments or {} + + self.backend: Optional[Any] = None + + self.is_hosted = is_hosted(api_url) + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", "60.0")) + self.timeout = timeout + + def default_model(self): + """Set default model in local NIM mode.""" + valid_models = [ + model.id for model in self.available_models if not model.base_model or model.base_model == model.id + ] + name = next(iter(valid_models), None) + if name: + warnings.warn( + f"Default model is set as: {name}. \n" + "Set model using model parameter. \n" + "To get available models use available_models property.", + UserWarning, + stacklevel=2, + ) + self._model = self.backend.model = name + else: + error_message = "No locally hosted model was found." + raise ValueError(error_message) + + def warm_up(self): + """ + Initializes the component. + """ + if self.backend is not None: + return + + self.backend = NimBackend( + model=self._model, + model_type="chat", + api_url=self.api_url, + api_key=self._api_key, + model_kwargs=self._model_arguments, + timeout=self.timeout, + client=self.__class__.__name__, + ) + + if not self.is_hosted and not self._model: + if self.backend.model: + self.model = self.backend.model + else: + self.default_model() + + @classmethod + def class_name(cls) -> str: + return "NvidiaChatGenerator" + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model=self._model, + api_url=self.api_url, + api_key=self._api_key.to_dict() if self._api_key else None, + model_arguments=self._model_arguments, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "NvidiaChatGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + init_params = data.get("init_parameters", {}) + deserialize_secrets_inplace(init_params, ["api_key"]) + return default_from_dict(cls, data) + + @property + def available_models(self) -> List[Model]: + """ + Get a list of available models that work with NvidiaChatGenerator. + """ + return self.backend.models() if self.backend else [] + + def _convert_messages_to_nvidia_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: + """ + Convert a list of messages to the format expected by NVIDIA NIM API. + + :param messages: The list of ChatMessages to convert. + :returns: A list of dictionaries in the format expected by NVIDIA NIM API. + """ + nvidia_messages = [] + + for message in messages: + if message.is_from(ChatRole.SYSTEM): + nvidia_messages.append({"role": "system", "content": message.text}) + elif message.is_from(ChatRole.USER): + nvidia_messages.append({"role": "user", "content": message.text}) + elif message.is_from(ChatRole.ASSISTANT): + nvidia_messages.append({"role": "assistant", "content": message.text}) + else: + # Skip other message types like tool messages for now + pass + + return nvidia_messages + + def _convert_nvidia_response_to_chat_message(self, response: Dict[str, Any]) -> ChatMessage: + """ + Convert the response from the NVIDIA API to a ChatMessage. + + :param response: The response from the NVIDIA API. + :returns: A ChatMessage object. + """ + text = response.get("content", "") + message = ChatMessage.from_assistant(text=text) + + # Add metadata to the message + message._meta.update({ + "model": response.get("model", None), + "finish_reason": response.get("finish_reason", None), + "usage": response.get("usage", {}), + }) + + return message + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + ): + """ + Invokes the NVIDIA NIM API with the given messages and generation kwargs. + + :param messages: A list of ChatMessage instances representing the input messages. + :returns: A dictionary with the following keys: + - `replies`: The responses from the model + """ + if self.backend is None: + msg = "The chat model has not been loaded. Call warm_up() before running." + raise RuntimeError(msg) + + # Convert messages to NVIDIA format + nvidia_messages = self._convert_messages_to_nvidia_format(messages) + + + # Call the backend and process response + assert self.backend is not None + + responses, _ = self.backend.generate_chat( + messages=nvidia_messages, + ) + + # Convert responses to ChatMessages + chat_messages = [self._convert_nvidia_response_to_chat_message(resp) for resp in responses] + return {"replies": chat_messages} diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/models.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/models.py index e996e9186..528cc99ef 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/models.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/models.py @@ -9,8 +9,8 @@ class Model: id: unique identifier for the model, passed as model parameter for requests model_type: API type (chat, vlm, embedding, ranking, completions) - client: client name, e.g. NvidiaGenerator, NVIDIAEmbeddings, - NVIDIARerank, NvidiaTextEmbedder, NvidiaDocumentEmbedder + client: client name, e.g. NvidiaGenerator, NvidiaChatGenerator, + NvidiaRanker, NvidiaTextEmbedder, NvidiaDocumentEmbedder endpoint: custom endpoint for the model aliases: list of aliases for the model @@ -19,7 +19,11 @@ class Model: id: str model_type: Optional[Literal["chat", "embedding", "ranking"]] = None - client: Optional[Literal["NvidiaGenerator", "NvidiaTextEmbedder", "NvidiaDocumentEmbedder", "NvidiaRanker"]] = None + client: Optional[ + Literal[ + "NvidiaGenerator", "NvidiaTextEmbedder", "NvidiaDocumentEmbedder", "NvidiaRanker", "NvidiaChatGenerator" + ] + ] = None endpoint: Optional[str] = None aliases: Optional[list] = None base_model: Optional[str] = None @@ -33,6 +37,7 @@ def validate(self): if self.client: supported = { "NvidiaGenerator": ("chat",), + "NvidiaChatGenerator": ("chat",), "NvidiaTextEmbedder": ("embedding",), "NvidiaDocumentEmbedder": ("embedding",), "NvidiaRanker": ("ranking",), @@ -51,7 +56,6 @@ def validate(self): "meta/codellama-70b": Model( id="meta/codellama-70b", model_type="chat", - client="NvidiaGenerator", aliases=[ "ai-codellama-70b", "playground_llama2_code_70b", @@ -65,13 +69,11 @@ def validate(self): "google/gemma-7b": Model( id="google/gemma-7b", model_type="chat", - client="NvidiaGenerator", aliases=["ai-gemma-7b", "playground_gemma_7b", "gemma_7b"], ), "meta/llama2-70b": Model( id="meta/llama2-70b", model_type="chat", - client="NvidiaGenerator", aliases=[ "ai-llama2-70b", "playground_llama2_70b", @@ -83,317 +85,263 @@ def validate(self): "mistralai/mistral-7b-instruct-v0.2": Model( id="mistralai/mistral-7b-instruct-v0.2", model_type="chat", - client="NvidiaGenerator", aliases=["ai-mistral-7b-instruct-v2", "playground_mistral_7b", "mistral_7b"], ), "mistralai/mixtral-8x7b-instruct-v0.1": Model( id="mistralai/mixtral-8x7b-instruct-v0.1", model_type="chat", - client="NvidiaGenerator", aliases=["ai-mixtral-8x7b-instruct", "playground_mixtral_8x7b", "mixtral_8x7b"], ), "google/codegemma-7b": Model( id="google/codegemma-7b", model_type="chat", - client="NvidiaGenerator", aliases=["ai-codegemma-7b"], ), "google/gemma-2b": Model( id="google/gemma-2b", model_type="chat", - client="NvidiaGenerator", aliases=["ai-gemma-2b", "playground_gemma_2b", "gemma_2b"], ), "google/recurrentgemma-2b": Model( id="google/recurrentgemma-2b", model_type="chat", - client="NvidiaGenerator", aliases=["ai-recurrentgemma-2b"], ), "mistralai/mistral-large": Model( id="mistralai/mistral-large", model_type="chat", - client="NvidiaGenerator", aliases=["ai-mistral-large"], ), "mistralai/mixtral-8x22b-instruct-v0.1": Model( id="mistralai/mixtral-8x22b-instruct-v0.1", model_type="chat", - client="NvidiaGenerator", aliases=["ai-mixtral-8x22b-instruct"], ), "meta/llama3-8b-instruct": Model( id="meta/llama3-8b-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-llama3-8b"], ), "meta/llama3-70b-instruct": Model( id="meta/llama3-70b-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-llama3-70b"], ), "microsoft/phi-3-mini-128k-instruct": Model( id="microsoft/phi-3-mini-128k-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-phi-3-mini"], ), "snowflake/arctic": Model( id="snowflake/arctic", model_type="chat", - client="NvidiaGenerator", aliases=["ai-arctic"], ), "databricks/dbrx-instruct": Model( id="databricks/dbrx-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-dbrx-instruct"], ), "microsoft/phi-3-mini-4k-instruct": Model( id="microsoft/phi-3-mini-4k-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-phi-3-mini-4k", "playground_phi2", "phi2"], ), "seallms/seallm-7b-v2.5": Model( id="seallms/seallm-7b-v2.5", model_type="chat", - client="NvidiaGenerator", aliases=["ai-seallm-7b"], ), "aisingapore/sea-lion-7b-instruct": Model( id="aisingapore/sea-lion-7b-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-sea-lion-7b-instruct"], ), "microsoft/phi-3-small-8k-instruct": Model( id="microsoft/phi-3-small-8k-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-phi-3-small-8k-instruct"], ), "microsoft/phi-3-small-128k-instruct": Model( id="microsoft/phi-3-small-128k-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-phi-3-small-128k-instruct"], ), "microsoft/phi-3-medium-4k-instruct": Model( id="microsoft/phi-3-medium-4k-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-phi-3-medium-4k-instruct"], ), "ibm/granite-8b-code-instruct": Model( id="ibm/granite-8b-code-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-granite-8b-code-instruct"], ), "ibm/granite-34b-code-instruct": Model( id="ibm/granite-34b-code-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-granite-34b-code-instruct"], ), "google/codegemma-1.1-7b": Model( id="google/codegemma-1.1-7b", model_type="chat", - client="NvidiaGenerator", aliases=["ai-codegemma-1.1-7b"], ), "mediatek/breeze-7b-instruct": Model( id="mediatek/breeze-7b-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-breeze-7b-instruct"], ), "upstage/solar-10.7b-instruct": Model( id="upstage/solar-10.7b-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-solar-10_7b-instruct"], ), "writer/palmyra-med-70b-32k": Model( id="writer/palmyra-med-70b-32k", model_type="chat", - client="NvidiaGenerator", aliases=["ai-palmyra-med-70b-32k"], ), "writer/palmyra-med-70b": Model( id="writer/palmyra-med-70b", model_type="chat", - client="NvidiaGenerator", aliases=["ai-palmyra-med-70b"], ), "mistralai/mistral-7b-instruct-v0.3": Model( id="mistralai/mistral-7b-instruct-v0.3", model_type="chat", - client="NvidiaGenerator", aliases=["ai-mistral-7b-instruct-v03"], ), "01-ai/yi-large": Model( id="01-ai/yi-large", model_type="chat", - client="NvidiaGenerator", aliases=["ai-yi-large"], ), "nvidia/nemotron-4-340b-instruct": Model( id="nvidia/nemotron-4-340b-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["qa-nemotron-4-340b-instruct"], ), "mistralai/codestral-22b-instruct-v0.1": Model( id="mistralai/codestral-22b-instruct-v0.1", model_type="chat", - client="NvidiaGenerator", aliases=["ai-codestral-22b-instruct-v01"], supports_structured_output=True, ), "google/gemma-2-9b-it": Model( id="google/gemma-2-9b-it", model_type="chat", - client="NvidiaGenerator", aliases=["ai-gemma-2-9b-it"], ), "google/gemma-2-27b-it": Model( id="google/gemma-2-27b-it", model_type="chat", - client="NvidiaGenerator", aliases=["ai-gemma-2-27b-it"], ), "microsoft/phi-3-medium-128k-instruct": Model( id="microsoft/phi-3-medium-128k-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-phi-3-medium-128k-instruct"], ), "deepseek-ai/deepseek-coder-6.7b-instruct": Model( id="deepseek-ai/deepseek-coder-6.7b-instruct", model_type="chat", - client="NvidiaGenerator", aliases=["ai-deepseek-coder-6_7b-instruct"], ), "nv-mistralai/mistral-nemo-12b-instruct": Model( id="nv-mistralai/mistral-nemo-12b-instruct", model_type="chat", - client="NvidiaGenerator", supports_tools=True, supports_structured_output=True, ), "meta/llama-3.1-8b-instruct": Model( id="meta/llama-3.1-8b-instruct", model_type="chat", - client="NvidiaGenerator", supports_tools=True, supports_structured_output=True, ), "meta/llama-3.1-70b-instruct": Model( id="meta/llama-3.1-70b-instruct", model_type="chat", - client="NvidiaGenerator", supports_tools=True, supports_structured_output=True, ), "meta/llama-3.1-405b-instruct": Model( id="meta/llama-3.1-405b-instruct", model_type="chat", - client="NvidiaGenerator", supports_tools=True, supports_structured_output=True, ), "nvidia/usdcode-llama3-70b-instruct": Model( id="nvidia/usdcode-llama3-70b-instruct", model_type="chat", - client="NvidiaGenerator", ), "mistralai/mamba-codestral-7b-v0.1": Model( id="mistralai/mamba-codestral-7b-v0.1", model_type="chat", - client="NvidiaGenerator", ), "writer/palmyra-fin-70b-32k": Model( id="writer/palmyra-fin-70b-32k", model_type="chat", - client="NvidiaGenerator", supports_structured_output=True, ), "google/gemma-2-2b-it": Model( id="google/gemma-2-2b-it", model_type="chat", - client="NvidiaGenerator", ), "mistralai/mistral-large-2-instruct": Model( id="mistralai/mistral-large-2-instruct", model_type="chat", - client="NvidiaGenerator", supports_tools=True, supports_structured_output=True, ), "mistralai/mathstral-7b-v0.1": Model( id="mistralai/mathstral-7b-v0.1", model_type="chat", - client="NvidiaGenerator", ), "rakuten/rakutenai-7b-instruct": Model( id="rakuten/rakutenai-7b-instruct", model_type="chat", - client="NvidiaGenerator", ), "rakuten/rakutenai-7b-chat": Model( id="rakuten/rakutenai-7b-chat", model_type="chat", - client="NvidiaGenerator", ), "baichuan-inc/baichuan2-13b-chat": Model( id="baichuan-inc/baichuan2-13b-chat", model_type="chat", - client="NvidiaGenerator", ), "thudm/chatglm3-6b": Model( id="thudm/chatglm3-6b", model_type="chat", - client="NvidiaGenerator", ), "microsoft/phi-3.5-mini-instruct": Model( id="microsoft/phi-3.5-mini-instruct", model_type="chat", - client="NvidiaGenerator", ), "microsoft/phi-3.5-moe-instruct": Model( id="microsoft/phi-3.5-moe-instruct", model_type="chat", - client="NvidiaGenerator", ), "nvidia/nemotron-mini-4b-instruct": Model( id="nvidia/nemotron-mini-4b-instruct", model_type="chat", - client="NvidiaGenerator", ), "ai21labs/jamba-1.5-large-instruct": Model( id="ai21labs/jamba-1.5-large-instruct", model_type="chat", - client="NvidiaGenerator", ), "ai21labs/jamba-1.5-mini-instruct": Model( id="ai21labs/jamba-1.5-mini-instruct", model_type="chat", - client="NvidiaGenerator", ), "yentinglin/llama-3-taiwan-70b-instruct": Model( id="yentinglin/llama-3-taiwan-70b-instruct", model_type="chat", - client="NvidiaGenerator", ), "tokyotech-llm/llama-3-swallow-70b-instruct-v0.1": Model( id="tokyotech-llm/llama-3-swallow-70b-instruct-v0.1", model_type="chat", - client="NvidiaGenerator", ), } diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py index 4c6632c2d..6d2a35558 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py @@ -27,7 +27,9 @@ def __init__( api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), model_kwargs: Optional[Dict[str, Any]] = None, client: Optional[ - Literal["NvidiaGenerator", "NvidiaTextEmbedder", "NvidiaDocumentEmbedder", "NvidiaRanker"] + Literal[ + "NvidiaGenerator", "NvidiaTextEmbedder", "NvidiaDocumentEmbedder", "NvidiaRanker", "NvidiaChatGenerator" + ] ] = None, timeout: Optional[float] = None, ): @@ -63,8 +65,6 @@ def __init__( self.model_kwargs = model_kwargs or {} self.client = client self.model_type = model_type - self.client = client - self.model_type = model_type if timeout is None: timeout = float(os.environ.get("NVIDIA_TIMEOUT", REQUEST_TIMEOUT)) self.timeout = timeout @@ -147,6 +147,55 @@ def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]: return replies, meta + def generate_chat( + self, + messages: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """ + Generate chat completions from a list of messages. + + :param messages: List of messages in the format expected by the NIM API + :returns: A tuple containing the response and metadata + """ + url = f"{self.api_url}/chat/completions" + + json_data = { + "model": self.model, + "messages": messages, + **self.model_kwargs, + } + + try: + res = self.session.post( + url, + json=json_data, + timeout=self.timeout, + ) + res.raise_for_status() + + except requests.HTTPError as e: + logger.error("Error when calling NIM chat completion endpoint: Error - {error}", error=e.response.text) + msg = f"Failed to query chat completion endpoint: Error - {e.response.text}" + raise ValueError(msg) from e + + # Process the response to extract content and metadata + data = res.json() + responses = [] + for choice in data.get("choices", []): + message = choice.get("message", {}) + response_data = { + "content": message.get("content", ""), + "model": data.get("model", self.model), + "finish_reason": choice.get("finish_reason", None), + } + # Add usage information if available + if "usage" in data: + response_data["usage"] = data["usage"] + responses.append(response_data) + + meta = {"model": data.get("model", self.model)} + return responses, meta + def models(self) -> List[Model]: url = f"{self.api_url}/models" diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py index 987976027..3ecd2109d 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py @@ -84,15 +84,19 @@ def determine_model(name: str) -> Optional[Model]: def validate_hosted_model( model_name: str, - client: Optional[Literal["NvidiaGenerator", "NvidiaTextEmbedder", "NvidiaDocumentEmbedder", "NvidiaRanker"]] = None, + client: Optional[ + Literal[ + "NvidiaGenerator", "NvidiaTextEmbedder", "NvidiaDocumentEmbedder", "NvidiaRanker", "NvidiaChatGenerator" + ] + ] = None, ) -> Any: """ Checks if a given model is compatible with given client. Args: model_name (str): The name of the model. - client (str): client name, e.g. NvidiaGenerator, NVIDIAEmbeddings, - NVIDIARerank, NvidiaTextEmbedder, NvidiaDocumentEmbedder + client (str): client name, e.g. NvidiaGenerator, NvidiaChatGenerator, + NvidiaRanker, NvidiaTextEmbedder, NvidiaDocumentEmbedder Raises: ValueError: If the model is incompatible with the client or if the model is unknown. @@ -100,6 +104,7 @@ def validate_hosted_model( """ supported = { "NvidiaGenerator": ("chat",), + "NvidiaChatGenerator": ("chat",), "NvidiaTextEmbedder": ("embedding",), "NvidiaDocumentEmbedder": ("embedding",), "NvidiaRanker": ("ranking",), diff --git a/integrations/nvidia/tests/test_chat_generator.py b/integrations/nvidia/tests/test_chat_generator.py new file mode 100644 index 000000000..ceb3be320 --- /dev/null +++ b/integrations/nvidia/tests/test_chat_generator.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from unittest.mock import MagicMock, patch + +import pytest +from haystack.dataclasses import ChatMessage +from haystack.utils import Secret +from requests_mock import Mocker + +from haystack_integrations.components.generators.nvidia.chat.chat_generator import NvidiaChatGenerator + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant."), + ChatMessage.from_user("What is the answer to life, the universe, and everything?"), + ] + + +@pytest.fixture +def mock_backend(): + with patch("haystack_integrations.components.generators.nvidia.chat.chat_generator.NimBackend") as mock: + backend_instance = MagicMock() + backend_instance.model = "meta/llama3-70b-instruct" + backend_instance.models.return_value = [ + MagicMock(id="model1", base_model="model1"), + MagicMock(id="model2", base_model="model2"), + ] + # Mock the generate_chat method to return a sample response + backend_instance.generate_chat.return_value = ( + [{"content": "42", "model": "meta/llama3-70b-instruct", "finish_reason": "stop"}], + {"model": "meta/llama3-70b-instruct"}, + ) + mock.return_value = backend_instance + yield mock + + +@pytest.fixture +def mock_local_chat_completion(requests_mock: Mocker) -> None: + requests_mock.post( + "http://localhost:8080/v1/chat/completions", + json={ + "choices": [ + { + "message": {"content": "The answer is 42.", "role": "assistant"}, + "finish_reason": "stop", + "index": 0, + } + ], + "usage": { + "prompt_tokens": 25, + "total_tokens": 30, + "completion_tokens": 5, + }, + "model": "meta/llama3-70b-instruct", + }, + ) + + +class TestNvidiaChatGenerator: + def test_init_default(self, monkeypatch): + """Test default initialization""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaChatGenerator("meta/llama3-70b-instruct") + + assert generator._api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert generator._model == "meta/llama3-70b-instruct" + assert generator._model_arguments == {} + + def test_init_with_parameters(self): + """Test initialization with parameters""" + generator = NvidiaChatGenerator( + api_key=Secret.from_token("fake-api-key"), + model="meta/llama3-70b-instruct", + model_arguments={ + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + }, + ) + assert generator._api_key == Secret.from_token("fake-api-key") + assert generator._model == "meta/llama3-70b-instruct" + assert generator._model_arguments == { + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + } + + def test_init_fail_wo_api_key(self, monkeypatch): + """Test initialization fails without API key""" + monkeypatch.delenv("NVIDIA_API_KEY", raising=False) + generator = NvidiaChatGenerator("meta/llama3-70b-instruct") + with pytest.raises(ValueError): + generator.warm_up() + + def test_to_dict(self, monkeypatch): + """Test serialization to dictionary""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaChatGenerator("meta/llama3-70b-instruct") + data = generator.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.nvidia.chat.chat_generator.NvidiaChatGenerator", + "init_parameters": { + "api_url": "https://integrate.api.nvidia.com/v1", + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, + "model": "meta/llama3-70b-instruct", + "model_arguments": {}, + }, + } + + def test_to_dict_with_custom_init_parameters(self, monkeypatch): + """Test serialization with custom init parameters""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaChatGenerator( + model="meta/llama3-70b-instruct", + api_url="https://my.url.com", + model_arguments={ + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + }, + ) + data = generator.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.nvidia.chat.chat_generator.NvidiaChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, + "api_url": "https://my.url.com/v1", + "model": "meta/llama3-70b-instruct", + "model_arguments": { + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + }, + }, + } + + def test_from_dict(self, monkeypatch): + """Test deserialization from dictionary""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.generators.nvidia.chat.chat_generator.NvidiaChatGenerator", + "init_parameters": { + "api_url": "https://my.url.com/v1", + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, + "model": "meta/llama3-70b-instruct", + "model_arguments": { + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + }, + }, + } + generator = NvidiaChatGenerator.from_dict(data) + assert generator._api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert generator._model == "meta/llama3-70b-instruct" + assert generator._model_arguments == { + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + } + assert generator.api_url == "https://my.url.com/v1" + + def test_warm_up_with_model(self, mock_backend, monkeypatch): + """Test warm_up initializes backend with model""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaChatGenerator("meta/llama3-70b-instruct") + generator.warm_up() + + mock_backend.assert_called_once() + call_kwargs = mock_backend.call_args[1] + assert call_kwargs["model"] == "meta/llama3-70b-instruct" + assert call_kwargs["model_type"] == "chat" + assert call_kwargs["client"] == "NvidiaChatGenerator" + + def test_warm_up_without_model_local(self, mock_backend, monkeypatch): + """Test warm_up sets default model when none provided for local backend""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + + # Override the mock to set model to None to trigger default_model call + mock_backend.return_value.model = None + + generator = NvidiaChatGenerator(model=None, api_url="http://localhost:8080") + generator.is_hosted = False # Force local mode + + with pytest.warns(UserWarning, match="Default model is set as:"): + generator.warm_up() + + assert generator._model == "model1" # Should be set to first model in mocked models list + + def test_run(self, mock_backend, monkeypatch, chat_messages): + """Test run method with regular non-streaming response""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaChatGenerator("meta/llama3-70b-instruct") + generator.warm_up() + + result = generator.run(messages=chat_messages) + + mock_backend().generate_chat.assert_called_once() + + assert "replies" in result + assert len(result["replies"]) == 1 + assert result["replies"][0].text == "42" + assert result["replies"][0].meta["model"] == "meta/llama3-70b-instruct" + assert result["replies"][0].meta["finish_reason"] == "stop" + + def test_convert_messages_to_nvidia_format(self, monkeypatch): + """Test conversion of ChatMessages to NVIDIA format""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaChatGenerator("meta/llama3-70b-instruct") + + messages = [ + ChatMessage.from_system("You are a helpful assistant."), + ChatMessage.from_user("What is the answer?"), + ChatMessage.from_assistant("The answer is 42."), + ] + + nvidia_messages = generator._convert_messages_to_nvidia_format(messages) + + assert len(nvidia_messages) == 3 + assert nvidia_messages[0] == {"role": "system", "content": "You are a helpful assistant."} + assert nvidia_messages[1] == {"role": "user", "content": "What is the answer?"} + assert nvidia_messages[2] == {"role": "assistant", "content": "The answer is 42."} + + def test_convert_nvidia_response_to_chat_message(self, monkeypatch): + """Test conversion of NVIDIA response to ChatMessage""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaChatGenerator("meta/llama3-70b-instruct") + + nvidia_response = { + "content": "The answer is 42.", + "model": "meta/llama3-70b-instruct", + "finish_reason": "stop", + "usage": {"prompt_tokens": 25, "completion_tokens": 5, "total_tokens": 30}, + } + + chat_message = generator._convert_nvidia_response_to_chat_message(nvidia_response) + + assert chat_message.text == "The answer is 42." + assert chat_message.meta["model"] == "meta/llama3-70b-instruct" + assert chat_message.meta["finish_reason"] == "stop" + assert chat_message.meta["usage"] == {"prompt_tokens": 25, "completion_tokens": 5, "total_tokens": 30} + + def test_error_if_warm_up_not_called(self, monkeypatch, chat_messages): + """Test error is raised if warm_up not called""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaChatGenerator("meta/llama3-70b-instruct") + + with pytest.raises(RuntimeError, match="The chat model has not been loaded"): + generator.run(messages=chat_messages) + + def test_setting_timeout(self, monkeypatch, mock_backend): + """Test timeout setting""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaChatGenerator(timeout=10.0) + generator.warm_up() + + assert mock_backend.call_args[1]["timeout"] == 10.0 + + def test_setting_timeout_env(self, monkeypatch, mock_backend): + """Test timeout from environment variable""" + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + monkeypatch.setenv("NVIDIA_TIMEOUT", "45") + generator = NvidiaChatGenerator() + generator.warm_up() + + assert mock_backend.call_args[1]["timeout"] == 45.0 + + @pytest.mark.skipif( + not os.environ.get("NVIDIA_API_KEY", None), + reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", + ) + @pytest.mark.integration + def test_run_integration_with_api_catalog(self): + """Integration test with NVIDIA API Catalog""" + generator = NvidiaChatGenerator( + model="meta/llama3-70b-instruct", + api_url="https://integrate.api.nvidia.com/v1", + api_key=Secret.from_env_var("NVIDIA_API_KEY"), + model_arguments={ + "temperature": 0.2, + "max_tokens": 50, + }, + ) + generator.warm_up() + + messages = [ + ChatMessage.from_system("You are a helpful assistant. Keep your answers brief."), + ChatMessage.from_user("What is the answer to life, the universe, and everything?"), + ] + + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + assert isinstance(result["replies"][0], ChatMessage) + assert len(result["replies"][0].text) > 0 + assert result["replies"][0].meta["model"] is not None + assert result["replies"][0].meta["finish_reason"] is not None diff --git a/integrations/nvidia/tests/test_utils.py b/integrations/nvidia/tests/test_utils.py index 0560856d3..c7031862a 100644 --- a/integrations/nvidia/tests/test_utils.py +++ b/integrations/nvidia/tests/test_utils.py @@ -107,6 +107,6 @@ def test_validate_hosted_model_without_client() -> None: def test_validate_hosted_model_with_client() -> None: """Test when model's client matches the provided client.""" - model = validate_hosted_model("meta/codellama-70b", "NvidiaGenerator") + model = validate_hosted_model("nvidia/llama-3.2-nv-rerankqa-1b-v1", "NvidiaRanker") assert model is not None - assert model.client == "NvidiaGenerator" + assert model.client == "NvidiaRanker" From 9e385bb0474639953a1f63391d4b9ad56a88eef1 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Tue, 18 Mar 2025 23:27:52 +0100 Subject: [PATCH 2/4] fmt --- .../generators/nvidia/chat/chat_generator.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py index cf1acdafb..a47f5ff78 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py @@ -207,11 +207,13 @@ def _convert_nvidia_response_to_chat_message(self, response: Dict[str, Any]) -> message = ChatMessage.from_assistant(text=text) # Add metadata to the message - message._meta.update({ - "model": response.get("model", None), - "finish_reason": response.get("finish_reason", None), - "usage": response.get("usage", {}), - }) + message._meta.update( + { + "model": response.get("model", None), + "finish_reason": response.get("finish_reason", None), + "usage": response.get("usage", {}), + } + ) return message @@ -234,7 +236,6 @@ def run( # Convert messages to NVIDIA format nvidia_messages = self._convert_messages_to_nvidia_format(messages) - # Call the backend and process response assert self.backend is not None From 8e228c29bf640d9c4bbdd0cde9f5167f6408f42a Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Wed, 19 Mar 2025 17:51:33 +0100 Subject: [PATCH 3/4] add alternative --- .../components/generators/nvidia/__init__.py | 3 +- .../nvidia/chat/alternative_chat_generator.py | 121 +++++++++++ .../tests/test_alternative_chat_generator.py | 203 ++++++++++++++++++ 3 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/alternative_chat_generator.py create mode 100644 integrations/nvidia/tests/test_alternative_chat_generator.py diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py index c53880870..72d334937 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py @@ -2,7 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 +from .chat.alternative_chat_generator import AlternativeNvidiaChatGenerator from .chat.chat_generator import NvidiaChatGenerator from .generator import NvidiaGenerator -__all__ = ["NvidiaChatGenerator", "NvidiaGenerator"] +__all__ = ["AlternativeNvidiaChatGenerator", "NvidiaChatGenerator", "NvidiaGenerator"] diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/alternative_chat_generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/alternative_chat_generator.py new file mode 100644 index 000000000..eec030d2f --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/alternative_chat_generator.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Any, Callable, Dict, Optional + +from haystack import component, default_to_dict +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import StreamingChunk +from haystack.utils import serialize_callable +from haystack.utils.auth import Secret + +from haystack_integrations.utils.nvidia import DEFAULT_API_URL + + +@component +class AlternativeNvidiaChatGenerator(OpenAIChatGenerator): + """ + Generates responses using generative chat models hosted with + [NVIDIA NIM](https://ai.nvidia.com) on the [NVIDIA API Catalog](https://build.nvidia.com/explore/discover). + + This component uses the ChatMessage format to communicate with NVIDIA NIM models that support chat completion. + + ### Usage example + + ```python + from haystack_integrations.components.generators.nvidia import NvidiaChatGenerator + from haystack.dataclasses import ChatMessage + + generator = AlternativeNvidiaChatGenerator( + model="meta/llama3-70b-instruct", + model_arguments={ + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + }, + ) + + messages = [ + ChatMessage.from_system("You are a helpful assistant."), + ChatMessage.from_user("What is the answer to life, the universe, and everything?") + ] + result = generator.run(messages=messages) + print(result["replies"]) + ``` + + You need an NVIDIA API key for this component to work. + """ + + def __init__( + self, + model: Optional[str] = None, + api_base_url: str = os.getenv("NVIDIA_API_URL", DEFAULT_API_URL), + api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, + ): + """ + Create a NvidiaChatGenerator component. + + :param model: + Name of the model to use for chat generation. + See the [NVIDIA NIMs](https://ai.nvidia.com) + for more information on the supported models. + `Note`: If no specific model along with locally hosted API URL is provided, + the system defaults to the available model found using /models API. + Check supported models at [NVIDIA NIM](https://ai.nvidia.com). + :param api_key: + API key for the NVIDIA NIM. Set it as the `NVIDIA_API_KEY` environment + variable or pass it here. + :param api_base_url: + Custom API URL for the NVIDIA NIM. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param generation_kwargs: + Additional arguments to pass to the model provider. These arguments are + specific to a model. + Search your model in the [NVIDIA NIM](https://ai.nvidia.com) + to find the arguments it accepts. + :param timeout: + Timeout for request calls, if not set it is inferred from the `NVIDIA_TIMEOUT` environment variable + or set to 60 by default. + """ + + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", "60.0")) + + super(AlternativeNvidiaChatGenerator, self).__init__( # noqa: UP008 + api_key=api_key, + model=model, + streaming_callback=streaming_callback, + api_base_url=api_base_url, + generation_kwargs=generation_kwargs, + timeout=timeout, + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + + # if we didn't implement the to_dict method here then the to_dict method of the superclass would be used + # which would serialize some fields that we don't want to serialize (e.g. the ones we don't have in + # the __init__) + # it would be hard to maintain the compatibility as superclass changes + return default_to_dict( + self, + model=self.model, + streaming_callback=callback_name, + api_base_url=self.api_base_url, + generation_kwargs=self.generation_kwargs, + api_key=self.api_key.to_dict(), + timeout=self.timeout, + ) diff --git a/integrations/nvidia/tests/test_alternative_chat_generator.py b/integrations/nvidia/tests/test_alternative_chat_generator.py new file mode 100644 index 000000000..dd997fb0b --- /dev/null +++ b/integrations/nvidia/tests/test_alternative_chat_generator.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from datetime import datetime +from unittest.mock import patch + +import pytest +from haystack.dataclasses import ChatMessage +from haystack.utils import Secret +from openai.types.chat import ChatCompletion +from requests_mock import Mocker + +from haystack_integrations.components.generators.nvidia.chat.alternative_chat_generator import ( + AlternativeNvidiaChatGenerator, +) + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant."), + ChatMessage.from_user("What is the answer to life, the universe, and everything?"), + ] + + +@pytest.fixture +def openai_mock_chat_completion(): + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletion( + id="foo", + model="gpt-4", + object="chat.completion", + choices=[ + { + "finish_reason": "stop", + "logprobs": None, + "index": 0, + "message": {"content": "Hello world!", "role": "assistant"}, + } + ], + created=int(datetime.now().timestamp()), # noqa: DTZ005 + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + + mock_chat_completion_create.return_value = completion + yield mock_chat_completion_create + + +@pytest.fixture +def mock_local_chat_completion(requests_mock: Mocker) -> None: + requests_mock.post( + "http://localhost:8080/v1/chat/completions", + json={ + "choices": [ + { + "message": {"content": "The answer is 42.", "role": "assistant"}, + "finish_reason": "stop", + "index": 0, + } + ], + "usage": { + "prompt_tokens": 25, + "total_tokens": 30, + "completion_tokens": 5, + }, + "model": "meta/llama3-70b-instruct", + }, + ) + + +class TestAlternativeNvidiaChatGenerator: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = AlternativeNvidiaChatGenerator("meta/llama3-70b-instruct") + + assert generator.api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert generator.model == "meta/llama3-70b-instruct" + assert generator.generation_kwargs == {} + + def test_init_with_parameters(self): + generator = AlternativeNvidiaChatGenerator( + api_key=Secret.from_token("fake-api-key"), + model="meta/llama3-70b-instruct", + generation_kwargs={ + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + }, + ) + assert generator.api_key == Secret.from_token("fake-api-key") + assert generator.model == "meta/llama3-70b-instruct" + assert generator.generation_kwargs == { + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + } + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("NVIDIA_API_KEY", raising=False) + with pytest.raises(ValueError): + AlternativeNvidiaChatGenerator("meta/llama3-70b-instruct") + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = AlternativeNvidiaChatGenerator("meta/llama3-70b-instruct") + data = generator.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.nvidia.chat." + "alternative_chat_generator.AlternativeNvidiaChatGenerator", + "init_parameters": { + "api_base_url": "https://integrate.api.nvidia.com/v1", + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, + "model": "meta/llama3-70b-instruct", + "generation_kwargs": {}, + "streaming_callback": None, + "timeout": 60.0, + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.generators.nvidia.chat." + "alternative_chat_generator.AlternativeNvidiaChatGenerator", + "init_parameters": { + "api_base_url": "https://my.url.com/v1", + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, + "model": "meta/llama3-70b-instruct", + "generation_kwargs": { + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + }, + "streaming_callback": None, + "timeout": 60.0, + }, + } + generator = AlternativeNvidiaChatGenerator.from_dict(data) + assert generator.api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert generator.model == "meta/llama3-70b-instruct" + assert generator.generation_kwargs == { + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + } + assert generator.api_base_url == "https://my.url.com/v1" + + def test_run(self, chat_messages, openai_mock_chat_completion): + generator = AlternativeNvidiaChatGenerator( + model="meta/llama3-70b-instruct", api_key=Secret.from_token("test-api-key") + ) + + response = generator.run(chat_messages) + + _, kwargs = openai_mock_chat_completion.call_args + assert kwargs["model"] == "meta/llama3-70b-instruct" + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_with_generation_kwargs(self, chat_messages, openai_mock_chat_completion): + generator = AlternativeNvidiaChatGenerator( + model="meta/llama3-70b-instruct", + api_key=Secret.from_token("test-api-key"), + generation_kwargs={"temperature": 0.7}, + ) + + generator.run(messages=chat_messages, generation_kwargs={"max_tokens": 100, "temperature": 0.2}) + + # Verify parameters are merged correctly (run kwargs override init kwargs) + # and the component calls the API with the correct parameters + _, kwargs = openai_mock_chat_completion.call_args + assert kwargs["max_tokens"] == 100 + assert kwargs["temperature"] == 0.2 + + @pytest.mark.skipif( + not os.environ.get("NVIDIA_API_KEY", None), + reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", + ) + @pytest.mark.integration + def test_run_integration(self): + generator = AlternativeNvidiaChatGenerator( + model="meta/llama3-70b-instruct", + api_key=Secret.from_env_var("NVIDIA_API_KEY"), + generation_kwargs={"temperature": 0.2, "max_tokens": 50}, + ) + + messages = [ + ChatMessage.from_system("You are a helpful assistant. Keep your answers brief."), + ChatMessage.from_user("What is the answer to life, the universe, and everything?"), + ] + + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + assert isinstance(result["replies"][0], ChatMessage) + assert len(result["replies"][0].text) > 0 From 0f22a7576aa692e36052f2f7a35b2cd25cc84656 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Wed, 19 Mar 2025 18:02:02 +0100 Subject: [PATCH 4/4] openai types --- integrations/nvidia/tests/test_alternative_chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/nvidia/tests/test_alternative_chat_generator.py b/integrations/nvidia/tests/test_alternative_chat_generator.py index dd997fb0b..d30e01e75 100644 --- a/integrations/nvidia/tests/test_alternative_chat_generator.py +++ b/integrations/nvidia/tests/test_alternative_chat_generator.py @@ -9,7 +9,7 @@ import pytest from haystack.dataclasses import ChatMessage from haystack.utils import Secret -from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletion # type: ignore from requests_mock import Mocker from haystack_integrations.components.generators.nvidia.chat.alternative_chat_generator import (