Skip to content

Commit 2a4f576

Browse files
authored
feat(tests): add tests for models and options and api (#1111)
* feat(tests): add tests for models and options and api * refactor(tests): split invalid policy validation into new test * feat(tests): add validation tests for models * feat(tests): enhance policy validation test coverage - Add test for multiple interactions with same policy - Add test for multiple policies in single interaction - Add test for duplicate policy IDs (currently allowed) - Document missing validations for future improvement
1 parent 6540bb5 commit 2a4f576

File tree

4 files changed

+601
-2
lines changed

4 files changed

+601
-2
lines changed

tests/eval/test_models.py

+342
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
18+
import pytest
19+
from pydantic import ValidationError
20+
21+
from nemoguardrails.eval.models import (
22+
ComplianceCheckLog,
23+
EvalConfig,
24+
EvalOutput,
25+
ExpectedOutput,
26+
InteractionSet,
27+
Policy,
28+
)
29+
30+
ROOT = os.path.dirname(__file__)
31+
32+
33+
def test_interaction_set_expected_output_instantiation():
34+
"""Test that ExpectedOutput is correctly instantiated based on type."""
35+
36+
# string type
37+
data = {"type": "string", "policy": "test_policy"}
38+
interaction_set = InteractionSet.model_validate(
39+
{"id": "test_id", "inputs": ["test input"], "expected_output": [data]}
40+
)
41+
assert len(interaction_set.expected_output) == 1
42+
assert interaction_set.expected_output[0].type == "string"
43+
assert interaction_set.expected_output[0].policy == "test_policy"
44+
45+
# dict type
46+
data = {"type": "dict", "policy": "test_policy"}
47+
interaction_set = InteractionSet.model_validate(
48+
{"id": "test_id", "inputs": ["test input"], "expected_output": [data]}
49+
)
50+
assert len(interaction_set.expected_output) == 1
51+
assert interaction_set.expected_output[0].type == "dict"
52+
assert interaction_set.expected_output[0].policy == "test_policy"
53+
54+
55+
def test_eval_config_from_path():
56+
"""Test loading config from path."""
57+
58+
config = EvalConfig.from_path(os.path.join(ROOT, "config_yml"))
59+
assert config is not None
60+
assert len(config.policies) > 0
61+
62+
63+
def test_compliance_check_log():
64+
"""Test ComplianceCheckLog model."""
65+
log = ComplianceCheckLog.model_validate({"id": "test_id", "llm_calls": []})
66+
assert log.id == "test_id"
67+
assert log.llm_calls == []
68+
69+
70+
def test_eval_output():
71+
"""Test EvalOutput model."""
72+
output = EvalOutput.model_validate(
73+
{
74+
"results": [
75+
{
76+
"id": "test_id",
77+
"input": "test input",
78+
"output": "test_output",
79+
"compliance_checks": [
80+
{
81+
"id": "check_id",
82+
"created_at": "2024-01-01T00:00:00",
83+
"method": "test_method",
84+
"compliance": {"policy1": True},
85+
"details": "test details",
86+
"interaction_id": "test_id",
87+
}
88+
],
89+
}
90+
],
91+
"logs": [],
92+
}
93+
)
94+
assert len(output.results) == 1
95+
assert output.results[0].id == "test_id"
96+
assert output.results[0].input == "test input"
97+
assert output.results[0].output == "test_output"
98+
assert len(output.results[0].compliance_checks) == 1
99+
assert output.results[0].compliance_checks[0].id == "check_id"
100+
assert output.results[0].compliance_checks[0].interaction_id == "test_id"
101+
102+
103+
def test_eval_config_policy_validation_empty_lists():
104+
"""Test that empty policies and interactions lists are handled correctly."""
105+
config = EvalConfig.model_validate(
106+
{
107+
"policies": [],
108+
"interactions": [],
109+
}
110+
)
111+
assert len(config.policies) == 0
112+
assert len(config.interactions) == 0
113+
114+
115+
def test_eval_config_policy_validation_invalid_policy_format_missing_description():
116+
"""Test that invalid policy formats are rejected."""
117+
with pytest.raises(ValueError):
118+
EvalConfig.model_validate(
119+
{
120+
"policies": [{"id": "policy1"}],
121+
"interactions": [],
122+
}
123+
)
124+
125+
126+
def test_eval_config_policy_validation_invalid_interaction_format_missing_inputs():
127+
"""Test that invalid interaction formats are rejected."""
128+
with pytest.raises(ValueError):
129+
EvalConfig.model_validate(
130+
{
131+
"policies": [{"id": "policy1", "description": "Test policy"}],
132+
"interactions": [
133+
{
134+
"id": "test_id",
135+
"expected_output": [{"type": "string", "policy": "policy1"}],
136+
}
137+
],
138+
}
139+
)
140+
141+
142+
def test_interaction_set_empty_expected_output():
143+
"""Test that empty expected_output list is handled correctly."""
144+
interaction_set = InteractionSet.model_validate(
145+
{"id": "test_id", "inputs": ["test input"], "expected_output": []}
146+
)
147+
assert len(interaction_set.expected_output) == 0
148+
149+
150+
def test_interaction_set_invalid_format():
151+
"""Test that invalid expected_output format is rejected."""
152+
with pytest.raises(ValueError):
153+
InteractionSet.model_validate(
154+
{
155+
"id": "test_id",
156+
"inputs": ["test input"],
157+
"expected_output": [{"type": "string"}],
158+
}
159+
)
160+
161+
# TODO: The model currently doesn't validate the type field values.
162+
# This test should pass once type validation is implemented.
163+
# with pytest.raises(ValueError):
164+
# InteractionSet.model_validate(
165+
# {
166+
# "id": "test_id",
167+
# "inputs": ["test input"],
168+
# "expected_output": [{"type": "invalid_type", "policy": "test_policy"}],
169+
# }
170+
# )
171+
172+
173+
def test_compliance_check_log_invalid_format():
174+
"""Test that invalid ComplianceCheckLog format is rejected."""
175+
with pytest.raises(ValueError):
176+
ComplianceCheckLog.model_validate({})
177+
178+
# invalid llm_calls format
179+
with pytest.raises(ValueError):
180+
ComplianceCheckLog.model_validate({"id": "test_id", "llm_calls": "invalid"})
181+
182+
183+
def test_policy_creation():
184+
policy = Policy(
185+
id="policy_1",
186+
description="Test policy description",
187+
weight=50,
188+
apply_to_all=False,
189+
)
190+
assert policy.id == "policy_1"
191+
assert policy.description == "Test policy description"
192+
assert policy.weight == 50
193+
assert not policy.apply_to_all
194+
195+
196+
def test_policy_default_values():
197+
policy = Policy(
198+
id="policy_2",
199+
description="Another test policy",
200+
)
201+
assert policy.weight == 100
202+
assert policy.apply_to_all
203+
204+
205+
def test_policy_invalid_weight():
206+
with pytest.raises(ValidationError):
207+
Policy(
208+
id="policy_3",
209+
description="Invalid weight test",
210+
weight="invalid_weight",
211+
)
212+
213+
214+
def test_expected_output_creation():
215+
output = ExpectedOutput(
216+
type="refusal",
217+
policy="policy_1",
218+
)
219+
assert output.type == "refusal"
220+
assert output.policy == "policy_1"
221+
222+
223+
def test_expected_output_missing_field():
224+
with pytest.raises(ValidationError):
225+
ExpectedOutput(
226+
type="refusal",
227+
)
228+
229+
230+
def test_eval_config_policy_validation_valid():
231+
"""Test that policy validation works correctly."""
232+
233+
config = EvalConfig.model_validate(
234+
{
235+
"policies": [{"id": "policy1", "description": "Test policy"}],
236+
"interactions": [
237+
{
238+
"id": "test_id",
239+
"inputs": ["test input"],
240+
"expected_output": [{"type": "string", "policy": "policy1"}],
241+
}
242+
],
243+
}
244+
)
245+
assert len(config.policies) == 1
246+
assert len(config.interactions) == 1
247+
248+
249+
def test_eval_config_policy_validation_invalid_policy_not_found():
250+
# invalid case, policy not found
251+
with pytest.raises(
252+
ValueError, match="Invalid policy id policy2 used in interaction set"
253+
):
254+
EvalConfig.model_validate(
255+
{
256+
"policies": [{"id": "policy1", "description": "Test policy"}],
257+
"interactions": [
258+
{
259+
"id": "test_id",
260+
"inputs": ["test input"],
261+
"expected_output": [
262+
{
263+
"type": "string",
264+
"policy": "policy2",
265+
}
266+
],
267+
}
268+
],
269+
}
270+
)
271+
272+
273+
def test_eval_config_policy_validation_multiple_interactions():
274+
"""Test that policy validation works with multiple interactions."""
275+
config = EvalConfig.model_validate(
276+
{
277+
"policies": [{"id": "policy1", "description": "Test policy"}],
278+
"interactions": [
279+
{
280+
"id": "test_id1",
281+
"inputs": ["test input 1"],
282+
"expected_output": [{"type": "string", "policy": "policy1"}],
283+
},
284+
{
285+
"id": "test_id2",
286+
"inputs": ["test input 2"],
287+
"expected_output": [{"type": "string", "policy": "policy1"}],
288+
},
289+
],
290+
}
291+
)
292+
assert len(config.interactions) == 2
293+
294+
295+
def test_eval_config_policy_validation_multiple_policies():
296+
"""Test that policy validation works with multiple policies."""
297+
config = EvalConfig.model_validate(
298+
{
299+
"policies": [
300+
{"id": "policy1", "description": "Test policy 1"},
301+
{"id": "policy2", "description": "Test policy 2"},
302+
],
303+
"interactions": [
304+
{
305+
"id": "test_id",
306+
"inputs": ["test input"],
307+
"expected_output": [
308+
{"type": "string", "policy": "policy1"},
309+
{"type": "string", "policy": "policy2"},
310+
],
311+
}
312+
],
313+
}
314+
)
315+
assert len(config.policies) == 2
316+
assert len(config.interactions[0].expected_output) == 2
317+
318+
319+
def test_eval_config_policy_validation_duplicate_policy_ids():
320+
"""Test that duplicate policy IDs are handled.
321+
322+
Note: The model currently doesn't validate for duplicate policy IDs.
323+
This test should be updated if duplicate policy ID validation is added.
324+
"""
325+
config = EvalConfig.model_validate(
326+
{
327+
"policies": [
328+
{"id": "policy1", "description": "Test policy 1"},
329+
{"id": "policy1", "description": "Test policy 2"},
330+
],
331+
"interactions": [
332+
{
333+
"id": "test_id",
334+
"inputs": ["test input"],
335+
"expected_output": [{"type": "string", "policy": "policy1"}],
336+
}
337+
],
338+
}
339+
)
340+
assert len(config.policies) == 2
341+
assert config.policies[0].id == "policy1"
342+
assert config.policies[1].id == "policy1"

0 commit comments

Comments
 (0)