Skip to content

Commit 51ea380

Browse files
committed
Fix duplicate argument handling in inductor lowering
Fixes #221 stack-info: PR: #222, branch: jansel/stack/68
1 parent ddcd924 commit 51ea380

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

helion/_compiler/inductor_lowering.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,20 @@ def prepare_node_lowering(
105105
node.meta["lowering"] = SympyExprLowering(val._sympy_())
106106
return
107107

108+
# Track arguments to reuse names for duplicates
109+
arg_to_name: dict[Node, str] = {}
110+
108111
def convert_arg(arg: Node) -> TensorBox:
109112
example = arg.meta["val"]
110-
input_names.append(name := f"{node.name}_input{len(input_names)}")
113+
114+
# Reuse existing name for duplicate arguments
115+
if arg in arg_to_name:
116+
name = arg_to_name[arg]
117+
else:
118+
name = f"{node.name}_input{len(input_names)}"
119+
arg_to_name[arg] = name
120+
input_names.append(name)
121+
111122
if isinstance(example, (torch.SymInt, torch.SymFloat, torch.SymBool)):
112123
dtype = {
113124
torch.SymInt: torch.int64,

test/test_misc.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,25 @@
1616

1717

1818
class TestMisc(TestCase):
19+
def test_binary_operation_duplicate_args(self):
20+
"""Test case to reproduce issue #221: binary operations with duplicate tensor references"""
21+
22+
@helion.kernel(use_default_config=True)
23+
def kernel_with_duplicate_refs(x: torch.Tensor) -> torch.Tensor:
24+
result = torch.empty_like(x)
25+
for tile in hl.tile(x.shape):
26+
val = x[tile]
27+
result[tile] = (
28+
val * val + val
29+
) # Multiple uses of same variable - triggers the bug
30+
return result
31+
32+
x = torch.randn([16, 16], device=DEVICE)
33+
expected = x * x + x
34+
35+
code, result = code_and_output(kernel_with_duplicate_refs, (x,))
36+
torch.testing.assert_close(result, expected)
37+
1938
def test_torch_alloc(self):
2039
@helion.kernel(config={"block_sizes": [64, 64]})
2140
def fn(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)