From 030df03c0bec998354e90224c354dd2149611064 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 20 Jun 2025 21:17:32 -0700 Subject: [PATCH] Add support for hl.grid(begin, end, step) stack-info: PR: https://github.com/pytorch-labs/helion/pull/211, branch: jansel/stack/62 --- helion/_compiler/program_id.py | 2 +- helion/_compiler/tile_strategy.py | 132 ++++++----- helion/_compiler/type_propagation.py | 48 ++-- helion/language/loops.py | 99 +++++--- test/test_grid.py | 324 +++++++++++++++++++++++++++ 5 files changed, 498 insertions(+), 107 deletions(-) diff --git a/helion/_compiler/program_id.py b/helion/_compiler/program_id.py index 38c12cc9..6d85ca67 100644 --- a/helion/_compiler/program_id.py +++ b/helion/_compiler/program_id.py @@ -119,7 +119,7 @@ def combined_device_cdiv(self, state: CodegenState) -> str: return " * ".join(pid.device_cdiv(state) for pid in self.pids) def combined_host_cdiv(self) -> str: - return " * ".join(pid.host_cdiv() for pid in self.pids) + return " * ".join(f"({pid.host_cdiv()})" for pid in self.pids) def codegen(self, state: CodegenState) -> None: pid_var = self.shared_pid_var or "tl.program_id(0)" diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index 2bd87fea..c41db5d7 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -138,6 +138,45 @@ def codegen_preamble(self, state: CodegenState) -> None: def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]: raise NotImplementedError + def _create_block_id_info_dict( + self, state: CodegenState, use_proxy_ends: bool = False + ) -> dict[int, LoopDimInfo]: + """Helper to create block_id_to_info dictionary with end bounds.""" + env = CompileEnvironment.current() + block_id_to_info = {} + + if use_proxy_ends: + _, _, proxy_ends, _ = state.proxy_args + assert isinstance(proxy_ends, list) + for block_idx, end in zip(self.block_ids, proxy_ends, strict=True): + if isinstance(end, (int, torch.SymInt)): + end_expr = _to_sympy(end) + else: + end_expr = None + block_id_to_info[block_idx] = LoopDimInfo( + end_var_name=None, end_expr=end_expr + ) + else: + for block_id in self.block_ids: + end_expr = env.block_sizes[block_id].numel + end_var_name = state.sympy_expr(end_expr) + block_id_to_info[block_id] = LoopDimInfo( + end_var_name=end_var_name, end_expr=end_expr + ) + + return block_id_to_info + + def _setup_block_size_constexpr( + self, state: CodegenState, block_size_var: str, block_size: SymIntLike + ) -> None: + """Helper to setup constexpr block size variable on host.""" + if state.device_function.constexpr_arg(block_size_var): + state.codegen.host_statements.append( + statement_from_string( + f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}" + ) + ) + class BlockSizeTileStrategy(TileStrategy): def __init__( @@ -265,19 +304,15 @@ def block_size_var(self, block_idx: int) -> str: def _codegen_common( self, state: CodegenState ) -> tuple[str, str, sympy.Expr, list[ast.AST]]: + offsets_var = self.new_var("offsets", dce=True) + block_size_var = self.block_size_var(-1) + self._setup_block_size_constexpr(state, block_size_var, self.block_size) block_ids = self.block_ids env = CompileEnvironment.current() total_numel = sympy.S.One - offsets_var = self.new_var("offsets", dce=True) - block_size_var = self.block_size_var(-1) statements = [] - if state.device_function.constexpr_arg(block_size_var): - block_size_str = HostFunction.current().literal_expr(self.block_size) - state.codegen.host_statements.append( - statement_from_string(f"{block_size_var} = {block_size_str}") - ) + for i, block_idx in enumerate(self._reorder(block_ids)): - # need to get the block size numel = env.block_sizes[block_idx].numel block_index_var = self.index_var(block_idx) expr = offsets_var @@ -316,13 +351,7 @@ def codegen_grid(self) -> ast.AST: state.device_function.set_pid(TmpPid()) - block_id_to_info = {} - for block_id in self.block_ids: - end_expr = env.block_sizes[block_id].numel - end_var_name = state.sympy_expr(end_expr) - block_id_to_info[block_id] = LoopDimInfo( - end_var_name=end_var_name, end_expr=end_expr - ) + block_id_to_info = self._create_block_id_info_dict(state) return DeviceGridState(self, block_id_to_info=block_id_to_info) def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: @@ -348,18 +377,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: orelse=[], type_comment=None, ) - # Create block_id_to_info with end bounds - block_id_to_info = {} - _, _, ends, _ = state.proxy_args - assert isinstance(ends, list) - for block_idx, end in zip(self.block_ids, ends, strict=True): - if isinstance(end, (int, torch.SymInt)): - end_expr = _to_sympy(end) - else: - end_expr = None - block_id_to_info[block_idx] = LoopDimInfo( - end_var_name=None, end_expr=end_expr - ) + block_id_to_info = self._create_block_id_info_dict(state, use_proxy_ends=True) return DeviceLoopState( self, @@ -430,8 +448,6 @@ def __init__( def codegen_grid(self, state: CodegenState) -> DeviceGridState: block_ids = self.block_ids env = CompileEnvironment.current() - device_function = state.device_function - dtype = env.triton_index_type() block_sizes = self.block_size assert len(block_sizes) == len(block_ids) if isinstance(state.device_function.pid, SharedProgramID): @@ -440,31 +456,47 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState: pids = self.select_pid_strategy() if isinstance(state.device_function.pid, SharedProgramID): pids.shared_pid_var = state.device_function.pid.shared_pid_var - for i, (block_idx, block_size) in enumerate( - reversed(self._reorder([*zip(block_ids, block_sizes, strict=True)])) + + assert state.ast_args is None + assert len(state.proxy_args) == 3 + if state.proxy_args[1] is None: + begins = [0] * len(block_ids) + else: + begins = state.proxy_args[0] + if not isinstance(begins, (list, tuple)): + begins = [begins] + assert len(begins) == len(block_ids) + + for i, (block_idx, block_size, begin) in enumerate( + reversed(self._reorder([*zip(block_ids, block_sizes, begins, strict=True)])) ): numel = env.block_sizes[block_idx].numel + device_function = state.device_function + dtype = env.triton_index_type() offset_var = self.offset_var(block_idx) index_var = self.index_var(block_idx) pid_var = device_function.new_var(f"pid_{i}", dce=True) + + begin_offset_expr = "" + if begin != 0: + begin_ast = self._to_ast(begin, to_dtype=dtype) + begin_offset_expr = ( + f"{state.codegen.lift(begin_ast, dce=True, prefix='begin').id} + " + ) + if block_size != 1: block_size_var = self.block_size_var(block_idx) assert block_size_var is not None - # TODO(jansel): need to check for conflict with user variable names since block_size_var is on host - if state.device_function.constexpr_arg(block_size_var): - state.codegen.host_statements.append( - statement_from_string( - f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}" - ) - ) - state.add_statement(f"{offset_var} = {pid_var} * {block_size_var}") + self._setup_block_size_constexpr(state, block_size_var, block_size) + state.add_statement( + f"{offset_var} = {begin_offset_expr}{pid_var} * {block_size_var}" + ) state.add_statement( f"{index_var} = ({offset_var} + tl.arange(0, ({block_size_var}))).to({dtype})" ) else: block_size_var = "1" - dtype = env.triton_index_type() - state.add_statement(f"{offset_var} = {pid_var}") + state.add_statement(f"{offset_var} = {begin_offset_expr}{pid_var}") state.add_statement( f"{index_var} = {offset_var} + tl.zeros([1], {dtype})" ) @@ -483,14 +515,7 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState: else: state.device_function.set_pid(pids) - # Extract end_var_name from end bound expressions - block_id_to_info = {} - for block_id in self.block_ids: - end_expr = env.block_sizes[block_id].numel - end_var_name = state.sympy_expr(end_expr) - block_id_to_info[block_id] = LoopDimInfo( - end_var_name=end_var_name, end_expr=end_expr - ) + block_id_to_info = self._create_block_id_info_dict(state) return DeviceGridState(self, block_id_to_info=block_id_to_info) def select_pid_strategy(self) -> ProgramIDs: @@ -509,6 +534,8 @@ def _to_ast(self, x: object, to_dtype: str | None = None) -> ast.AST: from .device_function import DeviceFunction return expr_from_string(DeviceFunction.current().sympy_expr(x)) + if isinstance(x, torch.SymInt): + return self._to_ast(x._sympy_()) raise NotImplementedError(f"{type(x)} is not implemented.") def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: @@ -534,12 +561,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: if block_size != 1: block_size_var = self.block_size_var(block_idx) assert block_size_var is not None - if state.device_function.constexpr_arg(block_size_var): - state.codegen.host_statements.append( - statement_from_string( - f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}" - ) - ) + self._setup_block_size_constexpr(state, block_size_var, block_size) else: block_size_var = "1" end_var_name = state.codegen.lift( diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 476e7bc5..3619c32a 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -982,29 +982,24 @@ def proxy(self) -> object: @staticmethod def allocate( - numel: int | torch.SymInt | AutoSize | None, origin: Origin - ) -> TileIndexType: - env = CompileEnvironment.current() - block_id = env.allocate_block_size(numel, source=LoopSpecBlockSizeSource()) - env.config_spec.block_sizes.append( - BlockSizeSpec( - block_id=block_id, - size_hint=_get_hint(numel), - ) - ) - return TileIndexType(origin, block_id) - - @staticmethod - def allocate_fixed( numel: int | torch.SymInt | AutoSize | None, - block_size: int | torch.SymInt, origin: Origin, + block_size: int | torch.SymInt | None = None, ) -> TileIndexType: env = CompileEnvironment.current() - return TileIndexType( - origin, - env.allocate_block_size(numel, source=FixedBlockSizeSource(block_size)), - ) + if block_size is None: + block_id = env.allocate_block_size(numel, source=LoopSpecBlockSizeSource()) + env.config_spec.block_sizes.append( + BlockSizeSpec( + block_id=block_id, + size_hint=_get_hint(numel), + ) + ) + else: + block_id = env.allocate_block_size( + numel, source=FixedBlockSizeSource(block_size) + ) + return TileIndexType(origin, block_id) def merge(self, other: TypeInfo) -> TypeInfo: if isinstance(other, TileIndexType): @@ -1024,7 +1019,12 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo: class GridIndexType(SymIntType): block_id: int - def __init__(self, origin: Origin, sym: torch.SymInt, block_id: int) -> None: + def __init__( + self, + origin: Origin, + sym: torch.SymInt, + block_id: int, + ) -> None: super().__init__(origin, sym) self.block_id = block_id @@ -1032,13 +1032,17 @@ def __str__(self) -> str: # pragma: no cover – debug helper return f"{type(self).__name__}({self.block_id})" @staticmethod - def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType: + def allocate( + numel: int | torch.SymInt, + origin: Origin, + step: int | torch.SymInt = 1, + ) -> GridIndexType: from .._compiler.compile_environment import CompileEnvironment from .host_function import HostFunction from .host_function import SymbolOrigin env = CompileEnvironment.current() - block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(1)) + block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(step)) # assign this a new unbacked symbol since this should be treated like a scalar rather than a tile sym = env.create_unbacked_symint() HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin( diff --git a/helion/language/loops.py b/helion/language/loops.py index 821a524c..6f3ba28f 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -200,13 +200,13 @@ def _( if bs is None: results.append(TileIndexType.allocate(size, origin)) elif isinstance(bs, int): - results.append(TileIndexType.allocate_fixed(size, bs, origin)) + results.append(TileIndexType.allocate(size, origin, bs)) elif isinstance(bs, torch.SymInt): from helion._compiler.compile_environment import CompileEnvironment index = CompileEnvironment.current().get_block_id(bs) if index is None: - results.append(TileIndexType.allocate_fixed(size, bs, origin)) + results.append(TileIndexType.allocate(size, origin, bs)) else: results.append(TileIndexType(origin=origin, block_id=index)) CompileEnvironment.current().block_sizes[index].mark_alternate_size( @@ -289,63 +289,104 @@ def _codegen_loop_helper( @_decorators.api( is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True ) -def grid(sizes: int, /) -> Iterator[torch.SymInt]: ... +def grid( + begin_or_end: int | torch.Tensor, + end_or_none: int | torch.Tensor | None = None, + /, + step: object = None, +) -> Iterator[torch.SymInt]: ... @overload @_decorators.api( is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True ) -def grid(sizes: Sequence[int], /) -> Iterator[Sequence[torch.SymInt]]: ... +def grid( + begin_or_end: Sequence[int | torch.Tensor], + end_or_none: Sequence[int | torch.Tensor] | None = None, + /, + step: object = None, +) -> Iterator[Sequence[torch.SymInt]]: ... @_decorators.api( is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True ) def grid( - sizes: int | Sequence[int], + begin_or_end: int | torch.Tensor | Sequence[int | torch.Tensor], + end_or_none: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None, /, + step: object = None, ) -> Iterator[torch.SymInt] | Iterator[Sequence[torch.SymInt]]: # type: ignore[type-arg] - """Iterate over *individual* indices of the given iteration space. + """Iterate over individual indices of the given iteration space. Semantics are equivalent to - for i in hl.tile(size, block_size=1): + for i in hl.tile(...): ... but `i` will be a scalar (`torch.SymInt`), not a 1-element tensor. - """ + When used at the top level of a function, this becomes the grid of the kernel. + Otherwise, it becomes a loop in the output kernel. + + Similar to `range()` there are multiple forms of this function: + grid(end) iterates from 0 to `end - 1`, with step size 1. + grid(begin, end) iterates from `begin` to `end - 1`, with step size 1. + grid(begin, end, step) iterates from `begin` to `end - 1`, with the given step size. + grid(end, step=step) iterates from 0 to `end - 1`, with the given step size. + """ raise exc.NotInsideKernel @_decorators.type_propagation(grid) -def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo: +def _( + begin_or_end: TypeInfo, + end_or_none: TypeInfo | None = None, + /, + step: TypeInfo | None = None, + *, + origin: Origin, +) -> TypeInfo: parent = ExtendedAST.current()[-2] if not isinstance(parent, ast.For): raise exc.LoopFunctionNotInFor("grid") - try: - proxy_sizes = sizes.proxy() - if not ( - isinstance(proxy_sizes, (int, torch.SymInt)) - or ( - isinstance(proxy_sizes, (list, tuple)) - and all(isinstance(x, (int, torch.SymInt)) for x in proxy_sizes) - ) - ): - raise NotImplementedError - except NotImplementedError: - raise exc.TypeInferenceError( - f"grid() expected int or list[int], got {sizes!s}" - ) from None + begin, end = _normalize_begin_end(begin_or_end, end_or_none, origin=origin) + proxy_begin = _to_proxy(begin) + proxy_end = _to_proxy(end) + _check_matching(proxy_begin, proxy_end) + if _not_none(step): + proxy_step = Tile._tiles_to_sizes(_to_proxy(step)) + _check_matching(proxy_end, proxy_step) + else: + proxy_step = begin.tree_map(lambda n: None) - if isinstance(proxy_sizes, (int, torch.SymInt)): - return IterType(origin, GridIndexType.allocate(proxy_sizes, origin)) + if unpack := not isinstance(proxy_end, (list, tuple)): + proxy_begin = [proxy_begin] + proxy_end = [proxy_end] + proxy_step = [proxy_step] + + results = [] + for begin_part, end_part, step_part in zip( + proxy_begin, proxy_end, proxy_step, strict=True + ): + size = end_part - begin_part + if isinstance(size, torch.Tensor): + size = None # data dependent size + if step_part is None: + step_part = 1 + results.append(GridIndexType.allocate(size, origin, step_part)) - assert isinstance(proxy_sizes, (list, tuple)) - elements = [GridIndexType.allocate(s, origin) for s in proxy_sizes] - _add_config_choices([x.block_id for x in elements]) - return IterType(origin, SequenceType(origin, elements)) + _add_config_choices( + [x.block_id for x in results], + is_tile=False, + has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin), + ) + if unpack: + (result,) = results + else: + result = SequenceType(origin, results) + return IterType(origin, result) @_decorators.codegen(grid) diff --git a/test/test_grid.py b/test/test_grid.py index 0ade4047..0a024abd 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -392,6 +392,330 @@ def _grid_2d_idx_nested_make_precompiler(x: torch.Tensor, y: torch.Tensor): return make_precompiler(_grid_2d_idx_nested_kernel)(x, y, out, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_4, num_warps=4, num_stages=3)""", ) + def test_grid_begin_end(self): + @helion.kernel(use_default_config=True) + def grid_begin_end(x: torch.Tensor) -> torch.Tensor: + n = x.size(0) + out = torch.zeros_like(x) + for i in hl.grid(2, n - 2): # grid(begin, end) + out[i] = x[i] * 2 + return out + + def grid_begin_end_pytorch(x: torch.Tensor) -> torch.Tensor: + n = x.size(0) + out = torch.zeros_like(x) + for i in range(2, n - 2): + out[i] = x[i] * 2 + return out + + x = torch.randn([16], device=DEVICE, dtype=torch.float32) + code, result = code_and_output(grid_begin_end, (x,)) + torch.testing.assert_close(result, grid_begin_end_pytorch(x)) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _grid_begin_end_kernel(x, out, out_stride_0, x_stride_0): + pid_0 = tl.program_id(0) + begin_0 = 2 + offset_0 = begin_0 + pid_0 + load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None) + v_0 = 2.0 + v_1 = load * v_0 + tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None) + +def grid_begin_end(x: torch.Tensor): + n = x.size(0) + out = torch.zeros_like(x) + _grid_begin_end_kernel[-4 + n,](x, out, out.stride(0), x.stride(0), num_warps=4, num_stages=3) + return out + +def _grid_begin_end_make_precompiler(x: torch.Tensor): + n = x.size(0) + out = torch.zeros_like(x) + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_grid_begin_end_kernel)(x, out, out.stride(0), x.stride(0), num_warps=4, num_stages=3)""", + ) + + def test_grid_begin_end_step(self): + @helion.kernel(use_default_config=True) + def grid_begin_end_step(x: torch.Tensor) -> torch.Tensor: + n = x.size(0) + out = torch.zeros_like(x) + for i in hl.grid(0, n, 2): # grid(begin, end, step) + out[i] = x[i] * 2 + return out + + def grid_begin_end_step_pytorch(x: torch.Tensor) -> torch.Tensor: + n = x.size(0) + out = torch.zeros_like(x) + for i in range(0, n, 2): + out[i] = x[i] * 2 + return out + + x = torch.randn([16], device=DEVICE, dtype=torch.float32) + code, result = code_and_output(grid_begin_end_step, (x,)) + torch.testing.assert_close(result, grid_begin_end_step_pytorch(x)) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _grid_begin_end_step_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None) + v_0 = 2.0 + v_1 = load * v_0 + tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None) + +def grid_begin_end_step(x: torch.Tensor): + n = x.size(0) + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 2 + _grid_begin_end_step_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +def _grid_begin_end_step_make_precompiler(x: torch.Tensor): + n = x.size(0) + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 2 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_grid_begin_end_step_kernel)(x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", + ) + + def test_grid_end_step_kwarg(self): + @helion.kernel(use_default_config=True) + def grid_end_step_kwarg(x: torch.Tensor) -> torch.Tensor: + n = x.size(0) + out = torch.zeros_like(x) + for i in hl.grid(n, step=2): # grid(end, step=step) + out[i] = x[i] * 2 + return out + + def grid_end_step_kwarg_pytorch(x: torch.Tensor) -> torch.Tensor: + n = x.size(0) + out = torch.zeros_like(x) + for i in range(0, n, 2): + out[i] = x[i] * 2 + return out + + x = torch.randn([16], device=DEVICE, dtype=torch.float32) + code, result = code_and_output(grid_end_step_kwarg, (x,)) + torch.testing.assert_close(result, grid_end_step_kwarg_pytorch(x)) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _grid_end_step_kwarg_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None) + v_0 = 2.0 + v_1 = load * v_0 + tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None) + +def grid_end_step_kwarg(x: torch.Tensor): + n = x.size(0) + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 2 + _grid_end_step_kwarg_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +def _grid_end_step_kwarg_make_precompiler(x: torch.Tensor): + n = x.size(0) + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 2 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_grid_end_step_kwarg_kernel)(x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", + ) + + def test_grid_multidim_begin_end(self): + @helion.kernel(use_default_config=True) + def grid_multidim_begin_end(x: torch.Tensor) -> torch.Tensor: + m, n = x.size() + out = torch.zeros_like(x) + for i, j in hl.grid( + [1, 1], [m - 1, n - 1] + ): # multidimensional grid(begin, end) + out[i, j] = x[i, j] * 2 + return out + + def grid_multidim_begin_end_pytorch(x: torch.Tensor) -> torch.Tensor: + m, n = x.size() + out = torch.zeros_like(x) + for i in range(1, m - 1): + for j in range(1, n - 1): + out[i, j] = x[i, j] * 2 + return out + + x = torch.randn([8, 8], device=DEVICE, dtype=torch.float32) + code, result = code_and_output(grid_multidim_begin_end, (x,)) + torch.testing.assert_close(result, grid_multidim_begin_end_pytorch(x)) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _grid_multidim_begin_end_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, m): + num_blocks_0 = -2 + m + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + begin_0 = 1 + offset_0 = begin_0 + pid_0 + begin_1 = 1 + offset_1 = begin_1 + pid_1 + load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None) + v_0 = 2.0 + v_1 = load * v_0 + tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None) + +def grid_multidim_begin_end(x: torch.Tensor): + m, n = x.size() + out = torch.zeros_like(x) + _grid_multidim_begin_end_kernel[(-2 + m) * (-2 + n),](x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, num_warps=4, num_stages=3) + return out + +def _grid_multidim_begin_end_make_precompiler(x: torch.Tensor): + m, n = x.size() + out = torch.zeros_like(x) + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_grid_multidim_begin_end_kernel)(x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, num_warps=4, num_stages=3)""", + ) + + def test_grid_multidim_begin_end_step(self): + @helion.kernel(use_default_config=True) + def grid_multidim_begin_end_step(x: torch.Tensor) -> torch.Tensor: + m, n = x.size() + out = torch.zeros_like(x) + for i, j in hl.grid( + [0, 0], [m, n], [2, 3] + ): # multidimensional grid(begin, end, step) + out[i, j] = x[i, j] * 2 + return out + + def grid_multidim_begin_end_step_pytorch(x: torch.Tensor) -> torch.Tensor: + m, n = x.size() + out = torch.zeros_like(x) + for i in range(0, m, 2): + for j in range(0, n, 3): + out[i, j] = x[i, j] * 2 + return out + + x = torch.randn([8, 9], device=DEVICE, dtype=torch.float32) + code, result = code_and_output(grid_multidim_begin_end_step, (x,)) + torch.testing.assert_close(result, grid_multidim_begin_end_step_pytorch(x)) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _grid_multidim_begin_end_step_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, m, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None) + v_0 = 2.0 + v_1 = load * v_0 + tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None) + +def grid_multidim_begin_end_step(x: torch.Tensor): + m, n = x.size() + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 2 + _BLOCK_SIZE_1 = 3 + _grid_multidim_begin_end_step_kernel[triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),](x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return out + +def _grid_multidim_begin_end_step_make_precompiler(x: torch.Tensor): + m, n = x.size() + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 2 + _BLOCK_SIZE_1 = 3 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_grid_multidim_begin_end_step_kernel)(x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""", + ) + + def test_tile_begin_end(self): + @helion.kernel(use_default_config=True) + def tile_begin_end(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x) + for tile in hl.tile(2, 10): # tile(begin, end) - simple range [2, 10) + out[tile] = x[tile] * 2 + return out + + def tile_begin_end_pytorch(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x) + # Tile should process all indices in range [2, 10) in chunks + for i in range(2, 10): + out[i] = x[i] * 2 + return out + + x = torch.randn([15], device=DEVICE, dtype=torch.float32) + code, result = code_and_output(tile_begin_end, (x,), block_size=4) + torch.testing.assert_close(result, tile_begin_end_pytorch(x)) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _tile_begin_end_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + begin_0 = 2 + offset_0 = begin_0 + pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + load = tl.load(x + indices_0 * x_stride_0, None) + v_0 = 2.0 + v_1 = load * v_0 + tl.store(out + indices_0 * out_stride_0, v_1, None) + +def tile_begin_end(x: torch.Tensor): + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 4 + _tile_begin_end_kernel[triton.cdiv(8, _BLOCK_SIZE_0),](x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +def _tile_begin_end_make_precompiler(x: torch.Tensor): + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 4 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_tile_begin_end_kernel)(x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", + ) + if __name__ == "__main__": unittest.main()