diff --git a/helion/language/__init__.py b/helion/language/__init__.py index c23d0593..56dcb8da 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -13,6 +13,7 @@ from .tile_ops import tile_begin as tile_begin from .tile_ops import tile_block_size as tile_block_size from .tile_ops import tile_end as tile_end +from .tile_ops import tile_id as tile_id from .tile_ops import tile_index as tile_index from .tile_proxy import Tile as Tile from .tunable_ops import register_block_size as register_block_size diff --git a/helion/language/tile_ops.py b/helion/language/tile_ops.py index 9da4b0a7..0ee02dc8 100644 --- a/helion/language/tile_ops.py +++ b/helion/language/tile_ops.py @@ -129,3 +129,39 @@ def _(tile: torch.SymInt) -> torch.SymInt: # since we return tile above, no codegen is needed for this function. # codegen is handled in _get_symnode() + + +@_decorators.api(tiles_as_sizes=True) +def tile_id(tile: Tile) -> int: + """ + Retrieve tile_id of a given tile or list of tiles. + This is equivalent to `tile.begin // tile.block_size`. + This can also be written as: `tile.id`. + """ + raise exc.NotInsideKernel + + +@_decorators.register_fake(tile_id) +def _(tile: torch.SymInt) -> torch.SymInt: + assert isinstance(tile, torch.SymInt) + return CompileEnvironment.current().create_unbacked_symint() + + +@_decorators.codegen(tile_id) +def _(state: CodegenState) -> ast.AST: + t = state.proxy_arg(0) + env = CompileEnvironment.current() + assert isinstance(t, torch.SymInt) + index = env.get_block_id(t) + assert index is not None + # disable_flatten: + # The functions in this file can't be used in flattened loops. + env.config_spec.flatten_loops.disable_block_id(index) + offset = state.codegen.offset_var(index) + + block_size = state.device_function.block_size_var(index) + if block_size is None: + expr_str = offset + else: + expr_str = f"{offset} // {block_size}" + return expr_from_string(expr_str) diff --git a/helion/language/tile_proxy.py b/helion/language/tile_proxy.py index d3221c47..1dc69aff 100644 --- a/helion/language/tile_proxy.py +++ b/helion/language/tile_proxy.py @@ -34,7 +34,7 @@ class Tile(torch.Tensor): Tile's can be used as indices to tensors, e.g. `tensor[tile]`. Tile's can also be use as sizes for allocations, e.g. `torch.empty([tile])`. There are also properties such as `tile.index`, `tile.begin`, - `tile.end`, and `tile.block_size` that can be used to retrieve various + `tile.end`, `tile.id` and `tile.block_size` that can be used to retrieve various information about the tile. Masking is implicit for tiles, so if the final tile is smaller than @@ -133,6 +133,15 @@ def block_size(self) -> int: return tile_block_size(self) + @property + def id(self) -> int: + """ + Alias for hl.tile_id, which retrieves the id of a tile. + """ + from .tile_ops import tile_id + + return tile_id(self) + class _CheckForIndexCalls: """ diff --git a/test/test_indexing.py b/test/test_indexing.py index e1f385df..1963627b 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -294,6 +294,89 @@ def fn(x: torch.Tensor) -> torch.Tensor: ) torch.testing.assert_close(result, expected) + def test_tile_id(self): + @helion.kernel + def test_tile_id_access(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x, dtype=torch.int32) + for tile in hl.tile(x.size(0)): + out[tile] = tile.id + return out + + x = torch.randn([64], device=DEVICE) + code, result = code_and_output( + test_tile_id_access, + (x,), + block_size=16, + ) + expected = torch.arange(4, device=DEVICE, dtype=torch.int32).repeat_interleave( + repeats=16 + ) + torch.testing.assert_close(result, expected) + code, result = code_and_output( + test_tile_id_access, + (x,), + block_size=1, + ) + expected = torch.arange(64, device=DEVICE, dtype=torch.int32) + torch.testing.assert_close(result, expected) + + def test_tile_id_1d_indexing(self): + @helion.kernel + def test_tile_id_atomic_add(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x, dtype=torch.int32) + for tile_m in hl.tile(x.size(0)): + hl.atomic_add(out, [tile_m.id], 1) + return out + + x = torch.randn(64, device=DEVICE) + code, result = code_and_output( + test_tile_id_atomic_add, + (x,), + block_size=[ + 16, + ], + ) + + expected = torch.zeros(64, device=DEVICE, dtype=torch.int32) + expected[:4] = 1 + torch.testing.assert_close(result, expected) + code, result = code_and_output( + test_tile_id_atomic_add, + (x,), + block_size=[ + 1, + ], + ) + expected = torch.ones(64, device=DEVICE, dtype=torch.int32) + torch.testing.assert_close(result, expected) + + @unittest.skip("flatten_loops config assert. issue#185") + def test_tile_id_2d_indexing(self): + @helion.kernel + def test_tile_id_index_st(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x, dtype=torch.int32) + for tile_m, tile_n in hl.tile(x.size()): + out[tile_m.id, tile_n.id] = 1 + return out + + x = torch.randn(64, 64, device=DEVICE) + code, result = code_and_output( + test_tile_id_index_st, + (x,), + block_size=[16, 16], + ) + + expected = torch.zeros(64, 64, device=DEVICE, dtype=torch.int32) + expected[:4, :4] = 1 + torch.testing.assert_close(result, expected) + code, result = code_and_output( + test_tile_id_index_st, + (x,), + block_size=[1, 1], + ) + expected = torch.ones(64, 64, device=DEVICE, dtype=torch.int32) + torch.testing.assert_close(result, expected) + def test_atomic_add_symint(self): @helion.kernel(config={"block_size": 32}) def fn(x: torch.Tensor) -> torch.Tensor: