diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index dfe9c600..723f4476 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -422,9 +422,6 @@ def is_flattened(self, config: Config) -> bool: spec = CompileEnvironment.current().config_spec return spec.flatten_loops.config_get(config.flatten_loops, self.block_id, False) - def is_grid(self) -> bool: - return self.block_size_source.is_grid() - def update_min_block(self, value: int, *, allow_flattened: bool = True) -> None: spec = CompileEnvironment.current().config_spec if not allow_flattened: @@ -437,9 +434,6 @@ class BlockSizeSource: def from_config(self, config: Config, block_id: int) -> int | torch.SymInt | None: raise NotImplementedError - def is_grid(self) -> bool: - return False - def l2_grouping(self, config: Config) -> int: return 1 @@ -452,17 +446,6 @@ def from_config(self, config: Config, block_id: int) -> int | torch.SymInt: return self.value -@dataclasses.dataclass -class GridBlockSizeSource(BlockSizeSource): - """Block size source for grid indices - always has block_size=1 but marks as grid for special indexing""" - - def from_config(self, config: Config, block_id: int) -> int: - return 1 - - def is_grid(self) -> bool: - return True - - @dataclasses.dataclass class LoopSpecBlockSizeSource(BlockSizeSource): def from_config(self, config: Config, block_id: int) -> int: diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index ba5aa914..6d669d2c 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -31,11 +31,13 @@ from .host_function import NoCurrentFunction from .output_header import reserved_names from .variable_origin import BlockSizeOrigin +from .variable_origin import GridOrigin from .variable_origin import Origin from .variable_origin import TensorSizeOrigin if TYPE_CHECKING: from ..runtime.config import Config + from .generate_ast import GenerateAST from .program_id import ProgramIDs from .program_id import SharedProgramID @@ -136,10 +138,11 @@ def __init__(self, val: int) -> None: class DeviceFunction: - def __init__(self, name: str, config: Config) -> None: + def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: super().__init__() self.name = name self.config = config + self.codegen = codegen self.arguments: list[Argument] = [] self.body: list[ast.AST] = [] self._tensor_args: dict[torch.Tensor, TensorArg] = {} @@ -219,6 +222,8 @@ def _lift_sympy_arg(self, expr: sympy.Expr) -> str: result = self.block_size_var(origin.origin.block_id) assert result is not None return result + if isinstance(origin.origin, GridOrigin): + return self.codegen.offset_var(origin.origin.block_id) return self.expr_arg(expr, origin.origin).name def user_sympy_expr(self, expr: sympy.Expr) -> str: diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 41a78517..24c7b4c8 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -40,6 +40,7 @@ from .ast_read_writes import ReadWrites from .compile_environment import CompileEnvironment from .host_function import HostFunction +from .inductor_lowering import APIFuncLowering from .inductor_lowering import CodegenState from .inductor_lowering import codegen_call_with_graph from .inductor_lowering import prepare_graph_lowerings @@ -98,6 +99,7 @@ def _get_proxy_slot( name=origin.suggest_var_name(), ) proxy.node.meta["val"] = obj + proxy.node.meta["lowering"] = APIFuncLowering(_tracing_ops._host_tensor) return transform(tracker[obj]) if isinstance(obj, proxy_tensor.py_sym_types): tracker = tracer.symnode_tracker @@ -111,6 +113,7 @@ def _get_proxy_slot( name=debug_name if debug_name.isidentifier() else "symnode", ) proxy.node.meta["val"] = obj + proxy.node.meta["lowering"] = APIFuncLowering(_tracing_ops._get_symnode) proxy.force = lambda: proxy return transform(tracker[obj]) return get_proxy_slot(obj, tracer, default, transform) diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index e50e2f2c..7fd45f82 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -39,7 +39,7 @@ def __init__(self, func: HostFunction, config: Config) -> None: self.host_statements: list[ast.AST] = [] self.statements_stack: list[list[ast.AST]] = [self.host_statements] self.on_device = False - self.device_function = DeviceFunction(f"_{func.name}_kernel", config) + self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self) self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = ( collections.defaultdict(list) ) @@ -404,7 +404,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST: precompile_def = codegen_precompile_def( host_def, codegen.device_function.name ) - return ast.Module( + result = ast.Module( [ *func.codegen_imports(), kernel_def, @@ -413,3 +413,6 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST: ], [], ) + # break circular reference for better GC + del codegen.device_function.codegen + return result diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index a40fafc0..82ff66ba 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -218,9 +218,7 @@ def compute_shape( if isinstance(symbol, sympy.Symbol): origin = HostFunction.current().expr_to_origin.get(symbol) if origin and isinstance(origin.origin, BlockSizeOrigin): - if env.block_sizes[origin.origin.block_id].is_grid(): - pass - elif tensor.size(tensor.ndim - len(input_size) - 1) != 1: + if tensor.size(tensor.ndim - len(input_size) - 1) != 1: output_size.append(k) else: output_size.append(1) @@ -267,9 +265,6 @@ def create( origin = HostFunction.current().expr_to_origin.get(symbol) if origin and isinstance(origin.origin, BlockSizeOrigin): index_var = state.codegen.index_var(origin.origin.block_id) - if env.block_sizes[origin.origin.block_id].is_grid(): - index_values.append(index_var) - continue expand = tile_strategy.expand_str(output_size, output_idx) i = len(index_values) index_values.append(f"({index_var}){expand}") diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index abef7d77..476e7bc5 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -45,6 +45,7 @@ from .variable_origin import DeviceOrigin from .variable_origin import GetItemOrigin from .variable_origin import GlobalOrigin +from .variable_origin import GridOrigin from .variable_origin import Origin from .variable_origin import SourceOrigin from .variable_origin import TensorSizeOrigin @@ -1023,11 +1024,8 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo: class GridIndexType(SymIntType): block_id: int - def __init__(self, origin: Origin, block_id: int) -> None: - from .._compiler.compile_environment import CompileEnvironment - - env = CompileEnvironment.current() - super().__init__(origin, env.block_sizes[block_id].var) + def __init__(self, origin: Origin, sym: torch.SymInt, block_id: int) -> None: + super().__init__(origin, sym) self.block_id = block_id def __str__(self) -> str: # pragma: no cover – debug helper @@ -1036,11 +1034,17 @@ def __str__(self) -> str: # pragma: no cover – debug helper @staticmethod def allocate(numel: int | torch.SymInt, origin: Origin) -> GridIndexType: from .._compiler.compile_environment import CompileEnvironment - from .._compiler.compile_environment import GridBlockSizeSource + from .host_function import HostFunction + from .host_function import SymbolOrigin env = CompileEnvironment.current() - block_idx = env.allocate_block_size(numel, source=GridBlockSizeSource()) - return GridIndexType(origin, block_idx) + block_id = env.allocate_block_size(numel, source=FixedBlockSizeSource(1)) + # assign this a new unbacked symbol since this should be treated like a scalar rather than a tile + sym = env.create_unbacked_symint() + HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin( + origin=GridOrigin(block_id), + ) + return GridIndexType(origin, sym, block_id) def merge(self, other: TypeInfo) -> TypeInfo: # type: ignore[override] if isinstance(other, GridIndexType): diff --git a/helion/_compiler/variable_origin.py b/helion/_compiler/variable_origin.py index 4c8cea19..014e2683 100644 --- a/helion/_compiler/variable_origin.py +++ b/helion/_compiler/variable_origin.py @@ -247,3 +247,13 @@ class ReductionDimensionOrigin(Origin): def host_str(self) -> str: raise NotImplementedError + + +@dataclasses.dataclass +class GridOrigin(Origin): + """Note this represents the tile_begin() of the grid, not the block size (which is always 1)""" + + block_id: int + + def host_str(self) -> str: + raise NotImplementedError diff --git a/helion/language/loops.py b/helion/language/loops.py index 897a36b7..821a524c 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -257,22 +257,29 @@ def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool: @_decorators.codegen(tile) def _(state: CodegenState) -> ast.AST: + return _codegen_loop_helper(state) + + +def _codegen_loop_helper( + state: CodegenState, +) -> ast.AST: + """Helper method for codegen of tile and grid decorators.""" for_loop = ExtendedAST.current()[-2] loop_type = for_loop._loop_type type_info = ExtendedAST.current()[-1]._type_info assert isinstance(for_loop, ast.For) assert isinstance(type_info, IterType) + if isinstance(type_info.inner, SequenceType): - tile_indices = type_info.inner.unpack() + indices = type_info.inner.unpack() else: - tile_indices = [type_info.inner] - assert all(isinstance(t, TileIndexType) for t in tile_indices) + indices = [type_info.inner] + assert all(isinstance(t, (TileIndexType, GridIndexType)) for t in indices) if loop_type == LoopType.GRID: env = CompileEnvironment.current() env.loop_dependency_checker.register_loop(for_loop) - - block_ids = [t.block_id for t in tile_indices] + block_ids = [t.block_id for t in indices] state.tile_strategy.codegen_grid(state, block_ids) return expr_from_string("None") raise AssertionError(f"Expected loop type: {loop_type}") @@ -343,18 +350,4 @@ def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo: @_decorators.codegen(grid) def _(state: CodegenState) -> ast.AST: - for_loop = ExtendedAST.current()[-2] - loop_type = for_loop._loop_type - type_info = ExtendedAST.current()[-1]._type_info - assert isinstance(for_loop, ast.For) - assert isinstance(type_info, IterType) - if isinstance(type_info.inner, SequenceType): - grid_indices = type_info.inner.unpack() - else: - grid_indices = [type_info.inner] - assert all(isinstance(t, GridIndexType) for t in grid_indices) - if loop_type == LoopType.GRID: - block_ids = [t.block_id for t in grid_indices] - state.tile_strategy.codegen_grid(state, block_ids) - return expr_from_string("None") - raise AssertionError(f"Expected loop type: {loop_type}") + return _codegen_loop_helper(state) diff --git a/test/test_control_flow.py b/test/test_control_flow.py index 09b9510a..a19b9c64 100644 --- a/test/test_control_flow.py +++ b/test/test_control_flow.py @@ -123,21 +123,20 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def _fn_kernel(x, y, output, output_stride_0, x_stride_0, y_stride_0): pid_0 = tl.program_id(0) offset_0 = pid_0 - indices_0 = offset_0 + tl.zeros([1], tl.int32) - load = tl.load(y + indices_0 * y_stride_0, None) + load = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None) v_0 = tl.full([], 0, tl.int32) v_1 = load != v_0 if tl.sum(v_1): - load_1 = tl.load(x + indices_0 * x_stride_0, None) + load_1 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None) v_2 = 2.0 v_3 = load_1 * v_2 - tl.store(output + indices_0 * output_stride_0, v_3, None) - load_2 = tl.load(y + indices_0 * y_stride_0, None) + tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, v_3, None) + load_2 = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None) v_4 = tl.full([], 0, tl.int32) v_5 = load_2 == v_4 if tl.sum(v_5): - load_3 = tl.load(x + indices_0 * x_stride_0, None) - tl.store(output + indices_0 * output_stride_0, load_3, None) + load_3 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None) + tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, load_3, None) def fn(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) diff --git a/test/test_examples.py b/test/test_examples.py index 074e93be..0bd6578c 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1687,9 +1687,8 @@ def test_moe_matmul_ogs(self): def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A_stride_0, A_stride_1, C_stride_0, C_stride_1, W_stride_0, W_stride_1, W_stride_2, expert_token_counts_stride_0, expert_token_offsets_stride_0, sorted_to_orig_token_idx_stride_0, max_T_per_expert, N, K, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 - indices_0 = offset_0 + tl.zeros([1], tl.int32) - start = tl.load(expert_token_offsets + indices_0 * expert_token_offsets_stride_0, None) - num_tokens = tl.load(expert_token_counts + indices_0 * expert_token_counts_stride_0, None) + start = tl.load(expert_token_offsets + tl.full([1], offset_0, tl.int32) * expert_token_offsets_stride_0, None) + num_tokens = tl.load(expert_token_counts + tl.full([1], offset_0, tl.int32) * expert_token_counts_stride_0, None) v_0 = tl.full([], 0, tl.int32) v_1 = num_tokens != v_0 if tl.sum(v_1): @@ -1725,7 +1724,7 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_ expert_orig_token_indices_copy_0 = expert_orig_token_indices_copy acc_copy_0 = acc_copy A_frag = tl.load(A + (expert_orig_token_indices_copy_0[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0) - W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0) + W_frag = tl.load(W + (tl.full([1], offset_0, tl.int32)[:, None] * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0) acc = tl.dot(A_frag, W_frag, acc=acc_copy_0, input_precision='tf32') existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0) view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1]) diff --git a/test/test_grid.py b/test/test_grid.py index af5d239c..0ade4047 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -77,7 +77,6 @@ def grid_1d_pytorch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 - indices_0 = offset_0 + tl.zeros([1], tl.int32) for offset_1 in range(0, 16, _BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) for offset_2 in range(0, 4, _BLOCK_SIZE_2): @@ -88,11 +87,11 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy - load = tl.load(x + (indices_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None) + load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None) load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0) acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32') v_0 = acc.to(tl.float16) - tl.store(out + (indices_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :]) + tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :]) def grid_1d(x: torch.Tensor, y: torch.Tensor): b, m, k = x.size() @@ -141,11 +140,11 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co for offset_3 in range(0, 32, _BLOCK_SIZE_3): acc_copy = acc acc_copy_0 = acc_copy - load = tl.reshape(tl.load(tl.make_block_ptr(x, [8, 16, 32], [512, 32, 1], [offset_0, offset_1, offset_3], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero'), [_BLOCK_SIZE_1, _BLOCK_SIZE_3]) + load = tl.reshape(tl.load(tl.make_block_ptr(x, [8, 16, 32], [512, 32, 1], [offset_0, offset_1, offset_3], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3], [2, 1, 0]), boundary_check=[1, 2], padding_option='zero'), [_BLOCK_SIZE_1, _BLOCK_SIZE_3]) load_1 = tl.load(tl.make_block_ptr(y, [32, 4], [4, 1], [offset_3, offset_2], [_BLOCK_SIZE_3, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero') acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32') v_0 = acc.to(tl.float16) - tl.store(tl.make_block_ptr(out, [8, 16, 4], [64, 4, 1], [offset_0, offset_1, offset_2], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), tl.reshape(v_0, [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2]), boundary_check=[0, 1, 2]) + tl.store(tl.make_block_ptr(out, [8, 16, 4], [64, 4, 1], [offset_0, offset_1, offset_2], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), tl.reshape(v_0, [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2]), boundary_check=[1, 2]) def grid_1d(x: torch.Tensor, y: torch.Tensor): b, m, k = x.size() @@ -216,9 +215,7 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE pid_0 = tl.program_id(0) % num_blocks_0 pid_1 = tl.program_id(0) // num_blocks_0 offset_0 = pid_0 - indices_0 = offset_0 + tl.zeros([1], tl.int32) offset_1 = pid_1 - indices_1 = offset_1 + tl.zeros([1], tl.int32) for offset_2 in range(0, 64, _BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) for offset_3 in range(0, 16, _BLOCK_SIZE_3): @@ -228,11 +225,11 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy - load = tl.load(x + (indices_0 * 8192 + indices_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None) + load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None) load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None) acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32') v_0 = acc.to(tl.float16) - tl.store(out + (indices_0 * 4096 + indices_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None) + tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None) def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor): bi, bj, m, k = x.size() @@ -283,11 +280,11 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE for offset_4 in range(0, 32, _BLOCK_SIZE_4): acc_copy = acc acc_copy_0 = acc_copy - load = tl.reshape(tl.load(tl.make_block_ptr(x, [3, 4, 64, 32], [8192, 2048, 32, 1], [offset_0, offset_1, offset_2, offset_4], [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_4], [3, 2, 1, 0]), boundary_check=[0, 1, 2, 3], padding_option='zero'), [_BLOCK_SIZE_2, _BLOCK_SIZE_4]) + load = tl.reshape(tl.load(tl.make_block_ptr(x, [3, 4, 64, 32], [8192, 2048, 32, 1], [offset_0, offset_1, offset_2, offset_4], [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_4], [3, 2, 1, 0]), boundary_check=[2, 3], padding_option='zero'), [_BLOCK_SIZE_2, _BLOCK_SIZE_4]) load_1 = tl.load(tl.make_block_ptr(y, [32, 16], [16, 1], [offset_4, offset_3], [_BLOCK_SIZE_4, _BLOCK_SIZE_3], [1, 0]), boundary_check=[0, 1], padding_option='zero') acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32') v_0 = acc.to(tl.float16) - tl.store(tl.make_block_ptr(out, [3, 4, 64, 16], [4096, 1024, 16, 1], [offset_0, offset_1, offset_2, offset_3], [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], [3, 2, 1, 0]), tl.reshape(v_0, [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_3]), boundary_check=[0, 1, 2, 3]) + tl.store(tl.make_block_ptr(out, [3, 4, 64, 16], [4096, 1024, 16, 1], [offset_0, offset_1, offset_2, offset_3], [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], [3, 2, 1, 0]), tl.reshape(v_0, [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_3]), boundary_check=[2, 3]) def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor): bi, bj, m, k = x.size() @@ -356,9 +353,7 @@ def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 - indices_0 = offset_0 + tl.zeros([1], tl.int32) for offset_1 in range(0, 4, 1): - indices_1 = offset_1 + tl.arange(0, 1).to(tl.int32) for offset_2 in range(0, 64, _BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) for offset_3 in range(0, 16, _BLOCK_SIZE_3): @@ -368,11 +363,11 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy - load = tl.load(x + (indices_0 * 8192 + indices_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None) + load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None) load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None) acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32') v_0 = acc.to(tl.float16) - tl.store(out + (indices_0 * 4096 + indices_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None) + tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None) def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor): bi, bj, m, k = x.size()