Skip to content

Commit ca8ee70

Browse files
committed
Simplify codegen for hl.grid
This makes codegen for grids go down the standard SymInt codegen path.
1 parent 9f44b61 commit ca8ee70

10 files changed

+56
-60
lines changed

helion/_compiler/compile_environment.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,6 @@ def is_flattened(self, config: Config) -> bool:
422422
spec = CompileEnvironment.current().config_spec
423423
return spec.flatten_loops.config_get(config.flatten_loops, self.block_id, False)
424424

425-
def is_grid(self) -> bool:
426-
return self.block_size_source.is_grid()
427-
428425
def update_min_block(self, value: int, *, allow_flattened: bool = True) -> None:
429426
spec = CompileEnvironment.current().config_spec
430427
if not allow_flattened:
@@ -437,9 +434,6 @@ class BlockSizeSource:
437434
def from_config(self, config: Config, block_id: int) -> int | torch.SymInt | None:
438435
raise NotImplementedError
439436

440-
def is_grid(self) -> bool:
441-
return False
442-
443437
def l2_grouping(self, config: Config) -> int:
444438
return 1
445439

@@ -452,17 +446,6 @@ def from_config(self, config: Config, block_id: int) -> int | torch.SymInt:
452446
return self.value
453447

454448

455-
@dataclasses.dataclass
456-
class GridBlockSizeSource(BlockSizeSource):
457-
"""Block size source for grid indices - always has block_size=1 but marks as grid for special indexing"""
458-
459-
def from_config(self, config: Config, block_id: int) -> int:
460-
return 1
461-
462-
def is_grid(self) -> bool:
463-
return True
464-
465-
466449
@dataclasses.dataclass
467450
class LoopSpecBlockSizeSource(BlockSizeSource):
468451
def from_config(self, config: Config, block_id: int) -> int:

helion/_compiler/device_function.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
from .host_function import NoCurrentFunction
3232
from .output_header import reserved_names
3333
from .variable_origin import BlockSizeOrigin
34+
from .variable_origin import GridOrigin
3435
from .variable_origin import Origin
3536
from .variable_origin import TensorSizeOrigin
3637

3738
if TYPE_CHECKING:
3839
from ..runtime.config import Config
40+
from .generate_ast import GenerateAST
3941
from .program_id import ProgramIDs
4042
from .program_id import SharedProgramID
4143

@@ -136,10 +138,11 @@ def __init__(self, val: int) -> None:
136138

137139

138140
class DeviceFunction:
139-
def __init__(self, name: str, config: Config) -> None:
141+
def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
140142
super().__init__()
141143
self.name = name
142144
self.config = config
145+
self.codegen = codegen
143146
self.arguments: list[Argument] = []
144147
self.body: list[ast.AST] = []
145148
self._tensor_args: dict[torch.Tensor, TensorArg] = {}
@@ -219,6 +222,8 @@ def _lift_sympy_arg(self, expr: sympy.Expr) -> str:
219222
result = self.block_size_var(origin.origin.block_id)
220223
assert result is not None
221224
return result
225+
if isinstance(origin.origin, GridOrigin):
226+
return self.codegen.offset_var(origin.origin.block_id)
222227
return self.expr_arg(expr, origin.origin).name
223228

224229
def user_sympy_expr(self, expr: sympy.Expr) -> str:

helion/_compiler/device_ir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .ast_read_writes import ReadWrites
4141
from .compile_environment import CompileEnvironment
4242
from .host_function import HostFunction
43+
from .inductor_lowering import APIFuncLowering
4344
from .inductor_lowering import CodegenState
4445
from .inductor_lowering import codegen_call_with_graph
4546
from .inductor_lowering import prepare_graph_lowerings
@@ -98,6 +99,7 @@ def _get_proxy_slot(
9899
name=origin.suggest_var_name(),
99100
)
100101
proxy.node.meta["val"] = obj
102+
proxy.node.meta["lowering"] = APIFuncLowering(_tracing_ops._host_tensor)
101103
return transform(tracker[obj])
102104
if isinstance(obj, proxy_tensor.py_sym_types):
103105
tracker = tracer.symnode_tracker
@@ -111,6 +113,7 @@ def _get_proxy_slot(
111113
name=debug_name if debug_name.isidentifier() else "symnode",
112114
)
113115
proxy.node.meta["val"] = obj
116+
proxy.node.meta["lowering"] = APIFuncLowering(_tracing_ops._get_symnode)
114117
proxy.force = lambda: proxy
115118
return transform(tracker[obj])
116119
return get_proxy_slot(obj, tracer, default, transform)

helion/_compiler/generate_ast.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, func: HostFunction, config: Config) -> None:
3939
self.host_statements: list[ast.AST] = []
4040
self.statements_stack: list[list[ast.AST]] = [self.host_statements]
4141
self.on_device = False
42-
self.device_function = DeviceFunction(f"_{func.name}_kernel", config)
42+
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
4343
self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = (
4444
collections.defaultdict(list)
4545
)
@@ -404,7 +404,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
404404
precompile_def = codegen_precompile_def(
405405
host_def, codegen.device_function.name
406406
)
407-
return ast.Module(
407+
result = ast.Module(
408408
[
409409
*func.codegen_imports(),
410410
kernel_def,
@@ -413,3 +413,6 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
413413
],
414414
[],
415415
)
416+
# break circular reference for better GC
417+
del codegen.device_function.codegen
418+
return result

helion/_compiler/indexing_strategy.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,7 @@ def compute_shape(
218218
if isinstance(symbol, sympy.Symbol):
219219
origin = HostFunction.current().expr_to_origin.get(symbol)
220220
if origin and isinstance(origin.origin, BlockSizeOrigin):
221-
if env.block_sizes[origin.origin.block_id].is_grid():
222-
pass
223-
elif tensor.size(tensor.ndim - len(input_size) - 1) != 1:
221+
if tensor.size(tensor.ndim - len(input_size) - 1) != 1:
224222
output_size.append(k)
225223
else:
226224
output_size.append(1)
@@ -267,9 +265,6 @@ def create(
267265
origin = HostFunction.current().expr_to_origin.get(symbol)
268266
if origin and isinstance(origin.origin, BlockSizeOrigin):
269267
index_var = state.codegen.index_var(origin.origin.block_id)
270-
if env.block_sizes[origin.origin.block_id].is_grid():
271-
index_values.append(index_var)
272-
continue
273268
expand = tile_strategy.expand_str(output_size, output_idx)
274269
i = len(index_values)
275270
index_values.append(f"({index_var}){expand}")

helion/_compiler/type_propagation.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from .variable_origin import DeviceOrigin
4646
from .variable_origin import GetItemOrigin
4747
from .variable_origin import GlobalOrigin
48+
from .variable_origin import GridOrigin
4849
from .variable_origin import Origin
4950
from .variable_origin import SourceOrigin
5051
from .variable_origin import TensorSizeOrigin
@@ -1023,11 +1024,8 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
10231024
class GridIndexType(SymIntType):
10241025
block_id: int
10251026

1026-
def __init__(self, origin: Origin, block_id: int) -> None:
1027-
from .._compiler.compile_environment import CompileEnvironment
1028-
1029-
env = CompileEnvironment.current()
1030-
super().__init__(origin, env.block_sizes[block_id].var)
1027+
def __init__(self, origin: Origin, sym: torch.SymInt, block_id: int) -> None:
1028+
super().__init__(origin, sym)
10311029
self.block_id = block_id
10321030

10331031
def __str__(self) -> str: # pragma: no cover – debug helper
@@ -1036,11 +1034,17 @@ def __str__(self) -> str: # pragma: no cover – debug helper
10361034
@staticmethod
10371035
def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType:
10381036
from .._compiler.compile_environment import CompileEnvironment
1039-
from .._compiler.compile_environment import GridBlockSizeSource
1037+
from .host_function import HostFunction
1038+
from .host_function import SymbolOrigin
10401039

10411040
env = CompileEnvironment.current()
1042-
block_idx = env.allocate_block_size(numel, source=GridBlockSizeSource())
1043-
return GridIndexType(origin, block_idx)
1041+
block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(1))
1042+
# assign this a new unbacked symbol since this should be treated like a scalar rather than a tile
1043+
sym = env.create_unbacked_symint()
1044+
HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin(
1045+
origin=GridOrigin(block_id),
1046+
)
1047+
return GridIndexType(origin, sym, block_id)
10441048

10451049
def merge(self, other: TypeInfo) -> TypeInfo: # type: ignore[override]
10461050
if isinstance(other, GridIndexType):

helion/_compiler/variable_origin.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,13 @@ class ReductionDimensionOrigin(Origin):
247247

248248
def host_str(self) -> str:
249249
raise NotImplementedError
250+
251+
252+
@dataclasses.dataclass
253+
class GridOrigin(Origin):
254+
"""Note this represents the tile_begin() of the grid, not the block size (which is always 1)"""
255+
256+
block_id: int
257+
258+
def host_str(self) -> str:
259+
raise NotImplementedError

test/test_control_flow.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,21 +123,20 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
123123
def _fn_kernel(x, y, output, output_stride_0, x_stride_0, y_stride_0):
124124
pid_0 = tl.program_id(0)
125125
offset_0 = pid_0
126-
indices_0 = offset_0 + tl.zeros([1], tl.int32)
127-
load = tl.load(y + indices_0 * y_stride_0, None)
126+
load = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None)
128127
v_0 = tl.full([], 0, tl.int32)
129128
v_1 = load != v_0
130129
if tl.sum(v_1):
131-
load_1 = tl.load(x + indices_0 * x_stride_0, None)
130+
load_1 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
132131
v_2 = 2.0
133132
v_3 = load_1 * v_2
134-
tl.store(output + indices_0 * output_stride_0, v_3, None)
135-
load_2 = tl.load(y + indices_0 * y_stride_0, None)
133+
tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, v_3, None)
134+
load_2 = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None)
136135
v_4 = tl.full([], 0, tl.int32)
137136
v_5 = load_2 == v_4
138137
if tl.sum(v_5):
139-
load_3 = tl.load(x + indices_0 * x_stride_0, None)
140-
tl.store(output + indices_0 * output_stride_0, load_3, None)
138+
load_3 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
139+
tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, load_3, None)
141140
142141
def fn(x: torch.Tensor, y: torch.Tensor):
143142
output = torch.zeros_like(x)

test/test_examples.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,9 +1687,8 @@ def test_moe_matmul_ogs(self):
16871687
def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A_stride_0, A_stride_1, C_stride_0, C_stride_1, W_stride_0, W_stride_1, W_stride_2, expert_token_counts_stride_0, expert_token_offsets_stride_0, sorted_to_orig_token_idx_stride_0, max_T_per_expert, N, K, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
16881688
pid_0 = tl.program_id(0)
16891689
offset_0 = pid_0
1690-
indices_0 = offset_0 + tl.zeros([1], tl.int32)
1691-
start = tl.load(expert_token_offsets + indices_0 * expert_token_offsets_stride_0, None)
1692-
num_tokens = tl.load(expert_token_counts + indices_0 * expert_token_counts_stride_0, None)
1690+
start = tl.load(expert_token_offsets + tl.full([1], offset_0, tl.int32) * expert_token_offsets_stride_0, None)
1691+
num_tokens = tl.load(expert_token_counts + tl.full([1], offset_0, tl.int32) * expert_token_counts_stride_0, None)
16931692
v_0 = tl.full([], 0, tl.int32)
16941693
v_1 = num_tokens != v_0
16951694
if tl.sum(v_1):
@@ -1725,7 +1724,7 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_
17251724
expert_orig_token_indices_copy_0 = expert_orig_token_indices_copy
17261725
acc_copy_0 = acc_copy
17271726
A_frag = tl.load(A + (expert_orig_token_indices_copy_0[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0)
1728-
W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
1727+
W_frag = tl.load(W + (tl.full([1], offset_0, tl.int32)[:, None] * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
17291728
acc = tl.dot(A_frag, W_frag, acc=acc_copy_0, input_precision='tf32')
17301729
existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0)
17311730
view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1])

0 commit comments

Comments
 (0)