Skip to content

Support range() loops (alias for hl.grid) #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
import builtins
from collections.abc import Callable
import contextlib
import dataclasses
Expand Down Expand Up @@ -495,7 +496,7 @@ def _extract_tile_begin_end(self, for_node: ast.For) -> tuple[object, object]:
assert isinstance(func_node, ExtendedAST)
func_type = func_node._type_info
assert isinstance(func_type, CallableType)
assert func_type.value in (hl.tile, hl.grid)
assert func_type.value in (hl.tile, hl.grid, builtins.range)
args = call_node.args
assert len(args) >= 1
if len(args) == 1:
Expand Down
4 changes: 4 additions & 0 deletions helion/language/loops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
import builtins
from typing import TYPE_CHECKING
from typing import Iterator
from typing import Sequence
Expand Down Expand Up @@ -286,6 +287,7 @@ def _codegen_loop_helper(


@overload
@_decorators.device_func_replacement(builtins.range)
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
Expand All @@ -298,6 +300,7 @@ def grid(


@overload
@_decorators.device_func_replacement(builtins.range)
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
Expand All @@ -309,6 +312,7 @@ def grid(
) -> Iterator[Sequence[torch.SymInt]]: ...


@_decorators.device_func_replacement(builtins.range)
@_decorators.api(
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
)
Expand Down
2 changes: 1 addition & 1 deletion test/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def fn(x: torch.Tensor) -> torch.Tensor:
batch = x.size(0)
out = x.new_empty(batch)
for tile_batch in hl.tile(batch):
for i in range(10):
for i in [1, 2, 3]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean the literal [1,2,3] is not supported? should we support it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, it is not supported. We could support it by unrolling the loop, not sure if it is allowd in Triton.

out[tile_batch] = x[tile_batch] + i
return out

Expand Down
121 changes: 121 additions & 0 deletions test/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,127 @@ def _tile_begin_end_make_precompiler(x: torch.Tensor):
return make_precompiler(_tile_begin_end_kernel)(x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
)

def test_range_as_grid_basic(self):
"""Test that range() works as an alias for hl.grid() in device code."""

@helion.kernel(use_default_config=True)
def range_kernel(x: torch.Tensor) -> torch.Tensor:
batch = x.size(0)
out = x.new_zeros(batch)
for tile_batch in hl.tile(batch):
for i in range(10): # This should work now as alias for hl.grid(10)
out[tile_batch] += x[tile_batch] + i
return out

x = torch.randn(35, device=DEVICE)

# Reference: sum over i of (x + i) = 10*x + sum(0..9) = 10*x + 45
expected = 10 * x + 45

code, result = code_and_output(range_kernel, (x,))
torch.testing.assert_close(result, expected)

def test_range_with_begin_end(self):
"""Test that range(begin, end) works as alias for hl.grid(begin, end)."""

@helion.kernel(use_default_config=True)
def range_begin_end_kernel(x: torch.Tensor) -> torch.Tensor:
batch = x.size(0)
out = x.new_zeros(batch)
for tile_batch in hl.tile(batch):
for i in range(2, 7): # range(begin, end)
out[tile_batch] += x[tile_batch] * i
return out

x = torch.randn(20, device=DEVICE)

# Reference: x * sum(range(2, 7)) = x * sum(2,3,4,5,6) = x * 20
expected = x * 20

code, result = code_and_output(range_begin_end_kernel, (x,))
torch.testing.assert_close(result, expected)

def test_range_with_step(self):
"""Test that range(begin, end, step) works as alias for hl.grid(begin, end, step)."""

@helion.kernel(use_default_config=True)
def range_step_kernel(x: torch.Tensor) -> torch.Tensor:
batch = x.size(0)
out = x.new_zeros(batch)
for tile_batch in hl.tile(batch):
for i in range(1, 10, 2): # range(begin, end, step)
out[tile_batch] += x[tile_batch] / i
return out

x = torch.randn(6, device=DEVICE)

# Reference: x * sum(1/i for i in range(1, 10, 2)) = x * sum(1/1, 1/3, 1/5, 1/7, 1/9)
# = x * (1 + 1/3 + 1/5 + 1/7 + 1/9) = x * sum([1, 1/3, 1/5, 1/7, 1/9])
reciprocal_sum = sum(1.0 / i for i in range(1, 10, 2))
expected = x * reciprocal_sum

code, result = code_and_output(range_step_kernel, (x,))
torch.testing.assert_close(result, expected)
self.assertExpectedInline(
code,
"""\
from __future__ import annotations

import torch
import triton
import triton.language as tl

@triton.jit
def _range_step_kernel_kernel(out, x, out_stride_0, x_stride_0, batch, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < batch
for offset_1 in range(1, 10, _BLOCK_SIZE_1):
load = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
v_0 = offset_1.to(tl.float32)
v_1 = load_1 / v_0
v_2 = load + v_1
tl.store(out + indices_0 * out_stride_0, v_2, mask_0)

def range_step_kernel(x: torch.Tensor):
batch = x.size(0)
out = x.new_zeros(batch)
_BLOCK_SIZE_0 = 8
_BLOCK_SIZE_1 = 2
_range_step_kernel_kernel[triton.cdiv(batch, _BLOCK_SIZE_0),](out, x, out.stride(0), x.stride(0), batch, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out

def _range_step_kernel_make_precompiler(x: torch.Tensor):
batch = x.size(0)
out = x.new_zeros(batch)
_BLOCK_SIZE_0 = 8
_BLOCK_SIZE_1 = 2
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_range_step_kernel_kernel)(out, x, out.stride(0), x.stride(0), batch, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
)

def test_range_with_tensor_size(self):
"""Test that range(tensor.size(dim)) works with dynamic tensor dimensions."""

@helion.kernel(use_default_config=True)
def range_tensor_size_kernel(x: torch.Tensor) -> torch.Tensor:
batch = x.size(0)
out = x.new_zeros(batch)
for tile_batch in hl.tile(batch):
for _ in range(x.size(1)): # Use tensor dimension in range
out[tile_batch] += x[tile_batch, 0] # Just use first column
return out

x = torch.randn(8, 5, device=DEVICE) # 8 rows, 5 columns

# Reference: Each row adds x[row, 0] for x.size(1) times = x[:, 0] * x.size(1)
expected = x[:, 0] * x.size(1)

code, result = code_and_output(range_tensor_size_kernel, (x,))
torch.testing.assert_close(result, expected)


if __name__ == "__main__":
unittest.main()
Loading