@@ -200,13 +200,13 @@ def _(
200
200
if bs is None :
201
201
results .append (TileIndexType .allocate (size , origin ))
202
202
elif isinstance (bs , int ):
203
- results .append (TileIndexType .allocate_fixed (size , bs , origin ))
203
+ results .append (TileIndexType .allocate (size , origin , bs ))
204
204
elif isinstance (bs , torch .SymInt ):
205
205
from helion ._compiler .compile_environment import CompileEnvironment
206
206
207
207
index = CompileEnvironment .current ().get_block_id (bs )
208
208
if index is None :
209
- results .append (TileIndexType .allocate_fixed (size , bs , origin ))
209
+ results .append (TileIndexType .allocate (size , origin , bs ))
210
210
else :
211
211
results .append (TileIndexType (origin = origin , block_id = index ))
212
212
CompileEnvironment .current ().block_sizes [index ].mark_alternate_size (
@@ -289,63 +289,104 @@ def _codegen_loop_helper(
289
289
@_decorators .api (
290
290
is_device_loop = True , is_device_only = False , cache_type = True , tiles_as_sizes = True
291
291
)
292
- def grid (sizes : int , / ) -> Iterator [torch .SymInt ]: ...
292
+ def grid (
293
+ begin_or_end : int | torch .Tensor ,
294
+ end_or_none : int | torch .Tensor | None = None ,
295
+ / ,
296
+ step : object = None ,
297
+ ) -> Iterator [torch .SymInt ]: ...
293
298
294
299
295
300
@overload
296
301
@_decorators .api (
297
302
is_device_loop = True , is_device_only = False , cache_type = True , tiles_as_sizes = True
298
303
)
299
- def grid (sizes : Sequence [int ], / ) -> Iterator [Sequence [torch .SymInt ]]: ...
304
+ def grid (
305
+ begin_or_end : Sequence [int | torch .Tensor ],
306
+ end_or_none : Sequence [int | torch .Tensor ] | None = None ,
307
+ / ,
308
+ step : object = None ,
309
+ ) -> Iterator [Sequence [torch .SymInt ]]: ...
300
310
301
311
302
312
@_decorators .api (
303
313
is_device_loop = True , is_device_only = False , cache_type = True , tiles_as_sizes = True
304
314
)
305
315
def grid (
306
- sizes : int | Sequence [int ],
316
+ begin_or_end : int | torch .Tensor | Sequence [int | torch .Tensor ],
317
+ end_or_none : int | torch .Tensor | Sequence [int | torch .Tensor ] | None = None ,
307
318
/ ,
319
+ step : object = None ,
308
320
) -> Iterator [torch .SymInt ] | Iterator [Sequence [torch .SymInt ]]: # type: ignore[type-arg]
309
- """Iterate over * individual* indices of the given iteration space.
321
+ """Iterate over individual indices of the given iteration space.
310
322
311
323
Semantics are equivalent to
312
324
313
- for i in hl.tile(size, block_size=1 ):
325
+ for i in hl.tile(... ):
314
326
...
315
327
316
328
but `i` will be a scalar (`torch.SymInt`), not a 1-element tensor.
317
- """
318
329
330
+ When used at the top level of a function, this becomes the grid of the kernel.
331
+ Otherwise, it becomes a loop in the output kernel.
332
+
333
+ Similar to `range()` there are multiple forms of this function:
334
+ grid(end) iterates from 0 to `end - 1`, with step size 1.
335
+ grid(begin, end) iterates from `begin` to `end - 1`, with step size 1.
336
+ grid(begin, end, step) iterates from `begin` to `end - 1`, with the given step size.
337
+ grid(end, step=step) iterates from 0 to `end - 1`, with the given step size.
338
+ """
319
339
raise exc .NotInsideKernel
320
340
321
341
322
342
@_decorators .type_propagation (grid )
323
- def _ (sizes : TypeInfo , * , origin : Origin ) -> TypeInfo :
343
+ def _ (
344
+ begin_or_end : TypeInfo ,
345
+ end_or_none : TypeInfo | None = None ,
346
+ / ,
347
+ step : TypeInfo | None = None ,
348
+ * ,
349
+ origin : Origin ,
350
+ ) -> TypeInfo :
324
351
parent = ExtendedAST .current ()[- 2 ]
325
352
if not isinstance (parent , ast .For ):
326
353
raise exc .LoopFunctionNotInFor ("grid" )
327
- try :
328
- proxy_sizes = sizes .proxy ()
329
- if not (
330
- isinstance (proxy_sizes , (int , torch .SymInt ))
331
- or (
332
- isinstance (proxy_sizes , (list , tuple ))
333
- and all (isinstance (x , (int , torch .SymInt )) for x in proxy_sizes )
334
- )
335
- ):
336
- raise NotImplementedError
337
- except NotImplementedError :
338
- raise exc .TypeInferenceError (
339
- f"grid() expected int or list[int], got { sizes !s} "
340
- ) from None
354
+ begin , end = _normalize_begin_end (begin_or_end , end_or_none , origin = origin )
355
+ proxy_begin = _to_proxy (begin )
356
+ proxy_end = _to_proxy (end )
357
+ _check_matching (proxy_begin , proxy_end )
358
+ if _not_none (step ):
359
+ proxy_step = Tile ._tiles_to_sizes (_to_proxy (step ))
360
+ _check_matching (proxy_end , proxy_step )
361
+ else :
362
+ proxy_step = begin .tree_map (lambda n : None )
341
363
342
- if isinstance (proxy_sizes , (int , torch .SymInt )):
343
- return IterType (origin , GridIndexType .allocate (proxy_sizes , origin ))
364
+ if unpack := not isinstance (proxy_end , (list , tuple )):
365
+ proxy_begin = [proxy_begin ]
366
+ proxy_end = [proxy_end ]
367
+ proxy_step = [proxy_step ]
368
+
369
+ results = []
370
+ for begin_part , end_part , step_part in zip (
371
+ proxy_begin , proxy_end , proxy_step , strict = True
372
+ ):
373
+ size = end_part - begin_part
374
+ if isinstance (size , torch .Tensor ):
375
+ size = None # data dependent size
376
+ if step_part is None :
377
+ step_part = 1
378
+ results .append (GridIndexType .allocate (size , origin , step_part ))
344
379
345
- assert isinstance (proxy_sizes , (list , tuple ))
346
- elements = [GridIndexType .allocate (s , origin ) for s in proxy_sizes ]
347
- _add_config_choices ([x .block_id for x in elements ])
348
- return IterType (origin , SequenceType (origin , elements ))
380
+ _add_config_choices (
381
+ [x .block_id for x in results ],
382
+ is_tile = False ,
383
+ has_begin = not all ((isinstance (x , int ) and x == 0 ) for x in proxy_begin ),
384
+ )
385
+ if unpack :
386
+ (result ,) = results
387
+ else :
388
+ result = SequenceType (origin , results )
389
+ return IterType (origin , result )
349
390
350
391
351
392
@_decorators .codegen (grid )
0 commit comments