Skip to content

Commit 030df03

Browse files
committed
Add support for hl.grid(begin, end, step)
stack-info: PR: #211, branch: jansel/stack/62
1 parent 095235b commit 030df03

File tree

5 files changed

+498
-107
lines changed

5 files changed

+498
-107
lines changed

helion/_compiler/program_id.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def combined_device_cdiv(self, state: CodegenState) -> str:
119119
return " * ".join(pid.device_cdiv(state) for pid in self.pids)
120120

121121
def combined_host_cdiv(self) -> str:
122-
return " * ".join(pid.host_cdiv() for pid in self.pids)
122+
return " * ".join(f"({pid.host_cdiv()})" for pid in self.pids)
123123

124124
def codegen(self, state: CodegenState) -> None:
125125
pid_var = self.shared_pid_var or "tl.program_id(0)"

helion/_compiler/tile_strategy.py

Lines changed: 77 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,45 @@ def codegen_preamble(self, state: CodegenState) -> None:
138138
def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:
139139
raise NotImplementedError
140140

141+
def _create_block_id_info_dict(
142+
self, state: CodegenState, use_proxy_ends: bool = False
143+
) -> dict[int, LoopDimInfo]:
144+
"""Helper to create block_id_to_info dictionary with end bounds."""
145+
env = CompileEnvironment.current()
146+
block_id_to_info = {}
147+
148+
if use_proxy_ends:
149+
_, _, proxy_ends, _ = state.proxy_args
150+
assert isinstance(proxy_ends, list)
151+
for block_idx, end in zip(self.block_ids, proxy_ends, strict=True):
152+
if isinstance(end, (int, torch.SymInt)):
153+
end_expr = _to_sympy(end)
154+
else:
155+
end_expr = None
156+
block_id_to_info[block_idx] = LoopDimInfo(
157+
end_var_name=None, end_expr=end_expr
158+
)
159+
else:
160+
for block_id in self.block_ids:
161+
end_expr = env.block_sizes[block_id].numel
162+
end_var_name = state.sympy_expr(end_expr)
163+
block_id_to_info[block_id] = LoopDimInfo(
164+
end_var_name=end_var_name, end_expr=end_expr
165+
)
166+
167+
return block_id_to_info
168+
169+
def _setup_block_size_constexpr(
170+
self, state: CodegenState, block_size_var: str, block_size: SymIntLike
171+
) -> None:
172+
"""Helper to setup constexpr block size variable on host."""
173+
if state.device_function.constexpr_arg(block_size_var):
174+
state.codegen.host_statements.append(
175+
statement_from_string(
176+
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
177+
)
178+
)
179+
141180

142181
class BlockSizeTileStrategy(TileStrategy):
143182
def __init__(
@@ -265,19 +304,15 @@ def block_size_var(self, block_idx: int) -> str:
265304
def _codegen_common(
266305
self, state: CodegenState
267306
) -> tuple[str, str, sympy.Expr, list[ast.AST]]:
307+
offsets_var = self.new_var("offsets", dce=True)
308+
block_size_var = self.block_size_var(-1)
309+
self._setup_block_size_constexpr(state, block_size_var, self.block_size)
268310
block_ids = self.block_ids
269311
env = CompileEnvironment.current()
270312
total_numel = sympy.S.One
271-
offsets_var = self.new_var("offsets", dce=True)
272-
block_size_var = self.block_size_var(-1)
273313
statements = []
274-
if state.device_function.constexpr_arg(block_size_var):
275-
block_size_str = HostFunction.current().literal_expr(self.block_size)
276-
state.codegen.host_statements.append(
277-
statement_from_string(f"{block_size_var} = {block_size_str}")
278-
)
314+
279315
for i, block_idx in enumerate(self._reorder(block_ids)):
280-
# need to get the block size
281316
numel = env.block_sizes[block_idx].numel
282317
block_index_var = self.index_var(block_idx)
283318
expr = offsets_var
@@ -316,13 +351,7 @@ def codegen_grid(self) -> ast.AST:
316351

317352
state.device_function.set_pid(TmpPid())
318353

319-
block_id_to_info = {}
320-
for block_id in self.block_ids:
321-
end_expr = env.block_sizes[block_id].numel
322-
end_var_name = state.sympy_expr(end_expr)
323-
block_id_to_info[block_id] = LoopDimInfo(
324-
end_var_name=end_var_name, end_expr=end_expr
325-
)
354+
block_id_to_info = self._create_block_id_info_dict(state)
326355
return DeviceGridState(self, block_id_to_info=block_id_to_info)
327356

328357
def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
@@ -348,18 +377,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
348377
orelse=[],
349378
type_comment=None,
350379
)
351-
# Create block_id_to_info with end bounds
352-
block_id_to_info = {}
353-
_, _, ends, _ = state.proxy_args
354-
assert isinstance(ends, list)
355-
for block_idx, end in zip(self.block_ids, ends, strict=True):
356-
if isinstance(end, (int, torch.SymInt)):
357-
end_expr = _to_sympy(end)
358-
else:
359-
end_expr = None
360-
block_id_to_info[block_idx] = LoopDimInfo(
361-
end_var_name=None, end_expr=end_expr
362-
)
380+
block_id_to_info = self._create_block_id_info_dict(state, use_proxy_ends=True)
363381

364382
return DeviceLoopState(
365383
self,
@@ -430,8 +448,6 @@ def __init__(
430448
def codegen_grid(self, state: CodegenState) -> DeviceGridState:
431449
block_ids = self.block_ids
432450
env = CompileEnvironment.current()
433-
device_function = state.device_function
434-
dtype = env.triton_index_type()
435451
block_sizes = self.block_size
436452
assert len(block_sizes) == len(block_ids)
437453
if isinstance(state.device_function.pid, SharedProgramID):
@@ -440,31 +456,47 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
440456
pids = self.select_pid_strategy()
441457
if isinstance(state.device_function.pid, SharedProgramID):
442458
pids.shared_pid_var = state.device_function.pid.shared_pid_var
443-
for i, (block_idx, block_size) in enumerate(
444-
reversed(self._reorder([*zip(block_ids, block_sizes, strict=True)]))
459+
460+
assert state.ast_args is None
461+
assert len(state.proxy_args) == 3
462+
if state.proxy_args[1] is None:
463+
begins = [0] * len(block_ids)
464+
else:
465+
begins = state.proxy_args[0]
466+
if not isinstance(begins, (list, tuple)):
467+
begins = [begins]
468+
assert len(begins) == len(block_ids)
469+
470+
for i, (block_idx, block_size, begin) in enumerate(
471+
reversed(self._reorder([*zip(block_ids, block_sizes, begins, strict=True)]))
445472
):
446473
numel = env.block_sizes[block_idx].numel
474+
device_function = state.device_function
475+
dtype = env.triton_index_type()
447476
offset_var = self.offset_var(block_idx)
448477
index_var = self.index_var(block_idx)
449478
pid_var = device_function.new_var(f"pid_{i}", dce=True)
479+
480+
begin_offset_expr = ""
481+
if begin != 0:
482+
begin_ast = self._to_ast(begin, to_dtype=dtype)
483+
begin_offset_expr = (
484+
f"{state.codegen.lift(begin_ast, dce=True, prefix='begin').id} + "
485+
)
486+
450487
if block_size != 1:
451488
block_size_var = self.block_size_var(block_idx)
452489
assert block_size_var is not None
453-
# TODO(jansel): need to check for conflict with user variable names since block_size_var is on host
454-
if state.device_function.constexpr_arg(block_size_var):
455-
state.codegen.host_statements.append(
456-
statement_from_string(
457-
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
458-
)
459-
)
460-
state.add_statement(f"{offset_var} = {pid_var} * {block_size_var}")
490+
self._setup_block_size_constexpr(state, block_size_var, block_size)
491+
state.add_statement(
492+
f"{offset_var} = {begin_offset_expr}{pid_var} * {block_size_var}"
493+
)
461494
state.add_statement(
462495
f"{index_var} = ({offset_var} + tl.arange(0, ({block_size_var}))).to({dtype})"
463496
)
464497
else:
465498
block_size_var = "1"
466-
dtype = env.triton_index_type()
467-
state.add_statement(f"{offset_var} = {pid_var}")
499+
state.add_statement(f"{offset_var} = {begin_offset_expr}{pid_var}")
468500
state.add_statement(
469501
f"{index_var} = {offset_var} + tl.zeros([1], {dtype})"
470502
)
@@ -483,14 +515,7 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
483515
else:
484516
state.device_function.set_pid(pids)
485517

486-
# Extract end_var_name from end bound expressions
487-
block_id_to_info = {}
488-
for block_id in self.block_ids:
489-
end_expr = env.block_sizes[block_id].numel
490-
end_var_name = state.sympy_expr(end_expr)
491-
block_id_to_info[block_id] = LoopDimInfo(
492-
end_var_name=end_var_name, end_expr=end_expr
493-
)
518+
block_id_to_info = self._create_block_id_info_dict(state)
494519
return DeviceGridState(self, block_id_to_info=block_id_to_info)
495520

496521
def select_pid_strategy(self) -> ProgramIDs:
@@ -509,6 +534,8 @@ def _to_ast(self, x: object, to_dtype: str | None = None) -> ast.AST:
509534
from .device_function import DeviceFunction
510535

511536
return expr_from_string(DeviceFunction.current().sympy_expr(x))
537+
if isinstance(x, torch.SymInt):
538+
return self._to_ast(x._sympy_())
512539
raise NotImplementedError(f"{type(x)} is not implemented.")
513540

514541
def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
@@ -534,12 +561,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
534561
if block_size != 1:
535562
block_size_var = self.block_size_var(block_idx)
536563
assert block_size_var is not None
537-
if state.device_function.constexpr_arg(block_size_var):
538-
state.codegen.host_statements.append(
539-
statement_from_string(
540-
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
541-
)
542-
)
564+
self._setup_block_size_constexpr(state, block_size_var, block_size)
543565
else:
544566
block_size_var = "1"
545567
end_var_name = state.codegen.lift(

helion/_compiler/type_propagation.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -982,29 +982,24 @@ def proxy(self) -> object:
982982

983983
@staticmethod
984984
def allocate(
985-
numel: int | torch.SymInt | AutoSize | None, origin: Origin
986-
) -> TileIndexType:
987-
env = CompileEnvironment.current()
988-
block_id = env.allocate_block_size(numel, source=LoopSpecBlockSizeSource())
989-
env.config_spec.block_sizes.append(
990-
BlockSizeSpec(
991-
block_id=block_id,
992-
size_hint=_get_hint(numel),
993-
)
994-
)
995-
return TileIndexType(origin, block_id)
996-
997-
@staticmethod
998-
def allocate_fixed(
999985
numel: int | torch.SymInt | AutoSize | None,
1000-
block_size: int | torch.SymInt,
1001986
origin: Origin,
987+
block_size: int | torch.SymInt | None = None,
1002988
) -> TileIndexType:
1003989
env = CompileEnvironment.current()
1004-
return TileIndexType(
1005-
origin,
1006-
env.allocate_block_size(numel, source=FixedBlockSizeSource(block_size)),
1007-
)
990+
if block_size is None:
991+
block_id = env.allocate_block_size(numel, source=LoopSpecBlockSizeSource())
992+
env.config_spec.block_sizes.append(
993+
BlockSizeSpec(
994+
block_id=block_id,
995+
size_hint=_get_hint(numel),
996+
)
997+
)
998+
else:
999+
block_id = env.allocate_block_size(
1000+
numel, source=FixedBlockSizeSource(block_size)
1001+
)
1002+
return TileIndexType(origin, block_id)
10081003

10091004
def merge(self, other: TypeInfo) -> TypeInfo:
10101005
if isinstance(other, TileIndexType):
@@ -1024,21 +1019,30 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
10241019
class GridIndexType(SymIntType):
10251020
block_id: int
10261021

1027-
def __init__(self, origin: Origin, sym: torch.SymInt, block_id: int) -> None:
1022+
def __init__(
1023+
self,
1024+
origin: Origin,
1025+
sym: torch.SymInt,
1026+
block_id: int,
1027+
) -> None:
10281028
super().__init__(origin, sym)
10291029
self.block_id = block_id
10301030

10311031
def __str__(self) -> str: # pragma: no cover – debug helper
10321032
return f"{type(self).__name__}({self.block_id})"
10331033

10341034
@staticmethod
1035-
def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType:
1035+
def allocate(
1036+
numel: int | torch.SymInt,
1037+
origin: Origin,
1038+
step: int | torch.SymInt = 1,
1039+
) -> GridIndexType:
10361040
from .._compiler.compile_environment import CompileEnvironment
10371041
from .host_function import HostFunction
10381042
from .host_function import SymbolOrigin
10391043

10401044
env = CompileEnvironment.current()
1041-
block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(1))
1045+
block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(step))
10421046
# assign this a new unbacked symbol since this should be treated like a scalar rather than a tile
10431047
sym = env.create_unbacked_symint()
10441048
HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin(

0 commit comments

Comments
 (0)