@@ -668,10 +668,107 @@ def while_loop(
668
668
loop_vars ,
669
669
maximum_iterations = None ,
670
670
):
671
- raise NotImplementedError (
672
- "`while_loop` is not supported with openvino backend"
671
+ def flatten_structure (data ):
672
+ if isinstance (data , dict ):
673
+ return [v for k in sorted (data ) for v in flatten_structure (data [k ])]
674
+ elif isinstance (data , (tuple , list )):
675
+ return [v for item in data for v in flatten_structure (item )]
676
+ else :
677
+ return [data ]
678
+
679
+ def pack_structure (template , flat ):
680
+ if isinstance (template , dict ):
681
+ keys = sorted (template )
682
+ packed = {}
683
+ for k in keys :
684
+ value , flat = pack_structure (template [k ], flat )
685
+ packed [k ] = value
686
+ return packed , flat
687
+ elif isinstance (template , (tuple , list )):
688
+ packed = []
689
+ for item in template :
690
+ value , flat = pack_structure (item , flat )
691
+ packed .append (value )
692
+ return (
693
+ tuple (packed ) if isinstance (template , tuple ) else packed
694
+ ), flat
695
+ else :
696
+ return flat [0 ], flat [1 :]
697
+
698
+ is_scalar_input = _is_scalar (loop_vars )
699
+
700
+ if is_scalar_input :
701
+ loop_vars = (loop_vars ,)
702
+ elif isinstance (loop_vars , (list , np .ndarray )):
703
+ loop_vars = tuple (loop_vars )
704
+ else :
705
+ assert isinstance (loop_vars , (tuple , dict )), (
706
+ f"Unsupported type { type (loop_vars )} for loop_vars"
707
+ )
708
+
709
+ flat_loop_vars = flatten_structure (loop_vars )
710
+ loop_vars_ov = [get_ov_output (var ) for var in flat_loop_vars ]
711
+
712
+ maximum_iterations = (
713
+ ov_opset .constant (- 1 , Type .i32 ).output (0 )
714
+ if maximum_iterations is None
715
+ else get_ov_output (maximum_iterations )
673
716
)
674
717
718
+ trip_count = maximum_iterations
719
+ execution_condition = ov_opset .constant (True , Type .boolean ).output (0 )
720
+ loop = ov_opset .loop (trip_count , execution_condition )
721
+
722
+ shapes = [var .get_partial_shape () for var in loop_vars_ov ]
723
+ types = [var .get_element_type () for var in loop_vars_ov ]
724
+ params = [
725
+ ov_opset .parameter (shape , dtype ) for shape , dtype in zip (shapes , types )
726
+ ]
727
+ param_tensors = [OpenVINOKerasTensor (p .output (0 )) for p in params ]
728
+
729
+ packed_args , _ = pack_structure (loop_vars , param_tensors )
730
+ if isinstance (packed_args , dict ):
731
+ body_out = body (packed_args )
732
+ else :
733
+ body_out = body (* packed_args )
734
+
735
+ if not isinstance (body_out , (list , tuple , dict )):
736
+ body_out = (body_out ,)
737
+
738
+ flat_body_out = flatten_structure (body_out )
739
+ if isinstance (packed_args , dict ):
740
+ cond_output = get_ov_output (cond (body_out ))
741
+ else :
742
+ cond_output = get_ov_output (cond (* body_out ))
743
+
744
+ if len (cond_output .get_partial_shape ()) != 0 :
745
+ raise ValueError (
746
+ "`cond` function must return a scalar boolean value, "
747
+ "but got shape {}" .format (cond_output .get_partial_shape ())
748
+ )
749
+
750
+ results = [cond_output ] + [get_ov_output (x ) for x in flat_body_out ]
751
+ body_func = Model (results = results , parameters = params )
752
+ loop .set_function (body_func )
753
+ loop .set_special_body_ports ([0 , 0 ])
754
+
755
+ for param , init_val , next_val in zip (params , loop_vars_ov , flat_body_out ):
756
+ loop .set_merged_input (param , init_val , get_ov_output (next_val ))
757
+
758
+ outputs_flat = [
759
+ OpenVINOKerasTensor (loop .get_iter_value (get_ov_output (val )))
760
+ for val in flat_body_out
761
+ ]
762
+ final_output , _ = pack_structure (loop_vars , outputs_flat )
763
+
764
+ if is_scalar_input :
765
+ if isinstance (final_output , tuple ):
766
+ return final_output [0 ]
767
+ else :
768
+ return final_output
769
+ else :
770
+ return final_output
771
+
675
772
676
773
def fori_loop (lower , upper , body_fun , init_val ):
677
774
raise NotImplementedError (
0 commit comments