Skip to content

Commit 33ff9f7

Browse files
committed
feat(tests): add unit tests for TaskPrompt validation
1 parent 29c1ac0 commit 33ff9f7

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

tests/rails/llm/test_config.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
from pydantic import ValidationError
18+
19+
from nemoguardrails.rails.llm.config import TaskPrompt
20+
21+
22+
def test_task_prompt_valid_content():
23+
prompt = TaskPrompt(task="example_task", content="This is a valid prompt.")
24+
assert prompt.task == "example_task"
25+
assert prompt.content == "This is a valid prompt."
26+
assert prompt.messages is None
27+
28+
29+
def test_task_prompt_valid_messages():
30+
prompt = TaskPrompt(task="example_task", messages=["Hello", "How can I help you?"])
31+
assert prompt.task == "example_task"
32+
assert prompt.messages == ["Hello", "How can I help you?"]
33+
assert prompt.content is None
34+
35+
36+
def test_task_prompt_missing_content_and_messages():
37+
with pytest.raises(ValidationError) as excinfo:
38+
TaskPrompt(task="example_task")
39+
assert "One of `content` or `messages` must be provided." in str(excinfo.value)
40+
41+
42+
def test_task_prompt_both_content_and_messages():
43+
with pytest.raises(ValidationError) as excinfo:
44+
TaskPrompt(
45+
task="example_task",
46+
content="This is a prompt.",
47+
messages=["Hello", "How can I help you?"],
48+
)
49+
assert "Only one of `content` or `messages` must be provided." in str(excinfo.value)
50+
51+
52+
def test_task_prompt_models_validation():
53+
prompt = TaskPrompt(
54+
task="example_task",
55+
content="Test prompt",
56+
models=["openai", "openai/gpt-3.5-turbo"],
57+
)
58+
assert prompt.models == ["openai", "openai/gpt-3.5-turbo"]
59+
60+
prompt = TaskPrompt(task="example_task", content="Test prompt", models=[])
61+
assert prompt.models == []
62+
63+
prompt = TaskPrompt(task="example_task", content="Test prompt", models=None)
64+
assert prompt.models is None
65+
66+
67+
def test_task_prompt_max_length_validation():
68+
prompt = TaskPrompt(task="example_task", content="Test prompt")
69+
assert prompt.max_length == 16000
70+
71+
prompt = TaskPrompt(task="example_task", content="Test prompt", max_length=1000)
72+
assert prompt.max_length == 1000
73+
74+
with pytest.raises(ValidationError) as excinfo:
75+
TaskPrompt(task="example_task", content="Test prompt", max_length=0)
76+
assert "Input should be greater than or equal to 1" in str(excinfo.value)
77+
78+
with pytest.raises(ValidationError) as excinfo:
79+
TaskPrompt(task="example_task", content="Test prompt", max_length=-1)
80+
assert "Input should be greater than or equal to 1" in str(excinfo.value)
81+
82+
83+
def test_task_prompt_mode_validation():
84+
prompt = TaskPrompt(task="example_task", content="Test prompt")
85+
# default mode is "standard"
86+
assert prompt.mode == "standard"
87+
88+
prompt = TaskPrompt(task="example_task", content="Test prompt", mode="chat")
89+
assert prompt.mode == "chat"
90+
91+
prompt = TaskPrompt(task="example_task", content="Test prompt", mode=None)
92+
assert prompt.mode is None
93+
94+
95+
def test_task_prompt_stop_tokens_validation():
96+
prompt = TaskPrompt(
97+
task="example_task", content="Test prompt", stop=["\n", "Human:", "Assistant:"]
98+
)
99+
assert prompt.stop == ["\n", "Human:", "Assistant:"]
100+
101+
prompt = TaskPrompt(task="example_task", content="Test prompt", stop=[])
102+
assert prompt.stop == []
103+
104+
prompt = TaskPrompt(task="example_task", content="Test prompt", stop=None)
105+
assert prompt.stop is None
106+
107+
with pytest.raises(ValidationError) as excinfo:
108+
TaskPrompt(task="example_task", content="Test prompt", stop=[1, 2, 3])
109+
assert "Input should be a valid string" in str(excinfo.value)
110+
111+
112+
def test_task_prompt_max_tokens_validation():
113+
prompt = TaskPrompt(task="example_task", content="Test prompt")
114+
assert prompt.max_tokens is None
115+
116+
prompt = TaskPrompt(task="example_task", content="Test prompt", max_tokens=1000)
117+
assert prompt.max_tokens == 1000
118+
119+
with pytest.raises(ValidationError) as excinfo:
120+
TaskPrompt(task="example_task", content="Test prompt", max_tokens=0)
121+
assert "Input should be greater than or equal to 1" in str(excinfo.value)
122+
123+
with pytest.raises(ValidationError) as excinfo:
124+
TaskPrompt(task="example_task", content="Test prompt", max_tokens=-1)
125+
assert "Input should be greater than or equal to 1" in str(excinfo.value)

0 commit comments

Comments
 (0)