diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index eb59971f..b59eec64 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -895,6 +895,9 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: remove_unnecessary_tile_index(graph.graph) remove_unnecessary_masking(graph.graph) device_ir.build_rolled_reductions() + if len(device_ir.root_ids) > 1: + # yz_grid not supported with shared program IDs + CompileEnvironment.current().config_spec.allow_use_yz_grid = False return device_ir diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index c41db5d7..2709ba4f 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -450,9 +450,6 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState: env = CompileEnvironment.current() block_sizes = self.block_size assert len(block_sizes) == len(block_ids) - if isinstance(state.device_function.pid, SharedProgramID): - # Disable for shared pid - self.fn.config.config["use_yz_grid"] = False pids = self.select_pid_strategy() if isinstance(state.device_function.pid, SharedProgramID): pids.shared_pid_var = state.device_function.pid.shared_pid_var