Skip to content

Commit 0c29ccf

Browse files
authored
[ONNX] Add per-channel quantization support for QuantizeLinear op (#4092)
1 parent 21e6e12 commit 0c29ccf

File tree

2 files changed

+108
-33
lines changed

2 files changed

+108
-33
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
269269

270270
auto resultETy = resultType.getDtype();
271271

272-
bool rank0 = scaleTy.getSizes().size() == 0;
273-
bool length1 =
274-
scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1;
275-
276-
if (!rank0 && !length1)
277-
return rewriter.notifyMatchFailure(binder.op,
278-
"unimplemented: non-scalar scale");
272+
int64_t scaleRank = scaleTy.getSizes().size();
273+
if (scaleRank > 1)
274+
return rewriter.notifyMatchFailure(
275+
binder.op, "unimplemented: only per-tensor or per-axis "
276+
"quantization supported");
279277

280278
auto qTensorTy = getQTorchTypeFromTorchIntType(resultType);
281279
if (!qTensorTy) {
@@ -290,37 +288,66 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
290288
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
291289
static_cast<int64_t>(torchqTy)));
292290

293-
scale = rewriter.create<Torch::AtenItemOp>(
294-
loc, rewriter.getType<Torch::FloatType>(), scale);
295-
296291
bool fpResult = isa<mlir::FloatType>(resultETy);
297-
Type zeropointTy = rewriter.getType<Torch::IntType>();
298-
if (fpResult)
299-
zeropointTy = rewriter.getType<Torch::FloatType>();
300-
zeropoint =
301-
rewriter.create<Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
302-
303-
if (fpResult) {
304-
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
305-
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
306-
Value one = rewriter.create<Torch::ConstantFloatOp>(
307-
loc, rewriter.getF64FloatAttr(1.0));
308-
Value div = rewriter.create<Torch::AtenDivScalarOp>(
309-
loc, operand.getType(), operand, scale);
310-
Value add = rewriter.create<Torch::AtenAddScalarOp>(
311-
loc, operand.getType(), div, zeropoint, one);
292+
bool isPerTensorQuantization = false;
293+
if (scaleRank == 0 ||
294+
llvm::all_of(scaleTy.getSizes(), [](int64_t s) { return s == 1; }))
295+
isPerTensorQuantization = true;
312296

313-
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
314-
binder.op, resultType, add, tyConst,
315-
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
316-
/*memory_format=*/none);
297+
// (TODO) Case: Per-Channel Quantization for floating point output.
298+
if (scaleRank == 1 && fpResult)
299+
return rewriter.notifyMatchFailure(
300+
binder.op, "unimplemented: support for per-Channel Quantization "
301+
"for floating point output.");
302+
303+
if (isPerTensorQuantization) {
304+
scale = rewriter.create<Torch::AtenItemOp>(
305+
loc, rewriter.getType<Torch::FloatType>(), scale);
306+
307+
Type zeropointTy = rewriter.getType<Torch::IntType>();
308+
if (fpResult)
309+
zeropointTy = rewriter.getType<Torch::FloatType>();
310+
zeropoint =
311+
rewriter.create<Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
312+
}
313+
314+
if (!fpResult) {
315+
Value quantize;
316+
// Case 1: Per-Tensor Quantization for non-floating point input.
317+
if (isPerTensorQuantization) {
318+
quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(
319+
loc, qTensorTy, operand, scale, zeropoint, tyConst);
320+
} else {
321+
// Case 2: Per-Channel Quantization for non-floating point input.
322+
int64_t axis;
323+
if (binder.s64IntegerAttr(axis, "axis", 1))
324+
return failure();
325+
326+
Value cstAxis = rewriter.create<Torch::ConstantIntOp>(
327+
loc, rewriter.getI64IntegerAttr(axis));
328+
quantize = rewriter.create<Torch::AtenQuantizePerChannelOp>(
329+
loc, qTensorTy, operand, scale, zeropoint, cstAxis, tyConst);
330+
}
331+
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(
332+
binder.op, resultType, quantize);
317333
return success();
318334
}
319335

320-
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(
321-
loc, qTensorTy, operand, scale, zeropoint, tyConst);
322-
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
323-
quantize);
336+
// Case 3: Per-Tensor Quantization for floating point input.
337+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
338+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
339+
Value one = rewriter.create<Torch::ConstantFloatOp>(
340+
loc, rewriter.getF64FloatAttr(1.0));
341+
Value div = rewriter.create<Torch::AtenDivScalarOp>(
342+
loc, operand.getType(), operand, scale);
343+
Value add = rewriter.create<Torch::AtenAddScalarOp>(
344+
loc, operand.getType(), div, zeropoint, one);
345+
346+
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
347+
binder.op, resultType, add, tyConst,
348+
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
349+
/*memory_format=*/none);
350+
324351
return success();
325352
});
326353
patterns.onOp(

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,54 @@ func.func @test_quantizelinear_f8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.
6464

6565
// -----
6666

67+
// CHECK-LABEL: @test_quantizelinear_per_channel_si8
68+
func.func @test_quantizelinear_per_channel_si8(%arg0: !torch.vtensor<[4,3,7,7],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],si8>) -> !torch.vtensor<[4,3,7,7],si8> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64} {
69+
// CHECK: %[[DTYPE:.+]] = torch.constant.int 12
70+
// CHECK: %[[AXIS:.+]] = torch.constant.int 1
71+
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_channel %arg0, %arg1, %arg2, %[[AXIS]], %[[DTYPE]]
72+
// CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]]
73+
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[4,3,7,7],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],si8>) -> !torch.vtensor<[4,3,7,7],si8>
74+
return %0: !torch.vtensor<[4,3,7,7],si8>
75+
}
76+
77+
// -----
78+
79+
// CHECK-LABEL: @test_quantizelinear_per_channel_ui8
80+
func.func @test_quantizelinear_per_channel_ui8(%arg0: !torch.vtensor<[4,3,7,7],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],ui8>) -> !torch.vtensor<[4,3,7,7],ui8> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64} {
81+
// CHECK: %[[DTYPE:.+]] = torch.constant.int 13
82+
// CHECK: %[[AXIS:.+]] = torch.constant.int 1
83+
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_channel %arg0, %arg1, %arg2, %[[AXIS]], %[[DTYPE]]
84+
// CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]]
85+
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[4,3,7,7],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],ui8>) -> !torch.vtensor<[4,3,7,7],ui8>
86+
return %0: !torch.vtensor<[4,3,7,7],ui8>
87+
}
88+
89+
// -----
90+
91+
// CHECK-LABEL: @test_quantizelinear_per_channel_si16
92+
func.func @test_quantizelinear_per_channel_si16(%arg0: !torch.vtensor<[4,3,7,7],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],si16>) -> !torch.vtensor<[4,3,7,7],si16> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64} {
93+
// CHECK: %[[DTYPE:.+]] = torch.constant.int 27
94+
// CHECK: %[[AXIS:.+]] = torch.constant.int 1
95+
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_channel %arg0, %arg1, %arg2, %[[AXIS]], %[[DTYPE]]
96+
// CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]]
97+
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[4,3,7,7],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],si16>) -> !torch.vtensor<[4,3,7,7],si16>
98+
return %0: !torch.vtensor<[4,3,7,7],si16>
99+
}
100+
101+
// -----
102+
103+
// CHECK-LABEL: @test_quantizelinear_per_channel_si32
104+
func.func @test_quantizelinear_per_channel_si32(%arg0: !torch.vtensor<[4,3,7,7],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],si32>) -> !torch.vtensor<[4,3,7,7],si32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64} {
105+
// CHECK: %[[DTYPE:.+]] = torch.constant.int 14
106+
// CHECK: %[[AXIS:.+]] = torch.constant.int 1
107+
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_channel %arg0, %arg1, %arg2, %[[AXIS]], %[[DTYPE]]
108+
// CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]]
109+
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[4,3,7,7],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],si32>) -> !torch.vtensor<[4,3,7,7],si32>
110+
return %0: !torch.vtensor<[4,3,7,7],si32>
111+
}
112+
113+
// -----
114+
67115
// CHECK-LABEL: @test_qlinearconv_nobias
68116
func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
69117
%0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8>

0 commit comments

Comments
 (0)