diff --git a/nemoguardrails/colang/v2_x/runtime/eval.py b/nemoguardrails/colang/v2_x/runtime/eval.py index 4c5998e8c..1d76266f7 100644 --- a/nemoguardrails/colang/v2_x/runtime/eval.py +++ b/nemoguardrails/colang/v2_x/runtime/eval.py @@ -40,20 +40,11 @@ class ComparisonExpression: """An expression to compare to values.""" def __init__(self, operator: Callable[[Any], bool], value: Any) -> None: - if not isinstance(value, (int, float)): - raise ColangValueError( - f"Comparison operators don't support values of type '{type(value)}'" - ) self.value = value self.operator = operator def compare(self, value: Any) -> bool: """Compare given value with the expression's value.""" - if not isinstance(value, type(self.value)): - raise ColangValueError( - "Comparing variables of different types is not supported!" - ) - return self.operator(value) @@ -168,6 +159,7 @@ def eval_expression(expr: str, context: dict) -> Any: "greater_than": _greater_than_operator, "equal_greater_than": _equal_or_greater_than_operator, "not_equal_to": _not_equal_to_operator, + "is_in": _is_in, "list": list, } ) @@ -286,6 +278,11 @@ def _not_equal_to_operator(v_ref: Any) -> ComparisonExpression: return ComparisonExpression(lambda val, val_ref=v_ref: val != val_ref, v_ref) +def _is_in(v_ref: Any) -> ComparisonExpression: + """Create a is in comparison expression.""" + return ComparisonExpression(lambda val, val_ref=v_ref: val in val_ref, v_ref) + + def _flows_info(state: State, flow_instance_uid: Optional[str] = None) -> dict: """Return a summary of the provided state, or all states by default.""" if flow_instance_uid is not None and flow_instance_uid in state.flow_states: