Skip to content

Commit 499bb17

Browse files
committed
Add support for hl.grid(begin, end, step)
stack-info: PR: #211, branch: jansel/stack/62
1 parent 61c7c18 commit 499bb17

File tree

5 files changed

+448
-56
lines changed

5 files changed

+448
-56
lines changed

helion/_compiler/program_id.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def combined_device_cdiv(self, state: CodegenState) -> str:
119119
return " * ".join(pid.device_cdiv(state) for pid in self.pids)
120120

121121
def combined_host_cdiv(self) -> str:
122-
return " * ".join(pid.host_cdiv() for pid in self.pids)
122+
return " * ".join(f"({pid.host_cdiv()})" for pid in self.pids)
123123

124124
def codegen(self, state: CodegenState) -> None:
125125
pid_var = self.shared_pid_var or "tl.program_id(0)"

helion/_compiler/tile_strategy.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,13 +440,32 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
440440
pids = self.select_pid_strategy()
441441
if isinstance(state.device_function.pid, SharedProgramID):
442442
pids.shared_pid_var = state.device_function.pid.shared_pid_var
443-
for i, (block_idx, block_size) in enumerate(
444-
reversed(self._reorder([*zip(block_ids, block_sizes, strict=True)]))
443+
444+
assert state.ast_args is None
445+
assert len(state.proxy_args) == 3
446+
if state.proxy_args[1] is None:
447+
begins = [0] * len(block_ids)
448+
else:
449+
begins = state.proxy_args[0]
450+
if not isinstance(begins, (list, tuple)):
451+
begins = [begins]
452+
assert len(begins) == len(block_ids)
453+
454+
for i, (block_idx, block_size, begin) in enumerate(
455+
reversed(self._reorder([*zip(block_ids, block_sizes, begins, strict=True)]))
445456
):
446457
numel = env.block_sizes[block_idx].numel
447458
offset_var = self.offset_var(block_idx)
448459
index_var = self.index_var(block_idx)
449460
pid_var = device_function.new_var(f"pid_{i}", dce=True)
461+
462+
begin_offset_expr = ""
463+
if begin != 0:
464+
begin_ast = self._to_ast(begin, to_dtype=dtype)
465+
begin_offset_expr = (
466+
f"{state.codegen.lift(begin_ast, dce=True, prefix='begin').id} + "
467+
)
468+
450469
if block_size != 1:
451470
block_size_var = self.block_size_var(block_idx)
452471
assert block_size_var is not None
@@ -457,14 +476,16 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
457476
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
458477
)
459478
)
460-
state.add_statement(f"{offset_var} = {pid_var} * {block_size_var}")
479+
state.add_statement(
480+
f"{offset_var} = {begin_offset_expr}{pid_var} * {block_size_var}"
481+
)
461482
state.add_statement(
462483
f"{index_var} = ({offset_var} + tl.arange(0, ({block_size_var}))).to({dtype})"
463484
)
464485
else:
465486
block_size_var = "1"
466487
dtype = env.triton_index_type()
467-
state.add_statement(f"{offset_var} = {pid_var}")
488+
state.add_statement(f"{offset_var} = {begin_offset_expr}{pid_var}")
468489
state.add_statement(
469490
f"{index_var} = {offset_var} + tl.zeros([1], {dtype})"
470491
)
@@ -509,6 +530,8 @@ def _to_ast(self, x: object, to_dtype: str | None = None) -> ast.AST:
509530
from .device_function import DeviceFunction
510531

511532
return expr_from_string(DeviceFunction.current().sympy_expr(x))
533+
if isinstance(x, torch.SymInt):
534+
return self._to_ast(x._sympy_())
512535
raise NotImplementedError(f"{type(x)} is not implemented.")
513536

514537
def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:

helion/_compiler/type_propagation.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -982,29 +982,24 @@ def proxy(self) -> object:
982982

983983
@staticmethod
984984
def allocate(
985-
numel: int | torch.SymInt | AutoSize | None, origin: Origin
986-
) -> TileIndexType:
987-
env = CompileEnvironment.current()
988-
block_id = env.allocate_block_size(numel, source=LoopSpecBlockSizeSource())
989-
env.config_spec.block_sizes.append(
990-
BlockSizeSpec(
991-
block_id=block_id,
992-
size_hint=_get_hint(numel),
993-
)
994-
)
995-
return TileIndexType(origin, block_id)
996-
997-
@staticmethod
998-
def allocate_fixed(
999985
numel: int | torch.SymInt | AutoSize | None,
1000-
block_size: int | torch.SymInt,
1001986
origin: Origin,
987+
block_size: int | torch.SymInt | None = None,
1002988
) -> TileIndexType:
1003989
env = CompileEnvironment.current()
1004-
return TileIndexType(
1005-
origin,
1006-
env.allocate_block_size(numel, source=FixedBlockSizeSource(block_size)),
1007-
)
990+
if block_size is None:
991+
block_id = env.allocate_block_size(numel, source=LoopSpecBlockSizeSource())
992+
env.config_spec.block_sizes.append(
993+
BlockSizeSpec(
994+
block_id=block_id,
995+
size_hint=_get_hint(numel),
996+
)
997+
)
998+
else:
999+
block_id = env.allocate_block_size(
1000+
numel, source=FixedBlockSizeSource(block_size)
1001+
)
1002+
return TileIndexType(origin, block_id)
10081003

10091004
def merge(self, other: TypeInfo) -> TypeInfo:
10101005
if isinstance(other, TileIndexType):
@@ -1024,21 +1019,30 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
10241019
class GridIndexType(SymIntType):
10251020
block_id: int
10261021

1027-
def __init__(self, origin: Origin, sym: torch.SymInt, block_id: int) -> None:
1022+
def __init__(
1023+
self,
1024+
origin: Origin,
1025+
sym: torch.SymInt,
1026+
block_id: int,
1027+
) -> None:
10281028
super().__init__(origin, sym)
10291029
self.block_id = block_id
10301030

10311031
def __str__(self) -> str: # pragma: no cover – debug helper
10321032
return f"{type(self).__name__}({self.block_id})"
10331033

10341034
@staticmethod
1035-
def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType:
1035+
def allocate(
1036+
numel: int | torch.SymInt,
1037+
origin: Origin,
1038+
step: int | torch.SymInt = 1,
1039+
) -> GridIndexType:
10361040
from .._compiler.compile_environment import CompileEnvironment
10371041
from .host_function import HostFunction
10381042
from .host_function import SymbolOrigin
10391043

10401044
env = CompileEnvironment.current()
1041-
block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(1))
1045+
block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(step))
10421046
# assign this a new unbacked symbol since this should be treated like a scalar rather than a tile
10431047
sym = env.create_unbacked_symint()
10441048
HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin(

helion/language/loops.py

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,13 @@ def _(
200200
if bs is None:
201201
results.append(TileIndexType.allocate(size, origin))
202202
elif isinstance(bs, int):
203-
results.append(TileIndexType.allocate_fixed(size, bs, origin))
203+
results.append(TileIndexType.allocate(size, origin, bs))
204204
elif isinstance(bs, torch.SymInt):
205205
from helion._compiler.compile_environment import CompileEnvironment
206206

207207
index = CompileEnvironment.current().get_block_id(bs)
208208
if index is None:
209-
results.append(TileIndexType.allocate_fixed(size, bs, origin))
209+
results.append(TileIndexType.allocate(size, origin, bs))
210210
else:
211211
results.append(TileIndexType(origin=origin, block_id=index))
212212
CompileEnvironment.current().block_sizes[index].mark_alternate_size(
@@ -289,63 +289,104 @@ def _codegen_loop_helper(
289289
@_decorators.api(
290290
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
291291
)
292-
def grid(sizes: int, /) -> Iterator[torch.SymInt]: ...
292+
def grid(
293+
begin_or_end: int | torch.Tensor,
294+
end_or_none: int | torch.Tensor | None = None,
295+
/,
296+
step: object = None,
297+
) -> Iterator[torch.SymInt]: ...
293298

294299

295300
@overload
296301
@_decorators.api(
297302
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
298303
)
299-
def grid(sizes: Sequence[int], /) -> Iterator[Sequence[torch.SymInt]]: ...
304+
def grid(
305+
begin_or_end: Sequence[int | torch.Tensor],
306+
end_or_none: Sequence[int | torch.Tensor] | None = None,
307+
/,
308+
step: object = None,
309+
) -> Iterator[Sequence[torch.SymInt]]: ...
300310

301311

302312
@_decorators.api(
303313
is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True
304314
)
305315
def grid(
306-
sizes: int | Sequence[int],
316+
begin_or_end: int | torch.Tensor | Sequence[int | torch.Tensor],
317+
end_or_none: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None,
307318
/,
319+
step: object = None,
308320
) -> Iterator[torch.SymInt] | Iterator[Sequence[torch.SymInt]]: # type: ignore[type-arg]
309-
"""Iterate over *individual* indices of the given iteration space.
321+
"""Iterate over individual indices of the given iteration space.
310322
311323
Semantics are equivalent to
312324
313-
for i in hl.tile(size, block_size=1):
325+
for i in hl.tile(...):
314326
...
315327
316328
but `i` will be a scalar (`torch.SymInt`), not a 1-element tensor.
317-
"""
318329
330+
When used at the top level of a function, this becomes the grid of the kernel.
331+
Otherwise, it becomes a loop in the output kernel.
332+
333+
Similar to `range()` there are multiple forms of this function:
334+
grid(end) iterates from 0 to `end - 1`, with step size 1.
335+
grid(begin, end) iterates from `begin` to `end - 1`, with step size 1.
336+
grid(begin, end, step) iterates from `begin` to `end - 1`, with the given step size.
337+
grid(end, step=step) iterates from 0 to `end - 1`, with the given step size.
338+
"""
319339
raise exc.NotInsideKernel
320340

321341

322342
@_decorators.type_propagation(grid)
323-
def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo:
343+
def _(
344+
begin_or_end: TypeInfo,
345+
end_or_none: TypeInfo | None = None,
346+
/,
347+
step: TypeInfo | None = None,
348+
*,
349+
origin: Origin,
350+
) -> TypeInfo:
324351
parent = ExtendedAST.current()[-2]
325352
if not isinstance(parent, ast.For):
326353
raise exc.LoopFunctionNotInFor("grid")
327-
try:
328-
proxy_sizes = sizes.proxy()
329-
if not (
330-
isinstance(proxy_sizes, (int, torch.SymInt))
331-
or (
332-
isinstance(proxy_sizes, (list, tuple))
333-
and all(isinstance(x, (int, torch.SymInt)) for x in proxy_sizes)
334-
)
335-
):
336-
raise NotImplementedError
337-
except NotImplementedError:
338-
raise exc.TypeInferenceError(
339-
f"grid() expected int or list[int], got {sizes!s}"
340-
) from None
354+
begin, end = _normalize_begin_end(begin_or_end, end_or_none, origin=origin)
355+
proxy_begin = _to_proxy(begin)
356+
proxy_end = _to_proxy(end)
357+
_check_matching(proxy_begin, proxy_end)
358+
if _not_none(step):
359+
proxy_step = Tile._tiles_to_sizes(_to_proxy(step))
360+
_check_matching(proxy_end, proxy_step)
361+
else:
362+
proxy_step = begin.tree_map(lambda n: None)
341363

342-
if isinstance(proxy_sizes, (int, torch.SymInt)):
343-
return IterType(origin, GridIndexType.allocate(proxy_sizes, origin))
364+
if unpack := not isinstance(proxy_end, (list, tuple)):
365+
proxy_begin = [proxy_begin]
366+
proxy_end = [proxy_end]
367+
proxy_step = [proxy_step]
368+
369+
results = []
370+
for begin_part, end_part, step_part in zip(
371+
proxy_begin, proxy_end, proxy_step, strict=True
372+
):
373+
size = end_part - begin_part
374+
if isinstance(size, torch.Tensor):
375+
size = None # data dependent size
376+
if step_part is None:
377+
step_part = 1
378+
results.append(GridIndexType.allocate(size, origin, step_part))
344379

345-
assert isinstance(proxy_sizes, (list, tuple))
346-
elements = [GridIndexType.allocate(s, origin) for s in proxy_sizes]
347-
_add_config_choices([x.block_id for x in elements])
348-
return IterType(origin, SequenceType(origin, elements))
380+
_add_config_choices(
381+
[x.block_id for x in results],
382+
is_tile=False,
383+
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
384+
)
385+
if unpack:
386+
(result,) = results
387+
else:
388+
result = SequenceType(origin, results)
389+
return IterType(origin, result)
349390

350391

351392
@_decorators.codegen(grid)

0 commit comments

Comments
 (0)