Skip to content

Commit 9247821

Browse files
committed
Move yz_grid disabling logic to ConfigSpec
This generates a smaller search space which is easier to autotune. stack-info: PR: #213, branch: jansel/stack/64
1 parent 3e7f664 commit 9247821

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

helion/_compiler/device_ir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,9 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
895895
remove_unnecessary_tile_index(graph.graph)
896896
remove_unnecessary_masking(graph.graph)
897897
device_ir.build_rolled_reductions()
898+
if len(device_ir.root_ids) > 1:
899+
# yz_grid not supported with shared program IDs
900+
CompileEnvironment.current().config_spec.allow_use_yz_grid = False
898901
return device_ir
899902

900903

helion/_compiler/tile_strategy.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,6 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
450450
env = CompileEnvironment.current()
451451
block_sizes = self.block_size
452452
assert len(block_sizes) == len(block_ids)
453-
if isinstance(state.device_function.pid, SharedProgramID):
454-
# Disable for shared pid
455-
self.fn.config.config["use_yz_grid"] = False
456453
pids = self.select_pid_strategy()
457454
if isinstance(state.device_function.pid, SharedProgramID):
458455
pids.shared_pid_var = state.device_function.pid.shared_pid_var

0 commit comments

Comments
 (0)