Skip to content

Commit 3e7f664

Browse files
committed
Support range() loops (alias for hl.grid)
Fixes #207 stack-info: PR: #212, branch: jansel/stack/63
1 parent 751b0c5 commit 3e7f664

File tree

4 files changed

+128
-2
lines changed

4 files changed

+128
-2
lines changed

helion/_compiler/device_ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import builtins
45
from collections.abc import Callable
56
import contextlib
67
import dataclasses
@@ -495,7 +496,7 @@ def _extract_tile_begin_end(self, for_node: ast.For) -> tuple[object, object]:
495496
assert isinstance(func_node, ExtendedAST)
496497
func_type = func_node._type_info
497498
assert isinstance(func_type, CallableType)
498-
assert func_type.value in (hl.tile, hl.grid)
499+
assert func_type.value in (hl.tile, hl.grid, builtins.range)
499500
args = call_node.args
500501
assert len(args) >= 1
501502
if len(args) == 1:

helion/language/loops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import builtins
45
from typing import TYPE_CHECKING
56
from typing import Iterator
67
from typing import Sequence
@@ -286,6 +287,7 @@ def _codegen_loop_helper(
286287

287288

288289
@overload
290+
@_decorators.device_func_replacement(builtins.range)
289291
@_decorators.api(
290292
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
291293
)
@@ -298,6 +300,7 @@ def grid(
298300

299301

300302
@overload
303+
@_decorators.device_func_replacement(builtins.range)
301304
@_decorators.api(
302305
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
303306
)
@@ -309,6 +312,7 @@ def grid(
309312
) -> Iterator[Sequence[torch.SymInt]]: ...
310313

311314

315+
@_decorators.device_func_replacement(builtins.range)
312316
@_decorators.api(
313317
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
314318
)

test/test_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def fn(x: torch.Tensor) -> torch.Tensor:
118118
batch = x.size(0)
119119
out = x.new_empty(batch)
120120
for tile_batch in hl.tile(batch):
121-
for i in range(10):
121+
for i in [1, 2, 3]:
122122
out[tile_batch] = x[tile_batch] + i
123123
return out
124124

test/test_grid.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,127 @@ def _tile_begin_end_make_precompiler(x: torch.Tensor):
716716
return make_precompiler(_tile_begin_end_kernel)(x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
717717
)
718718

719+
def test_range_as_grid_basic(self):
720+
"""Test that range() works as an alias for hl.grid() in device code."""
721+
722+
@helion.kernel(use_default_config=True)
723+
def range_kernel(x: torch.Tensor) -> torch.Tensor:
724+
batch = x.size(0)
725+
out = x.new_zeros(batch)
726+
for tile_batch in hl.tile(batch):
727+
for i in range(10): # This should work now as alias for hl.grid(10)
728+
out[tile_batch] += x[tile_batch] + i
729+
return out
730+
731+
x = torch.randn(35, device=DEVICE)
732+
733+
# Reference: sum over i of (x + i) = 10*x + sum(0..9) = 10*x + 45
734+
expected = 10 * x + 45
735+
736+
code, result = code_and_output(range_kernel, (x,))
737+
torch.testing.assert_close(result, expected)
738+
739+
def test_range_with_begin_end(self):
740+
"""Test that range(begin, end) works as alias for hl.grid(begin, end)."""
741+
742+
@helion.kernel(use_default_config=True)
743+
def range_begin_end_kernel(x: torch.Tensor) -> torch.Tensor:
744+
batch = x.size(0)
745+
out = x.new_zeros(batch)
746+
for tile_batch in hl.tile(batch):
747+
for i in range(2, 7): # range(begin, end)
748+
out[tile_batch] += x[tile_batch] * i
749+
return out
750+
751+
x = torch.randn(20, device=DEVICE)
752+
753+
# Reference: x * sum(range(2, 7)) = x * sum(2,3,4,5,6) = x * 20
754+
expected = x * 20
755+
756+
code, result = code_and_output(range_begin_end_kernel, (x,))
757+
torch.testing.assert_close(result, expected)
758+
759+
def test_range_with_step(self):
760+
"""Test that range(begin, end, step) works as alias for hl.grid(begin, end, step)."""
761+
762+
@helion.kernel(use_default_config=True)
763+
def range_step_kernel(x: torch.Tensor) -> torch.Tensor:
764+
batch = x.size(0)
765+
out = x.new_zeros(batch)
766+
for tile_batch in hl.tile(batch):
767+
for i in range(1, 10, 2): # range(begin, end, step)
768+
out[tile_batch] += x[tile_batch] / i
769+
return out
770+
771+
x = torch.randn(6, device=DEVICE)
772+
773+
# Reference: x * sum(1/i for i in range(1, 10, 2)) = x * sum(1/1, 1/3, 1/5, 1/7, 1/9)
774+
# = x * (1 + 1/3 + 1/5 + 1/7 + 1/9) = x * sum([1, 1/3, 1/5, 1/7, 1/9])
775+
reciprocal_sum = sum(1.0 / i for i in range(1, 10, 2))
776+
expected = x * reciprocal_sum
777+
778+
code, result = code_and_output(range_step_kernel, (x,))
779+
torch.testing.assert_close(result, expected)
780+
self.assertExpectedInline(
781+
code,
782+
"""\
783+
from __future__ import annotations
784+
785+
import torch
786+
import triton
787+
import triton.language as tl
788+
789+
@triton.jit
790+
def _range_step_kernel_kernel(out, x, out_stride_0, x_stride_0, batch, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
791+
pid_0 = tl.program_id(0)
792+
offset_0 = pid_0 * _BLOCK_SIZE_0
793+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
794+
mask_0 = indices_0 < batch
795+
for offset_1 in range(1, 10, _BLOCK_SIZE_1):
796+
load = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
797+
load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
798+
v_0 = offset_1.to(tl.float32)
799+
v_1 = load_1 / v_0
800+
v_2 = load + v_1
801+
tl.store(out + indices_0 * out_stride_0, v_2, mask_0)
802+
803+
def range_step_kernel(x: torch.Tensor):
804+
batch = x.size(0)
805+
out = x.new_zeros(batch)
806+
_BLOCK_SIZE_0 = 8
807+
_BLOCK_SIZE_1 = 2
808+
_range_step_kernel_kernel[triton.cdiv(batch, _BLOCK_SIZE_0),](out, x, out.stride(0), x.stride(0), batch, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
809+
return out
810+
811+
def _range_step_kernel_make_precompiler(x: torch.Tensor):
812+
batch = x.size(0)
813+
out = x.new_zeros(batch)
814+
_BLOCK_SIZE_0 = 8
815+
_BLOCK_SIZE_1 = 2
816+
from helion.runtime.precompile_shim import make_precompiler
817+
return make_precompiler(_range_step_kernel_kernel)(out, x, out.stride(0), x.stride(0), batch, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
818+
)
819+
820+
def test_range_with_tensor_size(self):
821+
"""Test that range(tensor.size(dim)) works with dynamic tensor dimensions."""
822+
823+
@helion.kernel(use_default_config=True)
824+
def range_tensor_size_kernel(x: torch.Tensor) -> torch.Tensor:
825+
batch = x.size(0)
826+
out = x.new_zeros(batch)
827+
for tile_batch in hl.tile(batch):
828+
for _ in range(x.size(1)): # Use tensor dimension in range
829+
out[tile_batch] += x[tile_batch, 0] # Just use first column
830+
return out
831+
832+
x = torch.randn(8, 5, device=DEVICE) # 8 rows, 5 columns
833+
834+
# Reference: Each row adds x[row, 0] for x.size(1) times = x[:, 0] * x.size(1)
835+
expected = x[:, 0] * x.size(1)
836+
837+
code, result = code_and_output(range_tensor_size_kernel, (x,))
838+
torch.testing.assert_close(result, expected)
839+
719840

720841
if __name__ == "__main__":
721842
unittest.main()

0 commit comments

Comments
 (0)