File tree 1 file changed +11
-4
lines changed
1 file changed +11
-4
lines changed Original file line number Diff line number Diff line change 11
11
--------------------------------------------------------------------------------
12
12
"""
13
13
14
+ import os
14
15
import unittest as ut
15
16
import numpy as np
16
17
import torch
@@ -429,15 +430,17 @@ def test_splits_file(self):
429
430
self .assertEqual (len (ds .val_ds ), len (split_val ))
430
431
self .assertEqual (len (ds .test_ds ), len (split_test ))
431
432
432
- # Create a TemporaryFile to save the splits, and test the datamodule
433
- with tempfile .NamedTemporaryFile (suffix = ".pt" , dir = self .tmp_path ) as temp :
433
+ try :
434
+ # Create a TemporaryFile to save the splits, and test the datamodule
435
+ temp_file = tempfile .NamedTemporaryFile (suffix = ".pt" , dir = self .tmp_path )
436
+
434
437
# Save the splits
435
- torch .save (splits , temp )
438
+ torch .save (splits , temp_file )
436
439
437
440
# Test the datamodule
438
441
task_kwargs = {
439
442
"df_path" : csv_file ,
440
- "splits_path" : temp .name ,
443
+ "splits_path" : temp_file .name ,
441
444
"split_val" : 0.0 ,
442
445
"split_test" : 0.0 ,
443
446
}
@@ -474,6 +477,10 @@ def test_splits_file(self):
474
477
)
475
478
np .testing .assert_array_equal (ds .val_ds .smiles_offsets_tensor , ds2 .val_ds .smiles_offsets_tensor )
476
479
np .testing .assert_array_equal (ds .test_ds .smiles_offsets_tensor , ds2 .test_ds .smiles_offsets_tensor )
480
+
481
+ finally :
482
+ temp_file .close ()
483
+ os .unlink (temp_file .name )
477
484
478
485
479
486
if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments