Skip to content

Commit 09e550f

Browse files
committed
Remove NDGridTileStrategy
This can use the standard `hl.tile()` codegen. stack-info: PR: #209, branch: jansel/stack/60
1 parent 3275bcf commit 09e550f

File tree

3 files changed

+4
-34
lines changed

3 files changed

+4
-34
lines changed

helion/_compiler/compile_environment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,10 @@ def from_config(self, config: Config, block_id: int) -> int | torch.SymInt:
454454

455455
@dataclasses.dataclass
456456
class GridBlockSizeSource(BlockSizeSource):
457+
"""Block size source for grid indices - always has block_size=1 but marks as grid for special indexing"""
458+
457459
def from_config(self, config: Config, block_id: int) -> int:
458-
raise NotImplementedError
460+
return 1
459461

460462
def is_grid(self) -> bool:
461463
return True

helion/_compiler/tile_dispatch.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from helion._compiler.tile_strategy import CompactedShape
1616
from helion._compiler.tile_strategy import DeviceLoopState
1717
from helion._compiler.tile_strategy import FlattenedTileStrategy
18-
from helion._compiler.tile_strategy import NDGridTileStrategy
1918
from helion._compiler.tile_strategy import NDTileStrategy
2019
from helion._compiler.tile_strategy import TileStrategy
2120

@@ -67,13 +66,7 @@ def _add_loop_strategy(
6766
config.l2_groupings, block_ids[0], 1
6867
)
6968

70-
if block_size_infos[0].is_grid():
71-
strategy: TileStrategy = NDGridTileStrategy(
72-
fn,
73-
block_ids,
74-
loop_order=loop_order,
75-
)
76-
elif block_size_infos[0].is_flattened(config):
69+
if block_size_infos[0].is_flattened(config):
7770
block_size = functools.reduce(
7871
operator.mul, [bs.from_config_assert(config) for bs in block_size_infos]
7972
)

helion/_compiler/tile_strategy.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -633,31 +633,6 @@ def select_pid_strategy(self) -> ProgramIDs:
633633
return super().select_pid_strategy()
634634

635635

636-
class NDGridTileStrategy(_BaseNDTileStrategy):
637-
def __init__(
638-
self,
639-
fn: DeviceFunction,
640-
block_ids: list[int],
641-
loop_order: list[int],
642-
) -> None:
643-
super().__init__(
644-
fn=fn,
645-
block_ids=block_ids,
646-
block_size=[1] * len(block_ids), # pyre-ignore[6]
647-
loop_order=loop_order,
648-
)
649-
650-
def mask_var(self, block_idx: int) -> str | None:
651-
return None
652-
653-
def _setup_mask(
654-
self,
655-
*args: object,
656-
**kwargs: object,
657-
) -> None:
658-
return None
659-
660-
661636
class CompactedShape(NamedTuple):
662637
size_str: str
663638
user_indices: list[int]

0 commit comments

Comments
 (0)