@@ -1602,6 +1602,67 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1602
1602
/* cudnn enabled */ boolFalse);
1603
1603
return success ();
1604
1604
});
1605
+ patterns.onOp (
1606
+ " MeanVarianceNormalization" , 13 ,
1607
+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1608
+ Torch::ValueTensorType resultType;
1609
+ Value input;
1610
+ SmallVector<int64_t > axes;
1611
+
1612
+ if (binder.tensorOperand (input) ||
1613
+ binder.s64IntegerArrayAttr (axes, " axes" ,
1614
+ llvm::SmallVector<int64_t >({0 , 2 , 3 })) ||
1615
+ binder.tensorResultType (resultType)) {
1616
+ return failure ();
1617
+ }
1618
+ Location loc = binder.getLoc ();
1619
+ Value keepDim = rewriter.create <Torch::ConstantBoolOp>(loc, true );
1620
+ Value unBiased = rewriter.create <Torch::ConstantBoolOp>(loc, false );
1621
+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
1622
+
1623
+ ArrayRef<int64_t > input_shape = resultType.getSizes ();
1624
+ SmallVector<int64_t > reduced_shape (input_shape);
1625
+ for (int64_t i : axes) {
1626
+ reduced_shape[i] = 1 ;
1627
+ }
1628
+
1629
+ Torch::ValueTensorType meanOutTy = Torch::ValueTensorType::get (
1630
+ resultType.getContext (), reduced_shape, resultType.getDtype ());
1631
+ SmallVector<Value> cstAxes;
1632
+ for (int64_t i : axes) {
1633
+ cstAxes.push_back (rewriter.create <Torch::ConstantIntOp>(
1634
+ loc, rewriter.getI64IntegerAttr (i)));
1635
+ }
1636
+ Value axes_list = rewriter.create <Torch::PrimListConstructOp>(
1637
+ loc,
1638
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
1639
+ cstAxes);
1640
+ Value mean = rewriter.create <Torch::AtenMeanDimOp>(
1641
+ loc, meanOutTy, input, axes_list, keepDim, none);
1642
+
1643
+ Value variance = rewriter.create <Torch::AtenVarDimOp>(
1644
+ loc, meanOutTy, input, axes_list, unBiased, keepDim);
1645
+
1646
+ Value cstOne = rewriter.create <Torch::ConstantIntOp>(
1647
+ loc, rewriter.getI64IntegerAttr (1 ));
1648
+ Value cstEps = rewriter.create <Torch::ConstantFloatOp>(
1649
+ loc, rewriter.getF64FloatAttr (1e-9 ));
1650
+ Value positiveVariance = rewriter.create <Torch::AtenAddScalarOp>(
1651
+ loc, meanOutTy, variance, cstEps, cstOne);
1652
+
1653
+ Value sqrt = rewriter.create <Torch::AtenSqrtOp>(loc, meanOutTy,
1654
+ positiveVariance);
1655
+
1656
+ Value subValue = rewriter.create <Torch::AtenSubTensorOp>(
1657
+ loc, resultType, input, mean, cstOne);
1658
+
1659
+ Value meanVarNorm = rewriter.create <Torch::AtenDivTensorOp>(
1660
+ loc, resultType, subValue, sqrt);
1661
+
1662
+ rewriter.replaceOp (binder.op , meanVarNorm);
1663
+
1664
+ return success ();
1665
+ });
1605
1666
patterns.onOp (
1606
1667
" Max" , 1 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1607
1668
Torch::ValueTensorType resultType;
0 commit comments