Skip to content

Commit 474c802

Browse files
committed
addressing tempfile problem on windows
1 parent e796093 commit 474c802

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

tests/test_datamodule.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
import torch
1818
import pandas as pd
1919
import tempfile
20-
import pytest
2120

22-
import graphium
2321
from graphium.utils.fs import rm, exists, get_size
2422
from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule
2523

@@ -29,10 +27,6 @@
2927

3028
class test_DataModule(ut.TestCase):
3129

32-
@pytest.fixture
33-
def _setup_tmp_path(self, tmp_path):
34-
self.tmp_path = tmp_path
35-
3630
def test_ogb_datamodule(self):
3731
# other datasets are too large to be tested
3832
dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"]
@@ -386,7 +380,6 @@ def test_datamodule_multiple_data_files(self):
386380

387381
self.assertEqual(len(ds.train_ds), 20)
388382

389-
@pytest.mark.usefixtures("_setup_tmp_path")
390383
def test_splits_file(self):
391384
# Test single CSV files
392385
csv_file = "tests/data/micro_ZINC_shard_1.csv"
@@ -432,7 +425,7 @@ def test_splits_file(self):
432425

433426
try:
434427
# Create a TemporaryFile to save the splits, and test the datamodule
435-
temp_file = tempfile.NamedTemporaryFile(suffix=".pt", dir=self.tmp_path)
428+
temp_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)
436429

437430
# Save the splits
438431
torch.save(splits, temp_file)
@@ -479,7 +472,8 @@ def test_splits_file(self):
479472
np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor)
480473

481474
finally:
482-
temp_file.close()
475+
temp_file.close()
476+
os.unlink(temp_file.name)
483477

484478

485479
if __name__ == "__main__":

0 commit comments

Comments
 (0)