Skip to content

Commit 6540bb5

Browse files
authored
fix: use ValueError in TaskPrompt to resolve TypeError raised by Pydantic (#1132)
* feat(config): add validation for prompt fields Added `ge=1` validation to `max_length` and `max_tokens` fields in the TaskPrompt class to ensure values are greater than or equal to 1. * fix(config): replace ValidationError with ValueError Replaced the use of Pydantic's ValidationError with ValueError in the TaskPrompt class's root_validator. This resolves a TypeError caused by incorrect instantiation of ValidationError during validation. * refactor(config): remove unused Flow import * feat(tests): add unit tests for TaskPrompt validation
1 parent 3041025 commit 6540bb5

File tree

2 files changed

+129
-5
lines changed

2 files changed

+129
-5
lines changed

nemoguardrails/rails/llm/config.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535
from nemoguardrails import utils
3636
from nemoguardrails.colang import parse_colang_file, parse_flow_elements
37-
from nemoguardrails.colang.v2_x.lang.colang_ast import Flow
3837
from nemoguardrails.colang.v2_x.lang.utils import format_colang_parsing_error_message
3938
from nemoguardrails.colang.v2_x.runtime.errors import ColangParsingError
4039

@@ -305,6 +304,7 @@ class TaskPrompt(BaseModel):
305304
max_length: Optional[int] = Field(
306305
default=16000,
307306
description="The maximum length of the prompt in number of characters.",
307+
ge=1,
308308
)
309309
mode: Optional[str] = Field(
310310
default=_default_config["prompting_mode"],
@@ -318,17 +318,16 @@ class TaskPrompt(BaseModel):
318318
max_tokens: Optional[int] = Field(
319319
default=None,
320320
description="The maximum number of tokens that can be generated in the chat completion.",
321+
ge=1,
321322
)
322323

323324
@root_validator(pre=True, allow_reuse=True)
324325
def check_fields(cls, values):
325326
if not values.get("content") and not values.get("messages"):
326-
raise ValidationError("One of `content` or `messages` must be provided.")
327+
raise ValueError("One of `content` or `messages` must be provided.")
327328

328329
if values.get("content") and values.get("messages"):
329-
raise ValidationError(
330-
"Only one of `content` or `messages` must be provided."
331-
)
330+
raise ValueError("Only one of `content` or `messages` must be provided.")
332331

333332
return values
334333

tests/rails/llm/test_config.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
from pydantic import ValidationError
18+
19+
from nemoguardrails.rails.llm.config import TaskPrompt
20+
21+
22+
def test_task_prompt_valid_content():
23+
prompt = TaskPrompt(task="example_task", content="This is a valid prompt.")
24+
assert prompt.task == "example_task"
25+
assert prompt.content == "This is a valid prompt."
26+
assert prompt.messages is None
27+
28+
29+
def test_task_prompt_valid_messages():
30+
prompt = TaskPrompt(task="example_task", messages=["Hello", "How can I help you?"])
31+
assert prompt.task == "example_task"
32+
assert prompt.messages == ["Hello", "How can I help you?"]
33+
assert prompt.content is None
34+
35+
36+
def test_task_prompt_missing_content_and_messages():
37+
with pytest.raises(ValidationError) as excinfo:
38+
TaskPrompt(task="example_task")
39+
assert "One of `content` or `messages` must be provided." in str(excinfo.value)
40+
41+
42+
def test_task_prompt_both_content_and_messages():
43+
with pytest.raises(ValidationError) as excinfo:
44+
TaskPrompt(
45+
task="example_task",
46+
content="This is a prompt.",
47+
messages=["Hello", "How can I help you?"],
48+
)
49+
assert "Only one of `content` or `messages` must be provided." in str(excinfo.value)
50+
51+
52+
def test_task_prompt_models_validation():
53+
prompt = TaskPrompt(
54+
task="example_task",
55+
content="Test prompt",
56+
models=["openai", "openai/gpt-3.5-turbo"],
57+
)
58+
assert prompt.models == ["openai", "openai/gpt-3.5-turbo"]
59+
60+
prompt = TaskPrompt(task="example_task", content="Test prompt", models=[])
61+
assert prompt.models == []
62+
63+
prompt = TaskPrompt(task="example_task", content="Test prompt", models=None)
64+
assert prompt.models is None
65+
66+
67+
def test_task_prompt_max_length_validation():
68+
prompt = TaskPrompt(task="example_task", content="Test prompt")
69+
assert prompt.max_length == 16000
70+
71+
prompt = TaskPrompt(task="example_task", content="Test prompt", max_length=1000)
72+
assert prompt.max_length == 1000
73+
74+
with pytest.raises(ValidationError) as excinfo:
75+
TaskPrompt(task="example_task", content="Test prompt", max_length=0)
76+
assert "Input should be greater than or equal to 1" in str(excinfo.value)
77+
78+
with pytest.raises(ValidationError) as excinfo:
79+
TaskPrompt(task="example_task", content="Test prompt", max_length=-1)
80+
assert "Input should be greater than or equal to 1" in str(excinfo.value)
81+
82+
83+
def test_task_prompt_mode_validation():
84+
prompt = TaskPrompt(task="example_task", content="Test prompt")
85+
# default mode is "standard"
86+
assert prompt.mode == "standard"
87+
88+
prompt = TaskPrompt(task="example_task", content="Test prompt", mode="chat")
89+
assert prompt.mode == "chat"
90+
91+
prompt = TaskPrompt(task="example_task", content="Test prompt", mode=None)
92+
assert prompt.mode is None
93+
94+
95+
def test_task_prompt_stop_tokens_validation():
96+
prompt = TaskPrompt(
97+
task="example_task", content="Test prompt", stop=["\n", "Human:", "Assistant:"]
98+
)
99+
assert prompt.stop == ["\n", "Human:", "Assistant:"]
100+
101+
prompt = TaskPrompt(task="example_task", content="Test prompt", stop=[])
102+
assert prompt.stop == []
103+
104+
prompt = TaskPrompt(task="example_task", content="Test prompt", stop=None)
105+
assert prompt.stop is None
106+
107+
with pytest.raises(ValidationError) as excinfo:
108+
TaskPrompt(task="example_task", content="Test prompt", stop=[1, 2, 3])
109+
assert "Input should be a valid string" in str(excinfo.value)
110+
111+
112+
def test_task_prompt_max_tokens_validation():
113+
prompt = TaskPrompt(task="example_task", content="Test prompt")
114+
assert prompt.max_tokens is None
115+
116+
prompt = TaskPrompt(task="example_task", content="Test prompt", max_tokens=1000)
117+
assert prompt.max_tokens == 1000
118+
119+
with pytest.raises(ValidationError) as excinfo:
120+
TaskPrompt(task="example_task", content="Test prompt", max_tokens=0)
121+
assert "Input should be greater than or equal to 1" in str(excinfo.value)
122+
123+
with pytest.raises(ValidationError) as excinfo:
124+
TaskPrompt(task="example_task", content="Test prompt", max_tokens=-1)
125+
assert "Input should be greater than or equal to 1" in str(excinfo.value)

0 commit comments

Comments
 (0)