@@ -138,6 +138,45 @@ def codegen_preamble(self, state: CodegenState) -> None:
138
138
def compact_shape (self , shapes : list [CompactedShape ]) -> list [CompactedShape ]:
139
139
raise NotImplementedError
140
140
141
+ def _create_block_id_info_dict (
142
+ self , state : CodegenState , use_proxy_ends : bool = False
143
+ ) -> dict [int , LoopDimInfo ]:
144
+ """Helper to create block_id_to_info dictionary with end bounds."""
145
+ env = CompileEnvironment .current ()
146
+ block_id_to_info = {}
147
+
148
+ if use_proxy_ends :
149
+ _ , _ , proxy_ends , _ = state .proxy_args
150
+ assert isinstance (proxy_ends , list )
151
+ for block_idx , end in zip (self .block_ids , proxy_ends , strict = True ):
152
+ if isinstance (end , (int , torch .SymInt )):
153
+ end_expr = _to_sympy (end )
154
+ else :
155
+ end_expr = None
156
+ block_id_to_info [block_idx ] = LoopDimInfo (
157
+ end_var_name = None , end_expr = end_expr
158
+ )
159
+ else :
160
+ for block_id in self .block_ids :
161
+ end_expr = env .block_sizes [block_id ].numel
162
+ end_var_name = state .sympy_expr (end_expr )
163
+ block_id_to_info [block_id ] = LoopDimInfo (
164
+ end_var_name = end_var_name , end_expr = end_expr
165
+ )
166
+
167
+ return block_id_to_info
168
+
169
+ def _setup_block_size_constexpr (
170
+ self , state : CodegenState , block_size_var : str , block_size : SymIntLike
171
+ ) -> None :
172
+ """Helper to setup constexpr block size variable on host."""
173
+ if state .device_function .constexpr_arg (block_size_var ):
174
+ state .codegen .host_statements .append (
175
+ statement_from_string (
176
+ f"{ block_size_var } = { HostFunction .current ().literal_expr (block_size )} "
177
+ )
178
+ )
179
+
141
180
142
181
class BlockSizeTileStrategy (TileStrategy ):
143
182
def __init__ (
@@ -265,19 +304,15 @@ def block_size_var(self, block_idx: int) -> str:
265
304
def _codegen_common (
266
305
self , state : CodegenState
267
306
) -> tuple [str , str , sympy .Expr , list [ast .AST ]]:
307
+ offsets_var = self .new_var ("offsets" , dce = True )
308
+ block_size_var = self .block_size_var (- 1 )
309
+ self ._setup_block_size_constexpr (state , block_size_var , self .block_size )
268
310
block_ids = self .block_ids
269
311
env = CompileEnvironment .current ()
270
312
total_numel = sympy .S .One
271
- offsets_var = self .new_var ("offsets" , dce = True )
272
- block_size_var = self .block_size_var (- 1 )
273
313
statements = []
274
- if state .device_function .constexpr_arg (block_size_var ):
275
- block_size_str = HostFunction .current ().literal_expr (self .block_size )
276
- state .codegen .host_statements .append (
277
- statement_from_string (f"{ block_size_var } = { block_size_str } " )
278
- )
314
+
279
315
for i , block_idx in enumerate (self ._reorder (block_ids )):
280
- # need to get the block size
281
316
numel = env .block_sizes [block_idx ].numel
282
317
block_index_var = self .index_var (block_idx )
283
318
expr = offsets_var
@@ -316,13 +351,7 @@ def codegen_grid(self) -> ast.AST:
316
351
317
352
state .device_function .set_pid (TmpPid ())
318
353
319
- block_id_to_info = {}
320
- for block_id in self .block_ids :
321
- end_expr = env .block_sizes [block_id ].numel
322
- end_var_name = state .sympy_expr (end_expr )
323
- block_id_to_info [block_id ] = LoopDimInfo (
324
- end_var_name = end_var_name , end_expr = end_expr
325
- )
354
+ block_id_to_info = self ._create_block_id_info_dict (state )
326
355
return DeviceGridState (self , block_id_to_info = block_id_to_info )
327
356
328
357
def codegen_device_loop (self , state : CodegenState ) -> DeviceLoopState :
@@ -348,18 +377,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
348
377
orelse = [],
349
378
type_comment = None ,
350
379
)
351
- # Create block_id_to_info with end bounds
352
- block_id_to_info = {}
353
- _ , _ , ends , _ = state .proxy_args
354
- assert isinstance (ends , list )
355
- for block_idx , end in zip (self .block_ids , ends , strict = True ):
356
- if isinstance (end , (int , torch .SymInt )):
357
- end_expr = _to_sympy (end )
358
- else :
359
- end_expr = None
360
- block_id_to_info [block_idx ] = LoopDimInfo (
361
- end_var_name = None , end_expr = end_expr
362
- )
380
+ block_id_to_info = self ._create_block_id_info_dict (state , use_proxy_ends = True )
363
381
364
382
return DeviceLoopState (
365
383
self ,
@@ -430,8 +448,6 @@ def __init__(
430
448
def codegen_grid (self , state : CodegenState ) -> DeviceGridState :
431
449
block_ids = self .block_ids
432
450
env = CompileEnvironment .current ()
433
- device_function = state .device_function
434
- dtype = env .triton_index_type ()
435
451
block_sizes = self .block_size
436
452
assert len (block_sizes ) == len (block_ids )
437
453
if isinstance (state .device_function .pid , SharedProgramID ):
@@ -440,31 +456,47 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
440
456
pids = self .select_pid_strategy ()
441
457
if isinstance (state .device_function .pid , SharedProgramID ):
442
458
pids .shared_pid_var = state .device_function .pid .shared_pid_var
443
- for i , (block_idx , block_size ) in enumerate (
444
- reversed (self ._reorder ([* zip (block_ids , block_sizes , strict = True )]))
459
+
460
+ assert state .ast_args is None
461
+ assert len (state .proxy_args ) == 3
462
+ if state .proxy_args [1 ] is None :
463
+ begins = [0 ] * len (block_ids )
464
+ else :
465
+ begins = state .proxy_args [0 ]
466
+ if not isinstance (begins , (list , tuple )):
467
+ begins = [begins ]
468
+ assert len (begins ) == len (block_ids )
469
+
470
+ for i , (block_idx , block_size , begin ) in enumerate (
471
+ reversed (self ._reorder ([* zip (block_ids , block_sizes , begins , strict = True )]))
445
472
):
446
473
numel = env .block_sizes [block_idx ].numel
474
+ device_function = state .device_function
475
+ dtype = env .triton_index_type ()
447
476
offset_var = self .offset_var (block_idx )
448
477
index_var = self .index_var (block_idx )
449
478
pid_var = device_function .new_var (f"pid_{ i } " , dce = True )
479
+
480
+ begin_offset_expr = ""
481
+ if begin != 0 :
482
+ begin_ast = self ._to_ast (begin , to_dtype = dtype )
483
+ begin_offset_expr = (
484
+ f"{ state .codegen .lift (begin_ast , dce = True , prefix = 'begin' ).id } + "
485
+ )
486
+
450
487
if block_size != 1 :
451
488
block_size_var = self .block_size_var (block_idx )
452
489
assert block_size_var is not None
453
- # TODO(jansel): need to check for conflict with user variable names since block_size_var is on host
454
- if state .device_function .constexpr_arg (block_size_var ):
455
- state .codegen .host_statements .append (
456
- statement_from_string (
457
- f"{ block_size_var } = { HostFunction .current ().literal_expr (block_size )} "
458
- )
459
- )
460
- state .add_statement (f"{ offset_var } = { pid_var } * { block_size_var } " )
490
+ self ._setup_block_size_constexpr (state , block_size_var , block_size )
491
+ state .add_statement (
492
+ f"{ offset_var } = { begin_offset_expr } { pid_var } * { block_size_var } "
493
+ )
461
494
state .add_statement (
462
495
f"{ index_var } = ({ offset_var } + tl.arange(0, ({ block_size_var } ))).to({ dtype } )"
463
496
)
464
497
else :
465
498
block_size_var = "1"
466
- dtype = env .triton_index_type ()
467
- state .add_statement (f"{ offset_var } = { pid_var } " )
499
+ state .add_statement (f"{ offset_var } = { begin_offset_expr } { pid_var } " )
468
500
state .add_statement (
469
501
f"{ index_var } = { offset_var } + tl.zeros([1], { dtype } )"
470
502
)
@@ -483,14 +515,7 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState:
483
515
else :
484
516
state .device_function .set_pid (pids )
485
517
486
- # Extract end_var_name from end bound expressions
487
- block_id_to_info = {}
488
- for block_id in self .block_ids :
489
- end_expr = env .block_sizes [block_id ].numel
490
- end_var_name = state .sympy_expr (end_expr )
491
- block_id_to_info [block_id ] = LoopDimInfo (
492
- end_var_name = end_var_name , end_expr = end_expr
493
- )
518
+ block_id_to_info = self ._create_block_id_info_dict (state )
494
519
return DeviceGridState (self , block_id_to_info = block_id_to_info )
495
520
496
521
def select_pid_strategy (self ) -> ProgramIDs :
@@ -509,6 +534,8 @@ def _to_ast(self, x: object, to_dtype: str | None = None) -> ast.AST:
509
534
from .device_function import DeviceFunction
510
535
511
536
return expr_from_string (DeviceFunction .current ().sympy_expr (x ))
537
+ if isinstance (x , torch .SymInt ):
538
+ return self ._to_ast (x ._sympy_ ())
512
539
raise NotImplementedError (f"{ type (x )} is not implemented." )
513
540
514
541
def codegen_device_loop (self , state : CodegenState ) -> DeviceLoopState :
@@ -534,12 +561,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
534
561
if block_size != 1 :
535
562
block_size_var = self .block_size_var (block_idx )
536
563
assert block_size_var is not None
537
- if state .device_function .constexpr_arg (block_size_var ):
538
- state .codegen .host_statements .append (
539
- statement_from_string (
540
- f"{ block_size_var } = { HostFunction .current ().literal_expr (block_size )} "
541
- )
542
- )
564
+ self ._setup_block_size_constexpr (state , block_size_var , block_size )
543
565
else :
544
566
block_size_var = "1"
545
567
end_var_name = state .codegen .lift (
0 commit comments