From 3d9eb22d2c6b878bbc42400962537a80010122cf Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 22 Apr 2025 09:17:51 +0200 Subject: [PATCH 1/4] feat(config): add validation for prompt fields Added `ge=1` validation to `max_length` and `max_tokens` fields in the TaskPrompt class to ensure values are greater than or equal to 1. --- nemoguardrails/rails/llm/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index acb5d5419..9243a5b91 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -305,6 +305,7 @@ class TaskPrompt(BaseModel): max_length: Optional[int] = Field( default=16000, description="The maximum length of the prompt in number of characters.", + ge=1, ) mode: Optional[str] = Field( default=_default_config["prompting_mode"], @@ -318,6 +319,7 @@ class TaskPrompt(BaseModel): max_tokens: Optional[int] = Field( default=None, description="The maximum number of tokens that can be generated in the chat completion.", + ge=1, ) @root_validator(pre=True, allow_reuse=True) From f6b3d89853c641db1e62016da4094e5238f24478 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 22 Apr 2025 09:18:42 +0200 Subject: [PATCH 2/4] fix(config): replace ValidationError with ValueError Replaced the use of Pydantic's ValidationError with ValueError in the TaskPrompt class's root_validator. This resolves a TypeError caused by incorrect instantiation of ValidationError during validation. --- nemoguardrails/rails/llm/config.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 9243a5b91..1c3d5c0e5 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -325,12 +325,10 @@ class TaskPrompt(BaseModel): @root_validator(pre=True, allow_reuse=True) def check_fields(cls, values): if not values.get("content") and not values.get("messages"): - raise ValidationError("One of `content` or `messages` must be provided.") + raise ValueError("One of `content` or `messages` must be provided.") if values.get("content") and values.get("messages"): - raise ValidationError( - "Only one of `content` or `messages` must be provided." - ) + raise ValueError("Only one of `content` or `messages` must be provided.") return values From 6e0f7d220c853ba5903cc1d658e45dffad66d93f Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 22 Apr 2025 09:19:05 +0200 Subject: [PATCH 3/4] refactor(config): remove unused Flow import --- nemoguardrails/rails/llm/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 1c3d5c0e5..a8768a8bd 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -34,7 +34,6 @@ from nemoguardrails import utils from nemoguardrails.colang import parse_colang_file, parse_flow_elements -from nemoguardrails.colang.v2_x.lang.colang_ast import Flow from nemoguardrails.colang.v2_x.lang.utils import format_colang_parsing_error_message from nemoguardrails.colang.v2_x.runtime.errors import ColangParsingError From 81236d63ea5b9d3fb8b80e85343615794fecc72a Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 22 Apr 2025 09:19:55 +0200 Subject: [PATCH 4/4] feat(tests): add unit tests for TaskPrompt validation --- tests/rails/llm/test_config.py | 125 +++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 tests/rails/llm/test_config.py diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py new file mode 100644 index 000000000..7213c56cc --- /dev/null +++ b/tests/rails/llm/test_config.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pydantic import ValidationError + +from nemoguardrails.rails.llm.config import TaskPrompt + + +def test_task_prompt_valid_content(): + prompt = TaskPrompt(task="example_task", content="This is a valid prompt.") + assert prompt.task == "example_task" + assert prompt.content == "This is a valid prompt." + assert prompt.messages is None + + +def test_task_prompt_valid_messages(): + prompt = TaskPrompt(task="example_task", messages=["Hello", "How can I help you?"]) + assert prompt.task == "example_task" + assert prompt.messages == ["Hello", "How can I help you?"] + assert prompt.content is None + + +def test_task_prompt_missing_content_and_messages(): + with pytest.raises(ValidationError) as excinfo: + TaskPrompt(task="example_task") + assert "One of `content` or `messages` must be provided." in str(excinfo.value) + + +def test_task_prompt_both_content_and_messages(): + with pytest.raises(ValidationError) as excinfo: + TaskPrompt( + task="example_task", + content="This is a prompt.", + messages=["Hello", "How can I help you?"], + ) + assert "Only one of `content` or `messages` must be provided." in str(excinfo.value) + + +def test_task_prompt_models_validation(): + prompt = TaskPrompt( + task="example_task", + content="Test prompt", + models=["openai", "openai/gpt-3.5-turbo"], + ) + assert prompt.models == ["openai", "openai/gpt-3.5-turbo"] + + prompt = TaskPrompt(task="example_task", content="Test prompt", models=[]) + assert prompt.models == [] + + prompt = TaskPrompt(task="example_task", content="Test prompt", models=None) + assert prompt.models is None + + +def test_task_prompt_max_length_validation(): + prompt = TaskPrompt(task="example_task", content="Test prompt") + assert prompt.max_length == 16000 + + prompt = TaskPrompt(task="example_task", content="Test prompt", max_length=1000) + assert prompt.max_length == 1000 + + with pytest.raises(ValidationError) as excinfo: + TaskPrompt(task="example_task", content="Test prompt", max_length=0) + assert "Input should be greater than or equal to 1" in str(excinfo.value) + + with pytest.raises(ValidationError) as excinfo: + TaskPrompt(task="example_task", content="Test prompt", max_length=-1) + assert "Input should be greater than or equal to 1" in str(excinfo.value) + + +def test_task_prompt_mode_validation(): + prompt = TaskPrompt(task="example_task", content="Test prompt") + # default mode is "standard" + assert prompt.mode == "standard" + + prompt = TaskPrompt(task="example_task", content="Test prompt", mode="chat") + assert prompt.mode == "chat" + + prompt = TaskPrompt(task="example_task", content="Test prompt", mode=None) + assert prompt.mode is None + + +def test_task_prompt_stop_tokens_validation(): + prompt = TaskPrompt( + task="example_task", content="Test prompt", stop=["\n", "Human:", "Assistant:"] + ) + assert prompt.stop == ["\n", "Human:", "Assistant:"] + + prompt = TaskPrompt(task="example_task", content="Test prompt", stop=[]) + assert prompt.stop == [] + + prompt = TaskPrompt(task="example_task", content="Test prompt", stop=None) + assert prompt.stop is None + + with pytest.raises(ValidationError) as excinfo: + TaskPrompt(task="example_task", content="Test prompt", stop=[1, 2, 3]) + assert "Input should be a valid string" in str(excinfo.value) + + +def test_task_prompt_max_tokens_validation(): + prompt = TaskPrompt(task="example_task", content="Test prompt") + assert prompt.max_tokens is None + + prompt = TaskPrompt(task="example_task", content="Test prompt", max_tokens=1000) + assert prompt.max_tokens == 1000 + + with pytest.raises(ValidationError) as excinfo: + TaskPrompt(task="example_task", content="Test prompt", max_tokens=0) + assert "Input should be greater than or equal to 1" in str(excinfo.value) + + with pytest.raises(ValidationError) as excinfo: + TaskPrompt(task="example_task", content="Test prompt", max_tokens=-1) + assert "Input should be greater than or equal to 1" in str(excinfo.value)