Skip to content

Simplify codegen for hl.grid #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,6 @@ def is_flattened(self, config: Config) -> bool:
spec = CompileEnvironment.current().config_spec
return spec.flatten_loops.config_get(config.flatten_loops, self.block_id, False)

def is_grid(self) -> bool:
return self.block_size_source.is_grid()

def update_min_block(self, value: int, *, allow_flattened: bool = True) -> None:
spec = CompileEnvironment.current().config_spec
if not allow_flattened:
Expand All @@ -437,9 +434,6 @@ class BlockSizeSource:
def from_config(self, config: Config, block_id: int) -> int | torch.SymInt | None:
raise NotImplementedError

def is_grid(self) -> bool:
return False

def l2_grouping(self, config: Config) -> int:
return 1

Expand All @@ -452,17 +446,6 @@ def from_config(self, config: Config, block_id: int) -> int | torch.SymInt:
return self.value


@dataclasses.dataclass
class GridBlockSizeSource(BlockSizeSource):
"""Block size source for grid indices - always has block_size=1 but marks as grid for special indexing"""

def from_config(self, config: Config, block_id: int) -> int:
return 1

def is_grid(self) -> bool:
return True


@dataclasses.dataclass
class LoopSpecBlockSizeSource(BlockSizeSource):
def from_config(self, config: Config, block_id: int) -> int:
Expand Down
7 changes: 6 additions & 1 deletion helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@
from .host_function import NoCurrentFunction
from .output_header import reserved_names
from .variable_origin import BlockSizeOrigin
from .variable_origin import GridOrigin
from .variable_origin import Origin
from .variable_origin import TensorSizeOrigin

if TYPE_CHECKING:
from ..runtime.config import Config
from .generate_ast import GenerateAST
from .program_id import ProgramIDs
from .program_id import SharedProgramID

Expand Down Expand Up @@ -136,10 +138,11 @@ def __init__(self, val: int) -> None:


class DeviceFunction:
def __init__(self, name: str, config: Config) -> None:
def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
super().__init__()
self.name = name
self.config = config
self.codegen = codegen
self.arguments: list[Argument] = []
self.body: list[ast.AST] = []
self._tensor_args: dict[torch.Tensor, TensorArg] = {}
Expand Down Expand Up @@ -219,6 +222,8 @@ def _lift_sympy_arg(self, expr: sympy.Expr) -> str:
result = self.block_size_var(origin.origin.block_id)
assert result is not None
return result
if isinstance(origin.origin, GridOrigin):
return self.codegen.offset_var(origin.origin.block_id)
return self.expr_arg(expr, origin.origin).name

def user_sympy_expr(self, expr: sympy.Expr) -> str:
Expand Down
3 changes: 3 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .ast_read_writes import ReadWrites
from .compile_environment import CompileEnvironment
from .host_function import HostFunction
from .inductor_lowering import APIFuncLowering
from .inductor_lowering import CodegenState
from .inductor_lowering import codegen_call_with_graph
from .inductor_lowering import prepare_graph_lowerings
Expand Down Expand Up @@ -98,6 +99,7 @@ def _get_proxy_slot(
name=origin.suggest_var_name(),
)
proxy.node.meta["val"] = obj
proxy.node.meta["lowering"] = APIFuncLowering(_tracing_ops._host_tensor)
return transform(tracker[obj])
if isinstance(obj, proxy_tensor.py_sym_types):
tracker = tracer.symnode_tracker
Expand All @@ -111,6 +113,7 @@ def _get_proxy_slot(
name=debug_name if debug_name.isidentifier() else "symnode",
)
proxy.node.meta["val"] = obj
proxy.node.meta["lowering"] = APIFuncLowering(_tracing_ops._get_symnode)
proxy.force = lambda: proxy
return transform(tracker[obj])
return get_proxy_slot(obj, tracer, default, transform)
Expand Down
7 changes: 5 additions & 2 deletions helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, func: HostFunction, config: Config) -> None:
self.host_statements: list[ast.AST] = []
self.statements_stack: list[list[ast.AST]] = [self.host_statements]
self.on_device = False
self.device_function = DeviceFunction(f"_{func.name}_kernel", config)
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = (
collections.defaultdict(list)
)
Expand Down Expand Up @@ -404,7 +404,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
precompile_def = codegen_precompile_def(
host_def, codegen.device_function.name
)
return ast.Module(
result = ast.Module(
[
*func.codegen_imports(),
kernel_def,
Expand All @@ -413,3 +413,6 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
],
[],
)
# break circular reference for better GC
del codegen.device_function.codegen
return result
7 changes: 1 addition & 6 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def compute_shape(
if isinstance(symbol, sympy.Symbol):
origin = HostFunction.current().expr_to_origin.get(symbol)
if origin and isinstance(origin.origin, BlockSizeOrigin):
if env.block_sizes[origin.origin.block_id].is_grid():
pass
elif tensor.size(tensor.ndim - len(input_size) - 1) != 1:
if tensor.size(tensor.ndim - len(input_size) - 1) != 1:
output_size.append(k)
else:
output_size.append(1)
Expand Down Expand Up @@ -267,9 +265,6 @@ def create(
origin = HostFunction.current().expr_to_origin.get(symbol)
if origin and isinstance(origin.origin, BlockSizeOrigin):
index_var = state.codegen.index_var(origin.origin.block_id)
if env.block_sizes[origin.origin.block_id].is_grid():
index_values.append(index_var)
continue
expand = tile_strategy.expand_str(output_size, output_idx)
i = len(index_values)
index_values.append(f"({index_var}){expand}")
Expand Down
20 changes: 12 additions & 8 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .variable_origin import DeviceOrigin
from .variable_origin import GetItemOrigin
from .variable_origin import GlobalOrigin
from .variable_origin import GridOrigin
from .variable_origin import Origin
from .variable_origin import SourceOrigin
from .variable_origin import TensorSizeOrigin
Expand Down Expand Up @@ -1023,11 +1024,8 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
class GridIndexType(SymIntType):
block_id: int

def __init__(self, origin: Origin, block_id: int) -> None:
from .._compiler.compile_environment import CompileEnvironment

env = CompileEnvironment.current()
super().__init__(origin, env.block_sizes[block_id].var)
def __init__(self, origin: Origin, sym: torch.SymInt, block_id: int) -> None:
super().__init__(origin, sym)
self.block_id = block_id

def __str__(self) -> str: # pragma: no cover – debug helper
Expand All @@ -1036,11 +1034,17 @@ def __str__(self) -> str: # pragma: no cover – debug helper
@staticmethod
def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType:
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.compile_environment import GridBlockSizeSource
from .host_function import HostFunction
from .host_function import SymbolOrigin

env = CompileEnvironment.current()
block_idx = env.allocate_block_size(numel, source=GridBlockSizeSource())
return GridIndexType(origin, block_idx)
block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(1))
# assign this a new unbacked symbol since this should be treated like a scalar rather than a tile
sym = env.create_unbacked_symint()
HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin(
origin=GridOrigin(block_id),
)
return GridIndexType(origin, sym, block_id)

def merge(self, other: TypeInfo) -> TypeInfo: # type: ignore[override]
if isinstance(other, GridIndexType):
Expand Down
10 changes: 10 additions & 0 deletions helion/_compiler/variable_origin.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,13 @@ class ReductionDimensionOrigin(Origin):

def host_str(self) -> str:
raise NotImplementedError


@dataclasses.dataclass
class GridOrigin(Origin):
"""Note this represents the tile_begin() of the grid, not the block size (which is always 1)"""

block_id: int

def host_str(self) -> str:
raise NotImplementedError
33 changes: 13 additions & 20 deletions helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,22 +257,29 @@ def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:

@_decorators.codegen(tile)
def _(state: CodegenState) -> ast.AST:
return _codegen_loop_helper(state)


def _codegen_loop_helper(
state: CodegenState,
) -> ast.AST:
"""Helper method for codegen of tile and grid decorators."""
for_loop = ExtendedAST.current()[-2]
loop_type = for_loop._loop_type
type_info = ExtendedAST.current()[-1]._type_info
assert isinstance(for_loop, ast.For)
assert isinstance(type_info, IterType)

if isinstance(type_info.inner, SequenceType):
tile_indices = type_info.inner.unpack()
indices = type_info.inner.unpack()
else:
tile_indices = [type_info.inner]
assert all(isinstance(t, TileIndexType) for t in tile_indices)
indices = [type_info.inner]
assert all(isinstance(t, (TileIndexType, GridIndexType)) for t in indices)

if loop_type == LoopType.GRID:
env = CompileEnvironment.current()
env.loop_dependency_checker.register_loop(for_loop)

block_ids = [t.block_id for t in tile_indices]
block_ids = [t.block_id for t in indices]
state.tile_strategy.codegen_grid(state, block_ids)
return expr_from_string("None")
raise AssertionError(f"Expected loop type: {loop_type}")
Expand Down Expand Up @@ -343,18 +350,4 @@ def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo:

@_decorators.codegen(grid)
def _(state: CodegenState) -> ast.AST:
for_loop = ExtendedAST.current()[-2]
loop_type = for_loop._loop_type
type_info = ExtendedAST.current()[-1]._type_info
assert isinstance(for_loop, ast.For)
assert isinstance(type_info, IterType)
if isinstance(type_info.inner, SequenceType):
grid_indices = type_info.inner.unpack()
else:
grid_indices = [type_info.inner]
assert all(isinstance(t, GridIndexType) for t in grid_indices)
if loop_type == LoopType.GRID:
block_ids = [t.block_id for t in grid_indices]
state.tile_strategy.codegen_grid(state, block_ids)
return expr_from_string("None")
raise AssertionError(f"Expected loop type: {loop_type}")
return _codegen_loop_helper(state)
13 changes: 6 additions & 7 deletions test/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,20 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def _fn_kernel(x, y, output, output_stride_0, x_stride_0, y_stride_0):
pid_0 = tl.program_id(0)
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
load = tl.load(y + indices_0 * y_stride_0, None)
load = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None)
v_0 = tl.full([], 0, tl.int32)
v_1 = load != v_0
if tl.sum(v_1):
load_1 = tl.load(x + indices_0 * x_stride_0, None)
load_1 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
v_2 = 2.0
v_3 = load_1 * v_2
tl.store(output + indices_0 * output_stride_0, v_3, None)
load_2 = tl.load(y + indices_0 * y_stride_0, None)
tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, v_3, None)
load_2 = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None)
v_4 = tl.full([], 0, tl.int32)
v_5 = load_2 == v_4
if tl.sum(v_5):
load_3 = tl.load(x + indices_0 * x_stride_0, None)
tl.store(output + indices_0 * output_stride_0, load_3, None)
load_3 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, load_3, None)

def fn(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x)
Expand Down
7 changes: 3 additions & 4 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,9 +1687,8 @@ def test_moe_matmul_ogs(self):
def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A_stride_0, A_stride_1, C_stride_0, C_stride_1, W_stride_0, W_stride_1, W_stride_2, expert_token_counts_stride_0, expert_token_offsets_stride_0, sorted_to_orig_token_idx_stride_0, max_T_per_expert, N, K, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
start = tl.load(expert_token_offsets + indices_0 * expert_token_offsets_stride_0, None)
num_tokens = tl.load(expert_token_counts + indices_0 * expert_token_counts_stride_0, None)
start = tl.load(expert_token_offsets + tl.full([1], offset_0, tl.int32) * expert_token_offsets_stride_0, None)
num_tokens = tl.load(expert_token_counts + tl.full([1], offset_0, tl.int32) * expert_token_counts_stride_0, None)
v_0 = tl.full([], 0, tl.int32)
v_1 = num_tokens != v_0
if tl.sum(v_1):
Expand Down Expand Up @@ -1725,7 +1724,7 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_
expert_orig_token_indices_copy_0 = expert_orig_token_indices_copy
acc_copy_0 = acc_copy
A_frag = tl.load(A + (expert_orig_token_indices_copy_0[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0)
W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
W_frag = tl.load(W + (tl.full([1], offset_0, tl.int32)[:, None] * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
acc = tl.dot(A_frag, W_frag, acc=acc_copy_0, input_precision='tf32')
existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0)
view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1])
Expand Down
Loading