Skip to content

Commit 2c1db54

Browse files
Add Support for MatMulInteger (#2072)
* Add Support for MatMulInteger MatMulInteger was supported in ONNX opset v10 (not checked in proposed change, the error can be addressed on save), this specific type combination is support in TensorFlow, but the node type not identified and handled properly here. Handles #2071 Signed-off-by: Gregory Morse <[email protected]> * Update math.py Signed-off-by: Gregory Morse <[email protected]> * Update support_status.md Signed-off-by: Gregory Morse <[email protected]> * Update test_backend.py Signed-off-by: Gregory Morse <[email protected]> Signed-off-by: Gregory Morse <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent 48e9015 commit 2c1db54

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

support_status.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
| AvgPool3D | 1 ~ 17 |
2828
| BatchMatMul | 1 ~ 17 |
2929
| BatchMatMulV2 | 1 ~ 17 |
30+
| BatchMatMulV3 | 1 ~ 17 |
3031
| BatchToSpaceND | 1 ~ 17 |
3132
| BiasAdd | 1 ~ 17 |
3233
| BiasAddV1 | 1 ~ 17 |

tests/test_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,15 @@ def func(x, y):
10751075
return tf.identity(x_, name=_TFOUTPUT)
10761076
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_val}, rtol=1e-5)
10771077

1078+
@check_tf_min_version("2.6")
1079+
def test_matmulinteger(self):
1080+
x_val = np.array([1, 2, -3, -4], dtype=np.int8).reshape((2, 2))
1081+
y_val = np.array([1, 2, -3, -4], dtype=np.int8).reshape((2, 2))
1082+
def func(x, y):
1083+
x_ = tf.matmul(x, y, output_type=tf.int32)
1084+
return tf.identity(x_, name=_TFOUTPUT)
1085+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
1086+
10781087
@check_onnxruntime_incompatibility("Sub")
10791088
def test_sub(self):
10801089
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))

tf2onnx/onnx_opset/math.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,13 +363,13 @@ def version_1(cls, ctx, node, **kwargs):
363363
name=op_name, shapes=shapes, dtypes=dtypes)
364364

365365

366-
@tf_op(["MatMul", "BatchMatMul", "BatchMatMulV2"])
366+
@tf_op(["MatMul", "BatchMatMul", "BatchMatMulV2", "BatchMatMulV3"])
367367
class MatMul:
368368
@classmethod
369369
def version_1(cls, ctx, node, **kwargs):
370370
# tensorflow allows transpose and conjugated. If found, insert the required transpose.
371371
# We could use Gemm as well but tensorflow does not pass bias in matmul.
372-
node.type = "MatMul"
372+
if node.type != "MatMulInteger": node.type = "MatMul"
373373

374374
attrs = ["transpose_a", "transpose_b", "adjoint_a", "adjoint_b", "adj_x", "adj_y"]
375375
attrs_val = [node.get_attr(attr) for attr in attrs]
@@ -408,7 +408,19 @@ def version_1(cls, ctx, node, **kwargs):
408408
val = node.get_attr(i)
409409
if val is not None and val.i != 0:
410410
raise ValueError(node.type + " attribute " + i + " is not supported")
411-
411+
@classmethod
412+
def version_10(cls, ctx, node, **kwargs):
413+
if (ctx.get_dtype(node.input[0]) in [onnx_pb.TensorProto.INT8, onnx_pb.TensorProto.UINT8] and
414+
ctx.get_dtype(node.input[1]) in [onnx_pb.TensorProto.INT8, onnx_pb.TensorProto.UINT8] and
415+
ctx.get_dtype(node.output[0]) == onnx_pb.TensorProto.INT32):
416+
node.type = "MatMulInteger"
417+
zpdata_a = np.zeros(1, dtype=utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0])))
418+
zero_point_node_a = ctx.make_const(utils.make_name("zero_point_a"), zpdata_a)
419+
zpdata_b = np.zeros(1, dtype=utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1])))
420+
zero_point_node_b = ctx.make_const(utils.make_name("zero_point_b"), zpdata_b)
421+
ctx.replace_inputs(node, [node.input[0], node.input[1],
422+
zero_point_node_a.output[0], zero_point_node_b.output[0]])
423+
cls.version_1(ctx, node, **kwargs)
412424

413425
@tf_op("Erf")
414426
class Erf:

0 commit comments

Comments
 (0)