diff --git a/multipy/runtime/test_compat.py b/multipy/runtime/test_compat.py index 8a73db66..bbf234c7 100644 --- a/multipy/runtime/test_compat.py +++ b/multipy/runtime/test_compat.py @@ -7,6 +7,7 @@ import unittest import torch +import torch._dynamo class TestCompat(unittest.TestCase): @@ -22,31 +23,41 @@ def test_pytorch3d(self): def test_hf_tokenizers(self): import tokenizers # noqa: F401 - @unittest.skip("torch.Library is not supported") def test_torchdynamo_eager(self): - import torch._dynamo as torchdynamo - @torchdynamo.optimize("eager") + torch._dynamo.reset() + def fn(x, y): a = torch.cos(x) b = torch.sin(y) return a + b - fn(torch.randn(10), torch.randn(10)) + c_fn = torch.compile(fn, backend="eager") + c_fn(torch.randn(10), torch.randn(10)) - @unittest.skip("torch.Library is not supported") def test_torchdynamo_ofi(self): - import torch._dynamo as torchdynamo - torchdynamo.reset() + torch._dynamo.reset() + + def fn(x, y): + a = torch.cos(x) + b = torch.sin(y) + return a + b + + c_fn = torch.compile(fn, backend="ofi") + c_fn(torch.randn(10), torch.randn(10)) + + def test_torchdynamo_inductor(self): + + torch._dynamo.reset() - @torchdynamo.optimize("ofi") def fn(x, y): a = torch.cos(x) b = torch.sin(y) return a + b - fn(torch.randn(10), torch.randn(10)) + c_fn = torch.compile(fn) + c_fn(torch.randn(10), torch.randn(10)) if __name__ == "__main__":