Skip to content

Commit d2ebb48

Browse files
committed
windows specific bug with temp files
1 parent 62065fb commit d2ebb48

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

tests/test_datamodule.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
--------------------------------------------------------------------------------
1212
"""
1313

14+
import os
1415
import unittest as ut
1516
import numpy as np
1617
import torch
@@ -429,15 +430,17 @@ def test_splits_file(self):
429430
self.assertEqual(len(ds.val_ds), len(split_val))
430431
self.assertEqual(len(ds.test_ds), len(split_test))
431432

432-
# Create a TemporaryFile to save the splits, and test the datamodule
433-
with tempfile.NamedTemporaryFile(suffix=".pt", dir=self.tmp_path) as temp:
433+
try:
434+
# Create a TemporaryFile to save the splits, and test the datamodule
435+
temp_file = tempfile.NamedTemporaryFile(suffix=".pt", dir=self.tmp_path)
436+
434437
# Save the splits
435-
torch.save(splits, temp)
438+
torch.save(splits, temp_file)
436439

437440
# Test the datamodule
438441
task_kwargs = {
439442
"df_path": csv_file,
440-
"splits_path": temp.name,
443+
"splits_path": temp_file.name,
441444
"split_val": 0.0,
442445
"split_test": 0.0,
443446
}
@@ -474,6 +477,10 @@ def test_splits_file(self):
474477
)
475478
np.testing.assert_array_equal(ds.val_ds.smiles_offsets_tensor, ds2.val_ds.smiles_offsets_tensor)
476479
np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor)
480+
481+
finally:
482+
temp_file.close()
483+
os.unlink(temp_file.name)
477484

478485

479486
if __name__ == "__main__":

0 commit comments

Comments
 (0)