Skip to content

Commit 20177c3

Browse files
committed
Lower to torch dialect without expansion
Signed-off-by: Zahid Wakeel <[email protected]>
1 parent 716303a commit 20177c3

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,6 +1602,67 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
16021602
/* cudnn enabled */ boolFalse);
16031603
return success();
16041604
});
1605+
patterns.onOp(
1606+
"MeanVarianceNormalization", 13,
1607+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
1608+
Torch::ValueTensorType resultType;
1609+
Value input;
1610+
SmallVector<int64_t> axes;
1611+
1612+
if (binder.tensorOperand(input) ||
1613+
binder.s64IntegerArrayAttr(axes, "axes",
1614+
llvm::SmallVector<int64_t>({0, 2, 3})) ||
1615+
binder.tensorResultType(resultType)) {
1616+
return failure();
1617+
}
1618+
Location loc = binder.getLoc();
1619+
Value keepDim = rewriter.create<Torch::ConstantBoolOp>(loc, true);
1620+
Value unBiased = rewriter.create<Torch::ConstantBoolOp>(loc, false);
1621+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
1622+
1623+
ArrayRef<int64_t> input_shape = resultType.getSizes();
1624+
SmallVector<int64_t> reduced_shape(input_shape);
1625+
for (int64_t i : axes) {
1626+
reduced_shape[i] = 1;
1627+
}
1628+
1629+
Torch::ValueTensorType meanOutTy = Torch::ValueTensorType::get(
1630+
resultType.getContext(), reduced_shape, resultType.getDtype());
1631+
SmallVector<Value> cstAxes;
1632+
for (int64_t i : axes) {
1633+
cstAxes.push_back(rewriter.create<Torch::ConstantIntOp>(
1634+
loc, rewriter.getI64IntegerAttr(i)));
1635+
}
1636+
Value axes_list = rewriter.create<Torch::PrimListConstructOp>(
1637+
loc,
1638+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
1639+
cstAxes);
1640+
Value mean = rewriter.create<Torch::AtenMeanDimOp>(
1641+
loc, meanOutTy, input, axes_list, keepDim, none);
1642+
1643+
Value variance = rewriter.create<Torch::AtenVarDimOp>(
1644+
loc, meanOutTy, input, axes_list, unBiased, keepDim);
1645+
1646+
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
1647+
loc, rewriter.getI64IntegerAttr(1));
1648+
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
1649+
loc, rewriter.getF64FloatAttr(1e-9));
1650+
Value positiveVariance = rewriter.create<Torch::AtenAddScalarOp>(
1651+
loc, meanOutTy, variance, cstEps, cstOne);
1652+
1653+
Value sqrt = rewriter.create<Torch::AtenSqrtOp>(loc, meanOutTy,
1654+
positiveVariance);
1655+
1656+
Value subValue = rewriter.create<Torch::AtenSubTensorOp>(
1657+
loc, resultType, input, mean, cstOne);
1658+
1659+
Value meanVarNorm = rewriter.create<Torch::AtenDivTensorOp>(
1660+
loc, resultType, subValue, sqrt);
1661+
1662+
rewriter.replaceOp(binder.op, meanVarNorm);
1663+
1664+
return success();
1665+
});
16051666
patterns.onOp(
16061667
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
16071668
Torch::ValueTensorType resultType;

0 commit comments

Comments
 (0)