|
36 | 36 | from nemoguardrails.colang import parse_colang_file, parse_flow_elements
|
37 | 37 | from nemoguardrails.colang.v2_x.lang.utils import format_colang_parsing_error_message
|
38 | 38 | from nemoguardrails.colang.v2_x.runtime.errors import ColangParsingError
|
| 39 | +from nemoguardrails.llm.types import Task |
39 | 40 |
|
40 | 41 | log = logging.getLogger(__name__)
|
41 | 42 |
|
@@ -1140,37 +1141,66 @@ def check_reasoning_traces_with_dialog_rails(cls, values):
|
1140 | 1141 | rails = values.get("rails", {})
|
1141 | 1142 | dialog_rails = rails.get("dialog", {})
|
1142 | 1143 |
|
1143 |
| - # check if any model has reasoning traces enabled |
1144 |
| - # TODO: we must check for models that are used in a specific dialog task |
1145 |
| - has_reasoning_traces = False |
1146 |
| - for model in models: |
1147 |
| - if isinstance(model, dict): |
1148 |
| - reasoning_config = model.get("reasoning_config", {}) |
1149 |
| - if not reasoning_config.get("remove_thinking_traces", True): |
1150 |
| - has_reasoning_traces = True |
1151 |
| - break |
1152 |
| - elif hasattr(model, "reasoning_config"): |
1153 |
| - if not model.reasoning_config.remove_thinking_traces: |
1154 |
| - has_reasoning_traces = True |
1155 |
| - break |
1156 |
| - |
1157 |
| - # check if dialog rails are present (explicitly or implicitly) |
1158 |
| - has_dialog_rails = bool(dialog_rails) |
1159 |
| - |
1160 |
| - # check implicit dialog rails through user messages, bot messages, or flows |
1161 |
| - if not has_dialog_rails: |
1162 |
| - has_dialog_rails = ( |
1163 |
| - bool(values.get("user_messages")) |
1164 |
| - or bool(values.get("bot_messages")) |
1165 |
| - or bool(values.get("flows")) |
1166 |
| - ) |
| 1144 | + # dialog rail tasks that should not have reasoning traces |
| 1145 | + dialog_rail_tasks = [ |
| 1146 | + Task.GENERATE_BOT_MESSAGE, |
| 1147 | + Task.GENERATE_USER_INTENT, |
| 1148 | + Task.GENERATE_NEXT_STEPS, |
| 1149 | + Task.GENERATE_INTENT_STEPS_MESSAGE, |
| 1150 | + ] |
1167 | 1151 |
|
1168 |
| - if has_reasoning_traces and has_dialog_rails: |
1169 |
| - raise ValueError( |
1170 |
| - "Reasoning traces cannot be enabled when dialog rails are present. " |
1171 |
| - "Please either disable reasoning traces or remove dialog rails." |
| 1152 | + # dialog rails are activated (explicitly or implicitly) |
| 1153 | + has_dialog_rails = ( |
| 1154 | + bool(dialog_rails) |
| 1155 | + or bool(values.get("user_messages")) |
| 1156 | + or bool(values.get("bot_messages")) |
| 1157 | + or bool(values.get("flows")) |
| 1158 | + ) |
| 1159 | + |
| 1160 | + if has_dialog_rails: |
| 1161 | + # Get the main model if it exists |
| 1162 | + main_model = next( |
| 1163 | + (model for model in models if model.get("type") == "main"), None |
1172 | 1164 | )
|
1173 | 1165 |
|
| 1166 | + violations = [] |
| 1167 | + |
| 1168 | + for task in dialog_rail_tasks: |
| 1169 | + # Check if there's a dedicated model for this task |
| 1170 | + task_model = next( |
| 1171 | + (model for model in models if model.get("type") == task.value), None |
| 1172 | + ) |
| 1173 | + |
| 1174 | + if task_model: |
| 1175 | + # Handle both dictionary and Model object cases |
| 1176 | + reasoning_config = ( |
| 1177 | + task_model.reasoning_config |
| 1178 | + if hasattr(task_model, "reasoning_config") |
| 1179 | + else task_model.get("reasoning_config", {}) |
| 1180 | + ) |
| 1181 | + if not reasoning_config.get("remove_thinking_traces", True): |
| 1182 | + violations.append( |
| 1183 | + f"Model '{task_model.get('type')}' has reasoning traces enabled in config.yml. " |
| 1184 | + f"Reasoning traces must be disabled for dialog rail tasks. " |
| 1185 | + f"Please update your config.yml to set 'remove_thinking_traces: true' under reasoning_config for this model." |
| 1186 | + ) |
| 1187 | + elif main_model: |
| 1188 | + # Handle both dictionary and Model object cases |
| 1189 | + reasoning_config = ( |
| 1190 | + main_model.reasoning_config |
| 1191 | + if hasattr(main_model, "reasoning_config") |
| 1192 | + else main_model.get("reasoning_config", {}) |
| 1193 | + ) |
| 1194 | + if not reasoning_config.get("remove_thinking_traces", True): |
| 1195 | + violations.append( |
| 1196 | + f"Main model has reasoning traces enabled in config.yml and is being used for dialog rail task '{task.value}'. " |
| 1197 | + f"Reasoning traces must be disabled when dialog rails are present. " |
| 1198 | + f"Please update your config.yml to set 'remove_thinking_traces: true' under reasoning_config for the main model." |
| 1199 | + ) |
| 1200 | + |
| 1201 | + if violations: |
| 1202 | + raise ValueError("\n".join(violations)) |
| 1203 | + |
1174 | 1204 | return values
|
1175 | 1205 |
|
1176 | 1206 | @root_validator(pre=True, allow_reuse=True)
|
|
0 commit comments