Skip to content

Commit 04e97e3

Browse files
committed
fix: improve reasoning traces validation for dialog rails
- Fix validation to properly handle both dictionary and Model object cases - Add proper handling of cases where some dialog rail tasks have dedicated models and others fall back to main model - Improve error messages to be more user-friendly: * Specify which model has the issue (main model or specific task model) * Reference config.yml in error messages * Provide clear YAML configuration instructions * Include specific task name when relevant - Update test cases to match new error messages and validation logic - Add proper handling of implicit dialog rails activation through user/bot messages and flows This change ensures that reasoning traces are properly disabled when dialog rails are present, regardless of whether they are explicitly configured or implicitly activated through user/bot messages or flows. The improved error messages make it easier for users to understand and fix configuration issues in their YAML files. add more tests
1 parent a6f620f commit 04e97e3

File tree

2 files changed

+379
-33
lines changed

2 files changed

+379
-33
lines changed

nemoguardrails/rails/llm/config.py

+58-28
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from nemoguardrails.colang import parse_colang_file, parse_flow_elements
3737
from nemoguardrails.colang.v2_x.lang.utils import format_colang_parsing_error_message
3838
from nemoguardrails.colang.v2_x.runtime.errors import ColangParsingError
39+
from nemoguardrails.llm.types import Task
3940

4041
log = logging.getLogger(__name__)
4142

@@ -1140,37 +1141,66 @@ def check_reasoning_traces_with_dialog_rails(cls, values):
11401141
rails = values.get("rails", {})
11411142
dialog_rails = rails.get("dialog", {})
11421143

1143-
# check if any model has reasoning traces enabled
1144-
# TODO: we must check for models that are used in a specific dialog task
1145-
has_reasoning_traces = False
1146-
for model in models:
1147-
if isinstance(model, dict):
1148-
reasoning_config = model.get("reasoning_config", {})
1149-
if not reasoning_config.get("remove_thinking_traces", True):
1150-
has_reasoning_traces = True
1151-
break
1152-
elif hasattr(model, "reasoning_config"):
1153-
if not model.reasoning_config.remove_thinking_traces:
1154-
has_reasoning_traces = True
1155-
break
1156-
1157-
# check if dialog rails are present (explicitly or implicitly)
1158-
has_dialog_rails = bool(dialog_rails)
1159-
1160-
# check implicit dialog rails through user messages, bot messages, or flows
1161-
if not has_dialog_rails:
1162-
has_dialog_rails = (
1163-
bool(values.get("user_messages"))
1164-
or bool(values.get("bot_messages"))
1165-
or bool(values.get("flows"))
1166-
)
1144+
# dialog rail tasks that should not have reasoning traces
1145+
dialog_rail_tasks = [
1146+
Task.GENERATE_BOT_MESSAGE,
1147+
Task.GENERATE_USER_INTENT,
1148+
Task.GENERATE_NEXT_STEPS,
1149+
Task.GENERATE_INTENT_STEPS_MESSAGE,
1150+
]
11671151

1168-
if has_reasoning_traces and has_dialog_rails:
1169-
raise ValueError(
1170-
"Reasoning traces cannot be enabled when dialog rails are present. "
1171-
"Please either disable reasoning traces or remove dialog rails."
1152+
# dialog rails are activated (explicitly or implicitly)
1153+
has_dialog_rails = (
1154+
bool(dialog_rails)
1155+
or bool(values.get("user_messages"))
1156+
or bool(values.get("bot_messages"))
1157+
or bool(values.get("flows"))
1158+
)
1159+
1160+
if has_dialog_rails:
1161+
# Get the main model if it exists
1162+
main_model = next(
1163+
(model for model in models if model.get("type") == "main"), None
11721164
)
11731165

1166+
violations = []
1167+
1168+
for task in dialog_rail_tasks:
1169+
# Check if there's a dedicated model for this task
1170+
task_model = next(
1171+
(model for model in models if model.get("type") == task.value), None
1172+
)
1173+
1174+
if task_model:
1175+
# Handle both dictionary and Model object cases
1176+
reasoning_config = (
1177+
task_model.reasoning_config
1178+
if hasattr(task_model, "reasoning_config")
1179+
else task_model.get("reasoning_config", {})
1180+
)
1181+
if not reasoning_config.get("remove_thinking_traces", True):
1182+
violations.append(
1183+
f"Model '{task_model.get('type')}' has reasoning traces enabled in config.yml. "
1184+
f"Reasoning traces must be disabled for dialog rail tasks. "
1185+
f"Please update your config.yml to set 'remove_thinking_traces: true' under reasoning_config for this model."
1186+
)
1187+
elif main_model:
1188+
# Handle both dictionary and Model object cases
1189+
reasoning_config = (
1190+
main_model.reasoning_config
1191+
if hasattr(main_model, "reasoning_config")
1192+
else main_model.get("reasoning_config", {})
1193+
)
1194+
if not reasoning_config.get("remove_thinking_traces", True):
1195+
violations.append(
1196+
f"Main model has reasoning traces enabled in config.yml and is being used for dialog rail task '{task.value}'. "
1197+
f"Reasoning traces must be disabled when dialog rails are present. "
1198+
f"Please update your config.yml to set 'remove_thinking_traces: true' under reasoning_config for the main model."
1199+
)
1200+
1201+
if violations:
1202+
raise ValueError("\n".join(violations))
1203+
11741204
return values
11751205

11761206
@root_validator(pre=True, allow_reuse=True)

0 commit comments

Comments
 (0)