Skip to content

Commit 8d8c30b

Browse files
committed
addressed the comments
1 parent 712b20f commit 8d8c30b

File tree

4 files changed

+21
-23
lines changed

4 files changed

+21
-23
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10561,6 +10561,7 @@ namespace {
1056110561
// | input, if target == 1
1056210562
// loss(x) = |
1056310563
// | max(0, margin - input), if target == -1
10564+
// target tensor may have values other than 1 and -1
1056410565
class DecomposeHingeEmbeddingLoss
1056510566
: public OpRewritePattern<AtenHingeEmbeddingLossOp> {
1056610567
using OpRewritePattern<AtenHingeEmbeddingLossOp>::OpRewritePattern;
@@ -10577,6 +10578,12 @@ class DecomposeHingeEmbeddingLoss
1057710578
auto targetTy = dyn_cast<ValueTensorType>(target.getType());
1057810579
if (!targetTy.hasDtype() || !targetTy.hasSizes())
1057910580
return rewriter.notifyMatchFailure(op, "target must have dtype and size");
10581+
10582+
int64_t reduction;
10583+
if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) {
10584+
return rewriter.notifyMatchFailure(op,
10585+
"reduction should be a constant int!");
10586+
}
1058010587
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
1058110588
Value minusOne = getConstantWithGivenDtypeAndValue(rewriter, loc, -1,
1058210589
targetTy.getDtype());
@@ -10613,22 +10620,14 @@ class DecomposeHingeEmbeddingLoss
1061310620
// Add : outputMargin + outputSelf
1061410621
auto output = rewriter.create<AtenAddTensorOp>(loc, inputTy, outputMargin,
1061510622
outputSelf, /*alpha=*/alpha);
10616-
int64_t reduction;
10617-
if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) {
10618-
return rewriter.notifyMatchFailure(op,
10619-
"reduction should be a constant int!");
10620-
}
10621-
Value loss;
10623+
Value loss = output;
1062210624
Value none = rewriter.create<ConstantNoneOp>(loc);
1062310625
// reduction: mean
1062410626
if (reduction == 1) {
1062510627
loss = rewriter.create<AtenMeanOp>(loc, resultTy, output, none);
1062610628
} else if (reduction == 2) {
1062710629
// reduction: sum
1062810630
loss = rewriter.create<AtenSumOp>(loc, resultTy, output, none);
10629-
} else {
10630-
// reduction: none
10631-
loss = output;
1063210631
}
1063310632
rewriter.replaceOp(op, loss);
1063410633
return success();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,7 +1797,7 @@
17971797
"L1LossSumReductionModule_basic",
17981798
"HingeEmbeddingLossReductionMeanModule_basic",
17991799
"HingeEmbeddingLossReductionSumModule_basic",
1800-
"HingeEmbeddingLossWithoutReductionModule_basic",
1800+
"HingeEmbeddingLossReductionNoneModule_basic",
18011801
"PixelShuffleModuleStaticRank3Int64_basic",
18021802
"PixelShuffleModuleStaticRank4Float32_basic",
18031803
"RandIntLowModule_basic",
@@ -2974,7 +2974,7 @@
29742974
"HingeEmbeddingLossBasicModule_basic",
29752975
"HingeEmbeddingLossReductionMeanModule_basic",
29762976
"HingeEmbeddingLossReductionSumModule_basic",
2977-
"HingeEmbeddingLossWithoutReductionModule_basic",
2977+
"HingeEmbeddingLossReductionNoneModule_basic",
29782978
"HstackBasicComplexModule_basic",
29792979
"HstackBasicFloatModule_basic",
29802980
"HstackBasicIntFloatModule_basic",
@@ -3988,7 +3988,7 @@
39883988
"HingeEmbeddingLossBasicModule_basic",
39893989
"HingeEmbeddingLossReductionMeanModule_basic",
39903990
"HingeEmbeddingLossReductionSumModule_basic",
3991-
"HingeEmbeddingLossWithoutReductionModule_basic",
3991+
"HingeEmbeddingLossReductionNoneModule_basic",
39923992
"Exp2StaticModule_basic",
39933993
"ElementwiseRreluWithNoiseEvalModule_basic",
39943994
"ElementwiseRreluWithNoiseEvalStaticModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,8 +2189,7 @@ def aten〇nll_loss_backward〡shape(grad_output: List[int], self: List[int], ta
21892189
def aten〇hinge_embedding_loss〡shape(self: List[int], target: List[int], margin: float = 1., reduction: int = 1) -> List[int]:
21902190
if reduction in [1,2]:
21912191
return []
2192-
else:
2193-
return upstream_shape_functions.unary(self)
2192+
return upstream_shape_functions.unary(self)
21942193

21952194
# TODO: upstream this
21962195
def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]:

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,8 +2486,8 @@ def __init__(self):
24862486
@annotate_args(
24872487
[
24882488
None,
2489-
([8, 1], torch.float32, True),
2490-
([1, 1], torch.float32, True),
2489+
([-1, -1], torch.float32, True),
2490+
([-1, -1], torch.float32, True),
24912491
]
24922492
)
24932493
def forward(self, input, target):
@@ -2507,8 +2507,8 @@ def __init__(self):
25072507
@annotate_args(
25082508
[
25092509
None,
2510-
([2, 5], torch.float32, True),
2511-
([1, 1], torch.float32, True),
2510+
([-1, -1], torch.float32, True),
2511+
([-1, -1], torch.float32, True),
25122512
]
25132513
)
25142514
def forward(self, input, target):
@@ -2520,24 +2520,24 @@ def HingeEmbeddingLossReductionSumModule_basic(module, tu: TestUtils):
25202520
module.forward(tu.rand(2, 5), tu.rand(1, 1))
25212521

25222522

2523-
class HingeEmbeddingLossWithoutReductionModule(torch.nn.Module):
2523+
class HingeEmbeddingLossReductionNoneModule(torch.nn.Module):
25242524
def __init__(self):
25252525
super().__init__()
25262526

25272527
@export
25282528
@annotate_args(
25292529
[
25302530
None,
2531-
([8, 5], torch.float32, True),
2532-
([1], torch.float32, True),
2531+
([-1, -1], torch.float32, True),
2532+
([-1], torch.float32, True),
25332533
]
25342534
)
25352535
def forward(self, input, target):
25362536
return torch.ops.aten.hinge_embedding_loss(input, target, margin=1.0)
25372537

25382538

2539-
@register_test_case(module_factory=lambda: HingeEmbeddingLossWithoutReductionModule())
2540-
def HingeEmbeddingLossWithoutReductionModule_basic(module, tu: TestUtils):
2539+
@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionNoneModule())
2540+
def HingeEmbeddingLossReductionNoneModule_basic(module, tu: TestUtils):
25412541
module.forward(tu.rand(8, 5), tu.rand(1))
25422542

25432543

0 commit comments

Comments
 (0)