@@ -716,6 +716,127 @@ def _tile_begin_end_make_precompiler(x: torch.Tensor):
716
716
return make_precompiler(_tile_begin_end_kernel)(x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""" ,
717
717
)
718
718
719
+ def test_range_as_grid_basic (self ):
720
+ """Test that range() works as an alias for hl.grid() in device code."""
721
+
722
+ @helion .kernel (use_default_config = True )
723
+ def range_kernel (x : torch .Tensor ) -> torch .Tensor :
724
+ batch = x .size (0 )
725
+ out = x .new_zeros (batch )
726
+ for tile_batch in hl .tile (batch ):
727
+ for i in range (10 ): # This should work now as alias for hl.grid(10)
728
+ out [tile_batch ] += x [tile_batch ] + i
729
+ return out
730
+
731
+ x = torch .randn (35 , device = DEVICE )
732
+
733
+ # Reference: sum over i of (x + i) = 10*x + sum(0..9) = 10*x + 45
734
+ expected = 10 * x + 45
735
+
736
+ code , result = code_and_output (range_kernel , (x ,))
737
+ torch .testing .assert_close (result , expected )
738
+
739
+ def test_range_with_begin_end (self ):
740
+ """Test that range(begin, end) works as alias for hl.grid(begin, end)."""
741
+
742
+ @helion .kernel (use_default_config = True )
743
+ def range_begin_end_kernel (x : torch .Tensor ) -> torch .Tensor :
744
+ batch = x .size (0 )
745
+ out = x .new_zeros (batch )
746
+ for tile_batch in hl .tile (batch ):
747
+ for i in range (2 , 7 ): # range(begin, end)
748
+ out [tile_batch ] += x [tile_batch ] * i
749
+ return out
750
+
751
+ x = torch .randn (20 , device = DEVICE )
752
+
753
+ # Reference: x * sum(range(2, 7)) = x * sum(2,3,4,5,6) = x * 20
754
+ expected = x * 20
755
+
756
+ code , result = code_and_output (range_begin_end_kernel , (x ,))
757
+ torch .testing .assert_close (result , expected )
758
+
759
+ def test_range_with_step (self ):
760
+ """Test that range(begin, end, step) works as alias for hl.grid(begin, end, step)."""
761
+
762
+ @helion .kernel (use_default_config = True )
763
+ def range_step_kernel (x : torch .Tensor ) -> torch .Tensor :
764
+ batch = x .size (0 )
765
+ out = x .new_zeros (batch )
766
+ for tile_batch in hl .tile (batch ):
767
+ for i in range (1 , 10 , 2 ): # range(begin, end, step)
768
+ out [tile_batch ] += x [tile_batch ] / i
769
+ return out
770
+
771
+ x = torch .randn (6 , device = DEVICE )
772
+
773
+ # Reference: x * sum(1/i for i in range(1, 10, 2)) = x * sum(1/1, 1/3, 1/5, 1/7, 1/9)
774
+ # = x * (1 + 1/3 + 1/5 + 1/7 + 1/9) = x * sum([1, 1/3, 1/5, 1/7, 1/9])
775
+ reciprocal_sum = sum (1.0 / i for i in range (1 , 10 , 2 ))
776
+ expected = x * reciprocal_sum
777
+
778
+ code , result = code_and_output (range_step_kernel , (x ,))
779
+ torch .testing .assert_close (result , expected )
780
+ self .assertExpectedInline (
781
+ code ,
782
+ """\
783
+ from __future__ import annotations
784
+
785
+ import torch
786
+ import triton
787
+ import triton.language as tl
788
+
789
+ @triton.jit
790
+ def _range_step_kernel_kernel(out, x, out_stride_0, x_stride_0, batch, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
791
+ pid_0 = tl.program_id(0)
792
+ offset_0 = pid_0 * _BLOCK_SIZE_0
793
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
794
+ mask_0 = indices_0 < batch
795
+ for offset_1 in range(1, 10, _BLOCK_SIZE_1):
796
+ load = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
797
+ load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
798
+ v_0 = offset_1.to(tl.float32)
799
+ v_1 = load_1 / v_0
800
+ v_2 = load + v_1
801
+ tl.store(out + indices_0 * out_stride_0, v_2, mask_0)
802
+
803
+ def range_step_kernel(x: torch.Tensor):
804
+ batch = x.size(0)
805
+ out = x.new_zeros(batch)
806
+ _BLOCK_SIZE_0 = 8
807
+ _BLOCK_SIZE_1 = 2
808
+ _range_step_kernel_kernel[triton.cdiv(batch, _BLOCK_SIZE_0),](out, x, out.stride(0), x.stride(0), batch, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
809
+ return out
810
+
811
+ def _range_step_kernel_make_precompiler(x: torch.Tensor):
812
+ batch = x.size(0)
813
+ out = x.new_zeros(batch)
814
+ _BLOCK_SIZE_0 = 8
815
+ _BLOCK_SIZE_1 = 2
816
+ from helion.runtime.precompile_shim import make_precompiler
817
+ return make_precompiler(_range_step_kernel_kernel)(out, x, out.stride(0), x.stride(0), batch, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""" ,
818
+ )
819
+
820
+ def test_range_with_tensor_size (self ):
821
+ """Test that range(tensor.size(dim)) works with dynamic tensor dimensions."""
822
+
823
+ @helion .kernel (use_default_config = True )
824
+ def range_tensor_size_kernel (x : torch .Tensor ) -> torch .Tensor :
825
+ batch = x .size (0 )
826
+ out = x .new_zeros (batch )
827
+ for tile_batch in hl .tile (batch ):
828
+ for _ in range (x .size (1 )): # Use tensor dimension in range
829
+ out [tile_batch ] += x [tile_batch , 0 ] # Just use first column
830
+ return out
831
+
832
+ x = torch .randn (8 , 5 , device = DEVICE ) # 8 rows, 5 columns
833
+
834
+ # Reference: Each row adds x[row, 0] for x.size(1) times = x[:, 0] * x.size(1)
835
+ expected = x [:, 0 ] * x .size (1 )
836
+
837
+ code , result = code_and_output (range_tensor_size_kernel , (x ,))
838
+ torch .testing .assert_close (result , expected )
839
+
719
840
720
841
if __name__ == "__main__" :
721
842
unittest .main ()
0 commit comments