diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 24c7b4c8..eb59971f 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import builtins from collections.abc import Callable import contextlib import dataclasses @@ -495,7 +496,7 @@ def _extract_tile_begin_end(self, for_node: ast.For) -> tuple[object, object]: assert isinstance(func_node, ExtendedAST) func_type = func_node._type_info assert isinstance(func_type, CallableType) - assert func_type.value in (hl.tile, hl.grid) + assert func_type.value in (hl.tile, hl.grid, builtins.range) args = call_node.args assert len(args) >= 1 if len(args) == 1: diff --git a/helion/language/loops.py b/helion/language/loops.py index 6f3ba28f..438bc84a 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import builtins from typing import TYPE_CHECKING from typing import Iterator from typing import Sequence @@ -286,6 +287,7 @@ def _codegen_loop_helper( @overload +@_decorators.device_func_replacement(builtins.range) @_decorators.api( is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True ) @@ -298,6 +300,7 @@ def grid( @overload +@_decorators.device_func_replacement(builtins.range) @_decorators.api( is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True ) @@ -309,6 +312,7 @@ def grid( ) -> Iterator[Sequence[torch.SymInt]]: ... +@_decorators.device_func_replacement(builtins.range) @_decorators.api( is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True ) diff --git a/test/test_errors.py b/test/test_errors.py index 74e7e8d2..186c96b0 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -118,7 +118,7 @@ def fn(x: torch.Tensor) -> torch.Tensor: batch = x.size(0) out = x.new_empty(batch) for tile_batch in hl.tile(batch): - for i in range(10): + for i in [1, 2, 3]: out[tile_batch] = x[tile_batch] + i return out diff --git a/test/test_grid.py b/test/test_grid.py index 0a024abd..9196e5b5 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -716,6 +716,127 @@ def _tile_begin_end_make_precompiler(x: torch.Tensor): return make_precompiler(_tile_begin_end_kernel)(x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", ) + def test_range_as_grid_basic(self): + """Test that range() works as an alias for hl.grid() in device code.""" + + @helion.kernel(use_default_config=True) + def range_kernel(x: torch.Tensor) -> torch.Tensor: + batch = x.size(0) + out = x.new_zeros(batch) + for tile_batch in hl.tile(batch): + for i in range(10): # This should work now as alias for hl.grid(10) + out[tile_batch] += x[tile_batch] + i + return out + + x = torch.randn(35, device=DEVICE) + + # Reference: sum over i of (x + i) = 10*x + sum(0..9) = 10*x + 45 + expected = 10 * x + 45 + + code, result = code_and_output(range_kernel, (x,)) + torch.testing.assert_close(result, expected) + + def test_range_with_begin_end(self): + """Test that range(begin, end) works as alias for hl.grid(begin, end).""" + + @helion.kernel(use_default_config=True) + def range_begin_end_kernel(x: torch.Tensor) -> torch.Tensor: + batch = x.size(0) + out = x.new_zeros(batch) + for tile_batch in hl.tile(batch): + for i in range(2, 7): # range(begin, end) + out[tile_batch] += x[tile_batch] * i + return out + + x = torch.randn(20, device=DEVICE) + + # Reference: x * sum(range(2, 7)) = x * sum(2,3,4,5,6) = x * 20 + expected = x * 20 + + code, result = code_and_output(range_begin_end_kernel, (x,)) + torch.testing.assert_close(result, expected) + + def test_range_with_step(self): + """Test that range(begin, end, step) works as alias for hl.grid(begin, end, step).""" + + @helion.kernel(use_default_config=True) + def range_step_kernel(x: torch.Tensor) -> torch.Tensor: + batch = x.size(0) + out = x.new_zeros(batch) + for tile_batch in hl.tile(batch): + for i in range(1, 10, 2): # range(begin, end, step) + out[tile_batch] += x[tile_batch] / i + return out + + x = torch.randn(6, device=DEVICE) + + # Reference: x * sum(1/i for i in range(1, 10, 2)) = x * sum(1/1, 1/3, 1/5, 1/7, 1/9) + # = x * (1 + 1/3 + 1/5 + 1/7 + 1/9) = x * sum([1, 1/3, 1/5, 1/7, 1/9]) + reciprocal_sum = sum(1.0 / i for i in range(1, 10, 2)) + expected = x * reciprocal_sum + + code, result = code_and_output(range_step_kernel, (x,)) + torch.testing.assert_close(result, expected) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _range_step_kernel_kernel(out, x, out_stride_0, x_stride_0, batch, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < batch + for offset_1 in range(1, 10, _BLOCK_SIZE_1): + load = tl.load(out + indices_0 * out_stride_0, mask_0, other=0) + load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_0 = offset_1.to(tl.float32) + v_1 = load_1 / v_0 + v_2 = load + v_1 + tl.store(out + indices_0 * out_stride_0, v_2, mask_0) + +def range_step_kernel(x: torch.Tensor): + batch = x.size(0) + out = x.new_zeros(batch) + _BLOCK_SIZE_0 = 8 + _BLOCK_SIZE_1 = 2 + _range_step_kernel_kernel[triton.cdiv(batch, _BLOCK_SIZE_0),](out, x, out.stride(0), x.stride(0), batch, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return out + +def _range_step_kernel_make_precompiler(x: torch.Tensor): + batch = x.size(0) + out = x.new_zeros(batch) + _BLOCK_SIZE_0 = 8 + _BLOCK_SIZE_1 = 2 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_range_step_kernel_kernel)(out, x, out.stride(0), x.stride(0), batch, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""", + ) + + def test_range_with_tensor_size(self): + """Test that range(tensor.size(dim)) works with dynamic tensor dimensions.""" + + @helion.kernel(use_default_config=True) + def range_tensor_size_kernel(x: torch.Tensor) -> torch.Tensor: + batch = x.size(0) + out = x.new_zeros(batch) + for tile_batch in hl.tile(batch): + for _ in range(x.size(1)): # Use tensor dimension in range + out[tile_batch] += x[tile_batch, 0] # Just use first column + return out + + x = torch.randn(8, 5, device=DEVICE) # 8 rows, 5 columns + + # Reference: Each row adds x[row, 0] for x.size(1) times = x[:, 0] * x.size(1) + expected = x[:, 0] * x.size(1) + + code, result = code_and_output(range_tensor_size_kernel, (x,)) + torch.testing.assert_close(result, expected) + if __name__ == "__main__": unittest.main()