Skip to content

Add support for hl.grid(begin, end, step) #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion helion/_compiler/program_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def combined_device_cdiv(self, state: CodegenState) -> str:
return " * ".join(pid.device_cdiv(state) for pid in self.pids)

def combined_host_cdiv(self) -> str:
return " * ".join(pid.host_cdiv() for pid in self.pids)
return " * ".join(f"({pid.host_cdiv()})" for pid in self.pids)

def codegen(self, state: CodegenState) -> None:
pid_var = self.shared_pid_var or "tl.program_id(0)"
Expand Down
132 changes: 77 additions & 55 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,45 @@ def codegen_preamble(self, state: CodegenState) -> None:
def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:
raise NotImplementedError

def _create_block_id_info_dict(
self, state: CodegenState, use_proxy_ends: bool = False
) -> dict[int, LoopDimInfo]:
"""Helper to create block_id_to_info dictionary with end bounds."""
env = CompileEnvironment.current()
block_id_to_info = {}

if use_proxy_ends:
_, _, proxy_ends, _ = state.proxy_args
assert isinstance(proxy_ends, list)
for block_idx, end in zip(self.block_ids, proxy_ends, strict=True):
if isinstance(end, (int, torch.SymInt)):
end_expr = _to_sympy(end)
else:
end_expr = None
block_id_to_info[block_idx] = LoopDimInfo(
end_var_name=None, end_expr=end_expr
)
else:
for block_id in self.block_ids:
end_expr = env.block_sizes[block_id].numel
end_var_name = state.sympy_expr(end_expr)
block_id_to_info[block_id] = LoopDimInfo(
end_var_name=end_var_name, end_expr=end_expr
)

return block_id_to_info

def _setup_block_size_constexpr(
self, state: CodegenState, block_size_var: str, block_size: SymIntLike
) -> None:
"""Helper to setup constexpr block size variable on host."""
if state.device_function.constexpr_arg(block_size_var):
state.codegen.host_statements.append(
statement_from_string(
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
)
)


class BlockSizeTileStrategy(TileStrategy):
def __init__(
Expand Down Expand Up @@ -265,19 +304,15 @@ def block_size_var(self, block_idx: int) -> str:
def _codegen_common(
self, state: CodegenState
) -> tuple[str, str, sympy.Expr, list[ast.AST]]:
offsets_var = self.new_var("offsets", dce=True)
block_size_var = self.block_size_var(-1)
self._setup_block_size_constexpr(state, block_size_var, self.block_size)
block_ids = self.block_ids
env = CompileEnvironment.current()
total_numel = sympy.S.One
offsets_var = self.new_var("offsets", dce=True)
block_size_var = self.block_size_var(-1)
statements = []
if state.device_function.constexpr_arg(block_size_var):
block_size_str = HostFunction.current().literal_expr(self.block_size)
state.codegen.host_statements.append(
statement_from_string(f"{block_size_var} = {block_size_str}")
)

for i, block_idx in enumerate(self._reorder(block_ids)):
# need to get the block size
numel = env.block_sizes[block_idx].numel
block_index_var = self.index_var(block_idx)
expr = offsets_var
Expand Down Expand Up @@ -316,13 +351,7 @@ def codegen_grid(self) -> ast.AST:

state.device_function.set_pid(TmpPid())

block_id_to_info = {}
for block_id in self.block_ids:
end_expr = env.block_sizes[block_id].numel
end_var_name = state.sympy_expr(end_expr)
block_id_to_info[block_id] = LoopDimInfo(
end_var_name=end_var_name, end_expr=end_expr
)
block_id_to_info = self._create_block_id_info_dict(state)
return DeviceGridState(self, block_id_to_info=block_id_to_info)

def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
Expand All @@ -348,18 +377,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
orelse=[],
type_comment=None,
)
# Create block_id_to_info with end bounds
block_id_to_info = {}
_, _, ends, _ = state.proxy_args
assert isinstance(ends, list)
for block_idx, end in zip(self.block_ids, ends, strict=True):
if isinstance(end, (int, torch.SymInt)):
end_expr = _to_sympy(end)
else:
end_expr = None
block_id_to_info[block_idx] = LoopDimInfo(
end_var_name=None, end_expr=end_expr
)
block_id_to_info = self._create_block_id_info_dict(state, use_proxy_ends=True)

return DeviceLoopState(
self,
Expand Down Expand Up @@ -430,8 +448,6 @@ def __init__(
def codegen_grid(self, state: CodegenState) -> DeviceGridState:
block_ids = self.block_ids
env = CompileEnvironment.current()
device_function = state.device_function
dtype = env.triton_index_type()
block_sizes = self.block_size
assert len(block_sizes) == len(block_ids)
if isinstance(state.device_function.pid, SharedProgramID):
Expand All @@ -440,31 +456,47 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
pids = self.select_pid_strategy()
if isinstance(state.device_function.pid, SharedProgramID):
pids.shared_pid_var = state.device_function.pid.shared_pid_var
for i, (block_idx, block_size) in enumerate(
reversed(self._reorder([*zip(block_ids, block_sizes, strict=True)]))

assert state.ast_args is None
assert len(state.proxy_args) == 3
if state.proxy_args[1] is None:
begins = [0] * len(block_ids)
else:
begins = state.proxy_args[0]
if not isinstance(begins, (list, tuple)):
begins = [begins]
assert len(begins) == len(block_ids)

for i, (block_idx, block_size, begin) in enumerate(
reversed(self._reorder([*zip(block_ids, block_sizes, begins, strict=True)]))
):
numel = env.block_sizes[block_idx].numel
device_function = state.device_function
dtype = env.triton_index_type()
offset_var = self.offset_var(block_idx)
index_var = self.index_var(block_idx)
pid_var = device_function.new_var(f"pid_{i}", dce=True)

begin_offset_expr = ""
if begin != 0:
begin_ast = self._to_ast(begin, to_dtype=dtype)
begin_offset_expr = (
f"{state.codegen.lift(begin_ast, dce=True, prefix='begin').id} + "
)

if block_size != 1:
block_size_var = self.block_size_var(block_idx)
assert block_size_var is not None
# TODO(jansel): need to check for conflict with user variable names since block_size_var is on host
if state.device_function.constexpr_arg(block_size_var):
state.codegen.host_statements.append(
statement_from_string(
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
)
)
state.add_statement(f"{offset_var} = {pid_var} * {block_size_var}")
self._setup_block_size_constexpr(state, block_size_var, block_size)
state.add_statement(
f"{offset_var} = {begin_offset_expr}{pid_var} * {block_size_var}"
)
state.add_statement(
f"{index_var} = ({offset_var} + tl.arange(0, ({block_size_var}))).to({dtype})"
)
else:
block_size_var = "1"
dtype = env.triton_index_type()
state.add_statement(f"{offset_var} = {pid_var}")
state.add_statement(f"{offset_var} = {begin_offset_expr}{pid_var}")
state.add_statement(
f"{index_var} = {offset_var} + tl.zeros([1], {dtype})"
)
Expand All @@ -483,14 +515,7 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
else:
state.device_function.set_pid(pids)

# Extract end_var_name from end bound expressions
block_id_to_info = {}
for block_id in self.block_ids:
end_expr = env.block_sizes[block_id].numel
end_var_name = state.sympy_expr(end_expr)
block_id_to_info[block_id] = LoopDimInfo(
end_var_name=end_var_name, end_expr=end_expr
)
block_id_to_info = self._create_block_id_info_dict(state)
return DeviceGridState(self, block_id_to_info=block_id_to_info)

def select_pid_strategy(self) -> ProgramIDs:
Expand All @@ -509,6 +534,8 @@ def _to_ast(self, x: object, to_dtype: str | None = None) -> ast.AST:
from .device_function import DeviceFunction

return expr_from_string(DeviceFunction.current().sympy_expr(x))
if isinstance(x, torch.SymInt):
return self._to_ast(x._sympy_())
raise NotImplementedError(f"{type(x)} is not implemented.")

def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
Expand All @@ -534,12 +561,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
if block_size != 1:
block_size_var = self.block_size_var(block_idx)
assert block_size_var is not None
if state.device_function.constexpr_arg(block_size_var):
state.codegen.host_statements.append(
statement_from_string(
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
)
)
self._setup_block_size_constexpr(state, block_size_var, block_size)
else:
block_size_var = "1"
end_var_name = state.codegen.lift(
Expand Down
48 changes: 26 additions & 22 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,29 +982,24 @@ def proxy(self) -> object:

@staticmethod
def allocate(
numel: int | torch.SymInt | AutoSize | None, origin: Origin
) -> TileIndexType:
env = CompileEnvironment.current()
block_id = env.allocate_block_size(numel, source=LoopSpecBlockSizeSource())
env.config_spec.block_sizes.append(
BlockSizeSpec(
block_id=block_id,
size_hint=_get_hint(numel),
)
)
return TileIndexType(origin, block_id)

@staticmethod
def allocate_fixed(
numel: int | torch.SymInt | AutoSize | None,
block_size: int | torch.SymInt,
origin: Origin,
block_size: int | torch.SymInt | None = None,
) -> TileIndexType:
env = CompileEnvironment.current()
return TileIndexType(
origin,
env.allocate_block_size(numel, source=FixedBlockSizeSource(block_size)),
)
if block_size is None:
block_id = env.allocate_block_size(numel, source=LoopSpecBlockSizeSource())
env.config_spec.block_sizes.append(
BlockSizeSpec(
block_id=block_id,
size_hint=_get_hint(numel),
)
)
else:
block_id = env.allocate_block_size(
numel, source=FixedBlockSizeSource(block_size)
)
return TileIndexType(origin, block_id)

def merge(self, other: TypeInfo) -> TypeInfo:
if isinstance(other, TileIndexType):
Expand All @@ -1024,21 +1019,30 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
class GridIndexType(SymIntType):
block_id: int

def __init__(self, origin: Origin, sym: torch.SymInt, block_id: int) -> None:
def __init__(
self,
origin: Origin,
sym: torch.SymInt,
block_id: int,
) -> None:
super().__init__(origin, sym)
self.block_id = block_id

def __str__(self) -> str: # pragma: no cover – debug helper
return f"{type(self).__name__}({self.block_id})"

@staticmethod
def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType:
def allocate(
numel: int | torch.SymInt,
origin: Origin,
step: int | torch.SymInt = 1,
) -> GridIndexType:
from .._compiler.compile_environment import CompileEnvironment
from .host_function import HostFunction
from .host_function import SymbolOrigin

env = CompileEnvironment.current()
block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(1))
block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(step))
# assign this a new unbacked symbol since this should be treated like a scalar rather than a tile
sym = env.create_unbacked_symint()
HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin(
Expand Down
Loading