Skip to content

Commit 3af7fb5

Browse files
committed
Add InvalidDeviceForLoop exception type
stack-info: PR: #205, branch: jansel/stack/57
1 parent 5d1a3ac commit 3af7fb5

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

helion/_compiler/device_ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,8 @@ def visit_For(self, node: ast.For) -> None:
507507
assert not node.orelse
508508
assert isinstance(node.iter, ExtendedAST)
509509
iter_type = node.iter._type_info
510-
assert isinstance(iter_type, IterType)
510+
if not isinstance(iter_type, IterType):
511+
raise exc.InvalidDeviceForLoop(iter_type)
511512
inner_type: TypeInfo = iter_type.inner
512513
if node._loop_type == LoopType.GRID:
513514
self._assign(node.target, inner_type.proxy())

helion/exc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ class UndefinedVariable(BaseError):
222222
message = "{} is not defined."
223223

224224

225+
class InvalidDeviceForLoop(BaseError):
226+
message = "For loops on device must use `hl.tile` or `hl.grid`, got {0!s}."
227+
228+
225229
class StarredArgsNotSupportedOnDevice(BaseError):
226230
message = "*/** args are not supported inside the `hl.tile` or `hl.grid` loop."
227231

test/test_errors.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,18 @@ def fn(x: torch.Tensor) -> torch.Tensor:
109109
r"Expected ndim=1, but got ndim=2.*You have too many indices",
110110
):
111111
code_and_output(fn, (torch.randn(8, device=DEVICE),))
112+
113+
def test_invalid_device_for_loop(self):
114+
"""Test that InvalidDeviceForLoop is raised for invalid for loops on device."""
115+
116+
@helion.kernel()
117+
def fn(x: torch.Tensor) -> torch.Tensor:
118+
batch = x.size(0)
119+
out = x.new_empty(batch)
120+
for tile_batch in hl.tile(batch):
121+
for i in range(10):
122+
out[tile_batch] = x[tile_batch] + i
123+
return out
124+
125+
with self.assertRaises(helion.exc.InvalidDeviceForLoop):
126+
code_and_output(fn, (torch.randn(8, device=DEVICE),))

0 commit comments

Comments
 (0)