Skip to content

Commit 59f67c4

Browse files
committed
tests to show the expected behavior of dialog rails and dialog tasks
add tests to show the expected behavior of dialog tasks
1 parent 04e97e3 commit 59f67c4

File tree

1 file changed

+249
-0
lines changed

1 file changed

+249
-0
lines changed

tests/test_dialog_tasks.py

+249
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
from unittest.mock import Mock, patch
18+
19+
import pytest
20+
21+
from nemoguardrails import LLMRails, RailsConfig
22+
from nemoguardrails.llm.taskmanager import LLMTaskManager
23+
from nemoguardrails.llm.types import Task
24+
25+
try:
26+
import langchain_openai
27+
28+
has_langchain_openai = True
29+
except ImportError:
30+
has_langchain_openai = False
31+
32+
has_openai_key = bool(os.getenv("OPENAI_API_KEY"))
33+
34+
skip_if_no_openai = pytest.mark.skipif(
35+
not (has_langchain_openai and has_openai_key),
36+
reason="Requires langchain_openai and OPENAI_API_KEY environment variable",
37+
)
38+
39+
40+
@skip_if_no_openai
41+
def test_dialog_tasks_with_only_input_rails():
42+
"""Test that dialog tasks are not used when only input rails are present."""
43+
44+
config = RailsConfig.from_content(
45+
yaml_content="""
46+
models:
47+
- type: main
48+
engine: openai
49+
model: gpt-3.5-turbo-instruct
50+
rails:
51+
input:
52+
flows:
53+
- self check input
54+
prompts:
55+
- task: self_check_input
56+
content: "Check if input is safe"
57+
""",
58+
)
59+
60+
assert not config.user_messages
61+
assert not config.bot_messages
62+
assert not config.flows
63+
assert not config.rails.dialog.single_call.enabled
64+
65+
rails = LLMRails(config=config)
66+
67+
assert not rails.config.rails.dialog.single_call.enabled
68+
# with just input rails, some basic flows and messages are created
69+
# but they are not for actual dialog processing
70+
assert rails.config.bot_messages
71+
assert rails.config.flows
72+
# even there should be no user messages defined
73+
assert not rails.config.user_messages
74+
75+
76+
@skip_if_no_openai
77+
def test_dialog_tasks_with_only_output_rails():
78+
"""Test that dialog tasks are not used when only output rails are present."""
79+
80+
config = RailsConfig.from_content(
81+
yaml_content="""
82+
models:
83+
- type: main
84+
engine: openai
85+
model: gpt-3.5-turbo-instruct
86+
rails:
87+
output:
88+
flows:
89+
- self check output
90+
prompts:
91+
- task: self_check_output
92+
content: "Check if output is safe"
93+
""",
94+
)
95+
96+
assert not config.user_messages
97+
assert not config.bot_messages
98+
assert not config.flows
99+
assert not config.rails.dialog.single_call.enabled
100+
101+
rails = LLMRails(config=config)
102+
103+
assert not rails.config.rails.dialog.single_call.enabled
104+
assert rails.config.bot_messages
105+
assert rails.config.flows
106+
assert not rails.config.user_messages
107+
108+
109+
@skip_if_no_openai
110+
def test_dialog_tasks_with_dialog_rails():
111+
"""Test that dialog tasks are used when dialog rails are present."""
112+
113+
config = RailsConfig.from_content(
114+
yaml_content="""
115+
models:
116+
- type: main
117+
engine: openai
118+
model: gpt-3.5-turbo-instruct
119+
rails:
120+
dialog:
121+
single_call:
122+
enabled: true
123+
""",
124+
colang_content="""
125+
define user express greeting
126+
"hello"
127+
"hi"
128+
129+
define bot express greeting
130+
"Hello there!"
131+
132+
define flow
133+
user express greeting
134+
bot express greeting
135+
""",
136+
)
137+
138+
assert config.user_messages
139+
assert config.bot_messages
140+
assert config.flows
141+
assert config.rails.dialog.single_call.enabled
142+
143+
rails = LLMRails(config=config)
144+
145+
assert rails.config.rails.dialog.single_call.enabled
146+
assert rails.config.user_messages
147+
assert rails.config.bot_messages
148+
assert rails.config.flows
149+
150+
151+
@skip_if_no_openai
152+
def test_dialog_tasks_with_implicit_dialog_rails():
153+
"""Test that dialog tasks are used when dialog rails are implicitly present through user/bot messages."""
154+
155+
config = RailsConfig.from_content(
156+
yaml_content="""
157+
models:
158+
- type: main
159+
engine: openai
160+
model: gpt-3.5-turbo-instruct
161+
""",
162+
colang_content="""
163+
define user express greeting
164+
"hello"
165+
"hi"
166+
167+
define bot express greeting
168+
"Hello there!"
169+
170+
define flow
171+
user express greeting
172+
bot express greeting
173+
""",
174+
)
175+
176+
assert config.user_messages
177+
assert config.bot_messages
178+
assert config.flows
179+
180+
assert config.user_messages == {"express greeting": ["hello", "hi"]}
181+
assert config.bot_messages == {"express greeting": ["Hello there!"]}
182+
assert len(config.bot_messages) == 1
183+
assert len(config.flows) == 1
184+
185+
assert not config.rails.dialog.single_call.enabled
186+
187+
rails = LLMRails(config=config)
188+
189+
assert rails.config.user_messages
190+
assert len(rails.config.user_messages) == 1
191+
assert rails.config.bot_messages
192+
assert len(rails.config.bot_messages) > 1
193+
assert rails.config.flows
194+
195+
196+
@skip_if_no_openai
197+
def test_dialog_tasks_with_mixed_rails():
198+
"""Test that dialog tasks are used when dialog rails are present along with other rails."""
199+
200+
config = RailsConfig.from_content(
201+
yaml_content="""
202+
models:
203+
- type: main
204+
engine: openai
205+
model: gpt-3.5-turbo-instruct
206+
rails:
207+
input:
208+
flows:
209+
- self check input
210+
output:
211+
flows:
212+
- self check output
213+
dialog:
214+
single_call:
215+
enabled: true
216+
prompts:
217+
- task: self_check_input
218+
content: "Check if input is safe"
219+
- task: self_check_output
220+
content: "Check if output is safe"
221+
""",
222+
colang_content="""
223+
define user express greeting
224+
"hello"
225+
"hi"
226+
227+
define bot express greeting
228+
"Hello there!"
229+
230+
define flow
231+
user express greeting
232+
bot express greeting
233+
""",
234+
)
235+
assert config.rails.dialog.single_call.enabled
236+
assert config.user_messages
237+
assert config.bot_messages
238+
assert config.flows
239+
assert config.rails.input.flows
240+
assert config.rails.output.flows
241+
242+
rails = LLMRails(config=config)
243+
244+
assert rails.config.rails.dialog.single_call.enabled
245+
assert rails.config.user_messages
246+
assert rails.config.bot_messages
247+
assert rails.config.flows
248+
assert rails.config.rails.input.flows
249+
assert rails.config.rails.output.flows

0 commit comments

Comments
 (0)