diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index c2c881e77..f149ae88d 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -74,7 +74,7 @@ def rewrite_constant_fold(g, ops): func_map = { "Add": np.add, "GreaterEqual": np.greater_equal, - "Cast": np.cast, + "Cast": np.asarray, "ConcatV2": np.concatenate, "Less": np.less, "ListDiff": np.setdiff1d, @@ -107,7 +107,7 @@ def rewrite_constant_fold(g, ops): if op.type == "Cast": dst = op.get_attr_int("to") np_type = tf2onnx.utils.map_onnx_to_numpy_type(dst) - val = np.cast[np_type](*inputs) + val = np.asarray(*inputs, dtype=np_type) elif op.type == "ConcatV2": axis = inputs[-1] values = inputs[:-1]