Skip to content

Expose tile.id #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .tile_ops import tile_begin as tile_begin
from .tile_ops import tile_block_size as tile_block_size
from .tile_ops import tile_end as tile_end
from .tile_ops import tile_id as tile_id
from .tile_ops import tile_index as tile_index
from .tile_proxy import Tile as Tile
from .tunable_ops import register_block_size as register_block_size
Expand Down
36 changes: 36 additions & 0 deletions helion/language/tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,39 @@ def _(tile: torch.SymInt) -> torch.SymInt:

# since we return tile above, no codegen is needed for this function.
# codegen is handled in _get_symnode()


@_decorators.api(tiles_as_sizes=True)
def tile_id(tile: Tile) -> int:
"""
Retrieve tile_id of a given tile or list of tiles.
This is equivalent to `tile.begin // tile.block_size`.
This can also be written as: `tile.id`.
"""
raise exc.NotInsideKernel


@_decorators.register_fake(tile_id)
def _(tile: torch.SymInt) -> torch.SymInt:
assert isinstance(tile, torch.SymInt)
return CompileEnvironment.current().create_unbacked_symint()


@_decorators.codegen(tile_id)
def _(state: CodegenState) -> ast.AST:
t = state.proxy_arg(0)
env = CompileEnvironment.current()
assert isinstance(t, torch.SymInt)
index = env.get_block_id(t)
assert index is not None
# disable_flatten:
# The functions in this file can't be used in flattened loops.
env.config_spec.flatten_loops.disable_block_id(index)
offset = state.codegen.offset_var(index)

block_size = state.device_function.block_size_var(index)
if block_size is None:
expr_str = offset
else:
expr_str = f"{offset} // {block_size}"
return expr_from_string(expr_str)
11 changes: 10 additions & 1 deletion helion/language/tile_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Tile(torch.Tensor):
Tile's can be used as indices to tensors, e.g. `tensor[tile]`. Tile's
can also be use as sizes for allocations, e.g. `torch.empty([tile])`.
There are also properties such as `tile.index`, `tile.begin`,
`tile.end`, and `tile.block_size` that can be used to retrieve various
`tile.end`, `tile.id` and `tile.block_size` that can be used to retrieve various
information about the tile.

Masking is implicit for tiles, so if the final tile is smaller than
Expand Down Expand Up @@ -133,6 +133,15 @@ def block_size(self) -> int:

return tile_block_size(self)

@property
def id(self) -> int:
"""
Alias for hl.tile_id, which retrieves the id of a tile.
"""
from .tile_ops import tile_id

return tile_id(self)


class _CheckForIndexCalls:
"""
Expand Down
83 changes: 83 additions & 0 deletions test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,89 @@ def fn(x: torch.Tensor) -> torch.Tensor:
)
torch.testing.assert_close(result, expected)

def test_tile_id(self):
@helion.kernel
def test_tile_id_access(x: torch.Tensor) -> torch.Tensor:
out = torch.zeros_like(x, dtype=torch.int32)
for tile in hl.tile(x.size(0)):
out[tile] = tile.id
return out

x = torch.randn([64], device=DEVICE)
code, result = code_and_output(
test_tile_id_access,
(x,),
block_size=16,
)
expected = torch.arange(4, device=DEVICE, dtype=torch.int32).repeat_interleave(
repeats=16
)
torch.testing.assert_close(result, expected)
code, result = code_and_output(
test_tile_id_access,
(x,),
block_size=1,
)
expected = torch.arange(64, device=DEVICE, dtype=torch.int32)
torch.testing.assert_close(result, expected)

def test_tile_id_1d_indexing(self):
@helion.kernel
def test_tile_id_atomic_add(x: torch.Tensor) -> torch.Tensor:
out = torch.zeros_like(x, dtype=torch.int32)
for tile_m in hl.tile(x.size(0)):
hl.atomic_add(out, [tile_m.id], 1)
return out

x = torch.randn(64, device=DEVICE)
code, result = code_and_output(
test_tile_id_atomic_add,
(x,),
block_size=[
16,
],
)

expected = torch.zeros(64, device=DEVICE, dtype=torch.int32)
expected[:4] = 1
torch.testing.assert_close(result, expected)
code, result = code_and_output(
test_tile_id_atomic_add,
(x,),
block_size=[
1,
],
)
expected = torch.ones(64, device=DEVICE, dtype=torch.int32)
torch.testing.assert_close(result, expected)

@unittest.skip("flatten_loops config assert. issue#185")
def test_tile_id_2d_indexing(self):
@helion.kernel
def test_tile_id_index_st(x: torch.Tensor) -> torch.Tensor:
out = torch.zeros_like(x, dtype=torch.int32)
for tile_m, tile_n in hl.tile(x.size()):
out[tile_m.id, tile_n.id] = 1
return out

x = torch.randn(64, 64, device=DEVICE)
code, result = code_and_output(
test_tile_id_index_st,
(x,),
block_size=[16, 16],
)

expected = torch.zeros(64, 64, device=DEVICE, dtype=torch.int32)
expected[:4, :4] = 1
torch.testing.assert_close(result, expected)
code, result = code_and_output(
test_tile_id_index_st,
(x,),
block_size=[1, 1],
)
expected = torch.ones(64, 64, device=DEVICE, dtype=torch.int32)
torch.testing.assert_close(result, expected)

def test_atomic_add_symint(self):
@helion.kernel(config={"block_size": 32})
def fn(x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading