Skip to content

Commit 2736ff4

Browse files
authored
Add lowering for Constant assignment (#187)
1 parent 58fff40 commit 2736ff4

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

helion/_compiler/inductor_lowering.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,8 @@ def proxy_arg(self, i: int) -> object:
988988

989989
def ast_arg(self, i: int) -> ast.AST:
990990
rv = self.ast_args[i]
991+
if isinstance(rv, int | float | bool):
992+
rv = ast.Constant(value=rv)
991993
assert isinstance(rv, ast.AST), "TODO: convert nested/defaults"
992994
return rv
993995

test/test_indexing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,21 @@ def test_block_size_access(x: torch.Tensor) -> torch.Tensor:
279279
expected = torch.full_like(x, 1, dtype=torch.int32)
280280
torch.testing.assert_close(result, expected)
281281

282+
def test_assign_int(self):
283+
@helion.kernel
284+
def fn(x: torch.Tensor) -> torch.Tensor:
285+
for tile in hl.tile(x.size(0)):
286+
x[tile] = 1
287+
return x
288+
289+
x = torch.zeros([200], device=DEVICE)
290+
expected = torch.ones_like(x)
291+
code, result = code_and_output(
292+
fn,
293+
(x,),
294+
)
295+
torch.testing.assert_close(result, expected)
296+
282297
def test_atomic_add_symint(self):
283298
@helion.kernel(config={"block_size": 32})
284299
def fn(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)