|
11 | 11 | --------------------------------------------------------------------------------
|
12 | 12 | """
|
13 | 13 |
|
| 14 | +import os |
14 | 15 | import unittest as ut
|
15 | 16 | import numpy as np
|
16 | 17 | import torch
|
17 | 18 | import pandas as pd
|
18 | 19 | import tempfile
|
19 | 20 |
|
20 |
| -import graphium |
21 | 21 | from graphium.utils.fs import rm, exists, get_size
|
22 | 22 | from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule
|
23 | 23 |
|
24 | 24 | import graphium_cpp
|
25 | 25 |
|
26 | 26 | TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000"
|
27 | 27 |
|
28 |
| - |
29 | 28 | class test_DataModule(ut.TestCase):
|
| 29 | + |
30 | 30 | def test_ogb_datamodule(self):
|
31 | 31 | # other datasets are too large to be tested
|
32 | 32 | dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"]
|
@@ -380,7 +380,7 @@ def test_datamodule_multiple_data_files(self):
|
380 | 380 |
|
381 | 381 | self.assertEqual(len(ds.train_ds), 20)
|
382 | 382 |
|
383 |
| - def test_splits_file(self, tmp_path): |
| 383 | + def test_splits_file(self): |
384 | 384 | # Test single CSV files
|
385 | 385 | csv_file = "tests/data/micro_ZINC_shard_1.csv"
|
386 | 386 | df = pd.read_csv(csv_file)
|
@@ -423,15 +423,17 @@ def test_splits_file(self, tmp_path):
|
423 | 423 | self.assertEqual(len(ds.val_ds), len(split_val))
|
424 | 424 | self.assertEqual(len(ds.test_ds), len(split_test))
|
425 | 425 |
|
426 |
| - # Create a TemporaryFile to save the splits, and test the datamodule |
427 |
| - with tempfile.NamedTemporaryFile(suffix=".pt", dir=tmp_path) as temp: |
| 426 | + try: |
| 427 | + # Create a TemporaryFile to save the splits, and test the datamodule |
| 428 | + temp_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False) |
| 429 | + |
428 | 430 | # Save the splits
|
429 |
| - torch.save(splits, temp) |
| 431 | + torch.save(splits, temp_file) |
430 | 432 |
|
431 | 433 | # Test the datamodule
|
432 | 434 | task_kwargs = {
|
433 | 435 | "df_path": csv_file,
|
434 |
| - "splits_path": temp.name, |
| 436 | + "splits_path": temp_file.name, |
435 | 437 | "split_val": 0.0,
|
436 | 438 | "split_test": 0.0,
|
437 | 439 | }
|
@@ -468,6 +470,10 @@ def test_splits_file(self, tmp_path):
|
468 | 470 | )
|
469 | 471 | np.testing.assert_array_equal(ds.val_ds.smiles_offsets_tensor, ds2.val_ds.smiles_offsets_tensor)
|
470 | 472 | np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor)
|
| 473 | + |
| 474 | + finally: |
| 475 | + temp_file.close() |
| 476 | + os.unlink(temp_file.name) |
471 | 477 |
|
472 | 478 |
|
473 | 479 | if __name__ == "__main__":
|
|
0 commit comments