Skip to content

Commit f78e84e

Browse files
[OpenVINO Backend] support while loop
1 parent 771b001 commit f78e84e

File tree

3 files changed

+107
-7
lines changed

3 files changed

+107
-7
lines changed

keras/src/backend/openvino/core.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -668,10 +668,107 @@ def while_loop(
668668
loop_vars,
669669
maximum_iterations=None,
670670
):
671-
raise NotImplementedError(
672-
"`while_loop` is not supported with openvino backend"
671+
def flatten_structure(data):
672+
if isinstance(data, dict):
673+
return [v for k in sorted(data) for v in flatten_structure(data[k])]
674+
elif isinstance(data, (tuple, list)):
675+
return [v for item in data for v in flatten_structure(item)]
676+
else:
677+
return [data]
678+
679+
def pack_structure(template, flat):
680+
if isinstance(template, dict):
681+
keys = sorted(template)
682+
packed = {}
683+
for k in keys:
684+
value, flat = pack_structure(template[k], flat)
685+
packed[k] = value
686+
return packed, flat
687+
elif isinstance(template, (tuple, list)):
688+
packed = []
689+
for item in template:
690+
value, flat = pack_structure(item, flat)
691+
packed.append(value)
692+
return (
693+
tuple(packed) if isinstance(template, tuple) else packed
694+
), flat
695+
else:
696+
return flat[0], flat[1:]
697+
698+
is_scalar_input = _is_scalar(loop_vars)
699+
700+
if is_scalar_input:
701+
loop_vars = (loop_vars,)
702+
elif isinstance(loop_vars, (list, np.ndarray)):
703+
loop_vars = tuple(loop_vars)
704+
else:
705+
assert isinstance(loop_vars, (tuple, dict)), (
706+
f"Unsupported type {type(loop_vars)} for loop_vars"
707+
)
708+
709+
flat_loop_vars = flatten_structure(loop_vars)
710+
loop_vars_ov = [get_ov_output(var) for var in flat_loop_vars]
711+
712+
maximum_iterations = (
713+
ov_opset.constant(-1, Type.i32).output(0)
714+
if maximum_iterations is None
715+
else get_ov_output(maximum_iterations)
673716
)
674717

718+
trip_count = maximum_iterations
719+
execution_condition = ov_opset.constant(True, Type.boolean).output(0)
720+
loop = ov_opset.loop(trip_count, execution_condition)
721+
722+
shapes = [var.get_partial_shape() for var in loop_vars_ov]
723+
types = [var.get_element_type() for var in loop_vars_ov]
724+
params = [
725+
ov_opset.parameter(shape, dtype) for shape, dtype in zip(shapes, types)
726+
]
727+
param_tensors = [OpenVINOKerasTensor(p.output(0)) for p in params]
728+
729+
packed_args, _ = pack_structure(loop_vars, param_tensors)
730+
if isinstance(packed_args, dict):
731+
body_out = body(packed_args)
732+
else:
733+
body_out = body(*packed_args)
734+
735+
if not isinstance(body_out, (list, tuple, dict)):
736+
body_out = (body_out,)
737+
738+
flat_body_out = flatten_structure(body_out)
739+
if isinstance(packed_args, dict):
740+
cond_output = get_ov_output(cond(body_out))
741+
else:
742+
cond_output = get_ov_output(cond(*body_out))
743+
744+
if len(cond_output.get_partial_shape()) != 0:
745+
raise ValueError(
746+
"`cond` function must return a scalar boolean value, "
747+
"but got shape {}".format(cond_output.get_partial_shape())
748+
)
749+
750+
results = [cond_output] + [get_ov_output(x) for x in flat_body_out]
751+
body_func = Model(results=results, parameters=params)
752+
loop.set_function(body_func)
753+
loop.set_special_body_ports([0, 0])
754+
755+
for param, init_val, next_val in zip(params, loop_vars_ov, flat_body_out):
756+
loop.set_merged_input(param, init_val, get_ov_output(next_val))
757+
758+
outputs_flat = [
759+
OpenVINOKerasTensor(loop.get_iter_value(get_ov_output(val)))
760+
for val in flat_body_out
761+
]
762+
final_output, _ = pack_structure(loop_vars, outputs_flat)
763+
764+
if is_scalar_input:
765+
if isinstance(final_output, tuple):
766+
return final_output[0]
767+
else:
768+
return final_output
769+
else:
770+
return final_output
771+
675772

676773
def fori_loop(lower, upper, body_fun, init_val):
677774
raise NotImplementedError(

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,6 @@ CoreOpsCallsTests::test_slice_update_basic_call
166166
CoreOpsCallsTests::test_slice_with_non_symbolic_tensors
167167
CoreOpsCallsTests::test_switch_basic_call
168168
CoreOpsCallsTests::test_unstack_basic_functionality
169-
CoreOpsCallsTests::test_while_loop_basic_functionality
170-
CoreOpsCallsTests::test_while_loop_with_max_iterations
171169
CoreOpsCorrectnessTest::test_associative_scan
172170
CoreOpsCorrectnessTest::test_cond
173171
CoreOpsCorrectnessTest::test_dynamic_slice
@@ -180,7 +178,6 @@ CoreOpsCorrectnessTest::test_slice_update
180178
CoreOpsCorrectnessTest::test_switch
181179
CoreOpsCorrectnessTest::test_unstack
182180
CoreOpsCorrectnessTest::test_vectorized_map
183-
CoreOpsCorrectnessTest::test_while_loop
184181
CoreOpsDtypeTest::test_convert_to_tensor0
185182
CoreOpsDtypeTest::test_convert_to_tensor1
186183
CoreOpsDtypeTest::test_convert_to_tensor2

keras/src/ops/core_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,10 @@ def body(i):
11131113
# Initial loop variable (i = 0)
11141114
loop_vars = (0,)
11151115
result = while_loop.call(loop_vars)
1116-
self.assertEqual(result[0], 5)
1116+
if backend.backend() == "openvino":
1117+
self.assertEqual(ops.convert_to_numpy(result[0]), 5)
1118+
else:
1119+
self.assertEqual(result[0], 5)
11171120

11181121
def test_while_loop_output_spec(self):
11191122
# Define dummy cond and body functions
@@ -1139,7 +1142,10 @@ def body(i):
11391142

11401143
while_loop = core.WhileLoop(cond, body, maximum_iterations=5)
11411144
result = while_loop.call((0,))
1142-
self.assertEqual(result[0], 5)
1145+
if backend.backend() == "openvino":
1146+
self.assertEqual(ops.convert_to_numpy(result[0]), 5)
1147+
else:
1148+
self.assertEqual(result[0], 5)
11431149

11441150
def test_whileloop_compute_output_spec(self):
11451151
# Define loop variables with different shapes and data types

0 commit comments

Comments
 (0)