Skip to content

Commit 9a78513

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI] Update test runner to use the new APIs (pytorch#147105)
Summary: Switch to the newer aoti_compile_and_package APIs. Some tests still kept using legacy APIs, and will follow up with internal test refactoring. Differential Revision: [D69609685](https://our.internmc.facebook.com/intern/diff/D69609685) Pull Request resolved: pytorch#147105 Approved by: https://github.com/jingsh
1 parent b52a8be commit 9a78513

12 files changed

+153
-109
lines changed

test/distributed/test_c10d_functional_native.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ def func(arg: torch.Tensor) -> torch.Tensor:
747747
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
748748

749749
# Test aoti
750-
AOTIRunnerUtil.run("cuda", func, (arg,))
750+
AOTIRunnerUtil.run(func, (arg,))
751751
torch.cuda.synchronize()
752752

753753
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@@ -793,7 +793,7 @@ def func(args: list[torch.Tensor]) -> torch.Tensor:
793793
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
794794

795795
# Test aoti
796-
out = AOTIRunnerUtil.run("cuda", func, (args,)) # noqa: F841
796+
out = AOTIRunnerUtil.run(func, (args,)) # noqa: F841
797797
torch.cuda.synchronize()
798798

799799
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@@ -905,7 +905,7 @@ def func(arg: torch.Tensor) -> torch.Tensor:
905905
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
906906

907907
# Test aoti
908-
AOTIRunnerUtil.run("cuda", func, (arg,))
908+
AOTIRunnerUtil.run(func, (arg,))
909909
torch.cuda.synchronize()
910910

911911
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@@ -939,7 +939,7 @@ def func(args: list[torch.Tensor]) -> torch.Tensor:
939939
)
940940

941941
# Test aoti
942-
out = AOTIRunnerUtil.run("cuda", func, (args,)) # noqa: F841
942+
out = AOTIRunnerUtil.run(func, (args,)) # noqa: F841
943943
torch.cuda.synchronize()
944944

945945
@unittest.skipIf(not HAS_GPU, "This is a GPU test!")
@@ -961,7 +961,7 @@ def func(arg: torch.Tensor) -> torch.Tensor:
961961
)
962962

963963
# Test aoti
964-
AOTIRunnerUtil.run("cuda", func, (arg,))
964+
AOTIRunnerUtil.run(func, (arg,))
965965
torch.cuda.synchronize()
966966

967967
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@@ -987,7 +987,7 @@ def func(arg: torch.Tensor) -> torch.Tensor:
987987
)
988988

989989
# Test aoti
990-
AOTIRunnerUtil.run("cuda", func, (arg,))
990+
AOTIRunnerUtil.run(func, (arg,))
991991
torch.cuda.synchronize()
992992

993993
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@@ -1023,7 +1023,7 @@ def func(args: list[torch.Tensor]) -> torch.Tensor:
10231023
)
10241024

10251025
# Test aoti
1026-
AOTIRunnerUtil.run("cuda", func, (args,))
1026+
AOTIRunnerUtil.run(func, (args,))
10271027
torch.cuda.synchronize()
10281028

10291029
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@@ -1108,7 +1108,7 @@ def func(arg: torch.Tensor) -> torch.Tensor:
11081108
)
11091109

11101110
# Test aoti
1111-
AOTIRunnerUtil.run("cuda", func, (arg,))
1111+
AOTIRunnerUtil.run(func, (arg,))
11121112
torch.cuda.synchronize()
11131113

11141114
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")

0 commit comments

Comments
 (0)