Skip to content

Commit feb86dc

Browse files
authored
Expose tile.id (#188)
1 parent 2736ff4 commit feb86dc

File tree

4 files changed

+130
-1
lines changed

4 files changed

+130
-1
lines changed

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .tile_ops import tile_begin as tile_begin
1414
from .tile_ops import tile_block_size as tile_block_size
1515
from .tile_ops import tile_end as tile_end
16+
from .tile_ops import tile_id as tile_id
1617
from .tile_ops import tile_index as tile_index
1718
from .tile_proxy import Tile as Tile
1819
from .tunable_ops import register_block_size as register_block_size

helion/language/tile_ops.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,39 @@ def _(tile: torch.SymInt) -> torch.SymInt:
129129

130130
# since we return tile above, no codegen is needed for this function.
131131
# codegen is handled in _get_symnode()
132+
133+
134+
@_decorators.api(tiles_as_sizes=True)
135+
def tile_id(tile: Tile) -> int:
136+
"""
137+
Retrieve tile_id of a given tile or list of tiles.
138+
This is equivalent to `tile.begin // tile.block_size`.
139+
This can also be written as: `tile.id`.
140+
"""
141+
raise exc.NotInsideKernel
142+
143+
144+
@_decorators.register_fake(tile_id)
145+
def _(tile: torch.SymInt) -> torch.SymInt:
146+
assert isinstance(tile, torch.SymInt)
147+
return CompileEnvironment.current().create_unbacked_symint()
148+
149+
150+
@_decorators.codegen(tile_id)
151+
def _(state: CodegenState) -> ast.AST:
152+
t = state.proxy_arg(0)
153+
env = CompileEnvironment.current()
154+
assert isinstance(t, torch.SymInt)
155+
index = env.get_block_id(t)
156+
assert index is not None
157+
# disable_flatten:
158+
# The functions in this file can't be used in flattened loops.
159+
env.config_spec.flatten_loops.disable_block_id(index)
160+
offset = state.codegen.offset_var(index)
161+
162+
block_size = state.device_function.block_size_var(index)
163+
if block_size is None:
164+
expr_str = offset
165+
else:
166+
expr_str = f"{offset} // {block_size}"
167+
return expr_from_string(expr_str)

helion/language/tile_proxy.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class Tile(torch.Tensor):
3434
Tile's can be used as indices to tensors, e.g. `tensor[tile]`. Tile's
3535
can also be use as sizes for allocations, e.g. `torch.empty([tile])`.
3636
There are also properties such as `tile.index`, `tile.begin`,
37-
`tile.end`, and `tile.block_size` that can be used to retrieve various
37+
`tile.end`, `tile.id` and `tile.block_size` that can be used to retrieve various
3838
information about the tile.
3939
4040
Masking is implicit for tiles, so if the final tile is smaller than
@@ -133,6 +133,15 @@ def block_size(self) -> int:
133133

134134
return tile_block_size(self)
135135

136+
@property
137+
def id(self) -> int:
138+
"""
139+
Alias for hl.tile_id, which retrieves the id of a tile.
140+
"""
141+
from .tile_ops import tile_id
142+
143+
return tile_id(self)
144+
136145

137146
class _CheckForIndexCalls:
138147
"""

test/test_indexing.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,89 @@ def fn(x: torch.Tensor) -> torch.Tensor:
294294
)
295295
torch.testing.assert_close(result, expected)
296296

297+
def test_tile_id(self):
298+
@helion.kernel
299+
def test_tile_id_access(x: torch.Tensor) -> torch.Tensor:
300+
out = torch.zeros_like(x, dtype=torch.int32)
301+
for tile in hl.tile(x.size(0)):
302+
out[tile] = tile.id
303+
return out
304+
305+
x = torch.randn([64], device=DEVICE)
306+
code, result = code_and_output(
307+
test_tile_id_access,
308+
(x,),
309+
block_size=16,
310+
)
311+
expected = torch.arange(4, device=DEVICE, dtype=torch.int32).repeat_interleave(
312+
repeats=16
313+
)
314+
torch.testing.assert_close(result, expected)
315+
code, result = code_and_output(
316+
test_tile_id_access,
317+
(x,),
318+
block_size=1,
319+
)
320+
expected = torch.arange(64, device=DEVICE, dtype=torch.int32)
321+
torch.testing.assert_close(result, expected)
322+
323+
def test_tile_id_1d_indexing(self):
324+
@helion.kernel
325+
def test_tile_id_atomic_add(x: torch.Tensor) -> torch.Tensor:
326+
out = torch.zeros_like(x, dtype=torch.int32)
327+
for tile_m in hl.tile(x.size(0)):
328+
hl.atomic_add(out, [tile_m.id], 1)
329+
return out
330+
331+
x = torch.randn(64, device=DEVICE)
332+
code, result = code_and_output(
333+
test_tile_id_atomic_add,
334+
(x,),
335+
block_size=[
336+
16,
337+
],
338+
)
339+
340+
expected = torch.zeros(64, device=DEVICE, dtype=torch.int32)
341+
expected[:4] = 1
342+
torch.testing.assert_close(result, expected)
343+
code, result = code_and_output(
344+
test_tile_id_atomic_add,
345+
(x,),
346+
block_size=[
347+
1,
348+
],
349+
)
350+
expected = torch.ones(64, device=DEVICE, dtype=torch.int32)
351+
torch.testing.assert_close(result, expected)
352+
353+
@unittest.skip("flatten_loops config assert. issue#185")
354+
def test_tile_id_2d_indexing(self):
355+
@helion.kernel
356+
def test_tile_id_index_st(x: torch.Tensor) -> torch.Tensor:
357+
out = torch.zeros_like(x, dtype=torch.int32)
358+
for tile_m, tile_n in hl.tile(x.size()):
359+
out[tile_m.id, tile_n.id] = 1
360+
return out
361+
362+
x = torch.randn(64, 64, device=DEVICE)
363+
code, result = code_and_output(
364+
test_tile_id_index_st,
365+
(x,),
366+
block_size=[16, 16],
367+
)
368+
369+
expected = torch.zeros(64, 64, device=DEVICE, dtype=torch.int32)
370+
expected[:4, :4] = 1
371+
torch.testing.assert_close(result, expected)
372+
code, result = code_and_output(
373+
test_tile_id_index_st,
374+
(x,),
375+
block_size=[1, 1],
376+
)
377+
expected = torch.ones(64, 64, device=DEVICE, dtype=torch.int32)
378+
torch.testing.assert_close(result, expected)
379+
297380
def test_atomic_add_symint(self):
298381
@helion.kernel(config={"block_size": 32})
299382
def fn(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)