Skip to content

Commit 3a10b45

Browse files
committed
fixing CI on Windows & re-enabling other OS'
1 parent 17346f8 commit 3a10b45

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

.github/workflows/test.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ jobs:
1919
strategy:
2020
fail-fast: false
2121
matrix:
22-
os: [windows-latest] # ubuntu-latest, macos-latest
23-
python-version: ["3.11"] # -> Will re-enable support for py312 once pyg is released, "3.10",
22+
os: [ubuntu-latest, macos-latest, windows-latest]
23+
python-version: ["3.10", "3.11"] # -> Will re-enable support for py312 once pyg is released, "3.10",
2424

2525
runs-on: ${{ matrix.os }}
2626
timeout-minutes: 30

tests/test_datamodule.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,22 @@
1111
--------------------------------------------------------------------------------
1212
"""
1313

14+
import os
1415
import unittest as ut
1516
import numpy as np
1617
import torch
1718
import pandas as pd
1819
import tempfile
1920

20-
import graphium
2121
from graphium.utils.fs import rm, exists, get_size
2222
from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule
2323

2424
import graphium_cpp
2525

2626
TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000"
2727

28-
2928
class test_DataModule(ut.TestCase):
29+
3030
def test_ogb_datamodule(self):
3131
# other datasets are too large to be tested
3232
dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"]
@@ -380,7 +380,7 @@ def test_datamodule_multiple_data_files(self):
380380

381381
self.assertEqual(len(ds.train_ds), 20)
382382

383-
def test_splits_file(self, tmp_path):
383+
def test_splits_file(self):
384384
# Test single CSV files
385385
csv_file = "tests/data/micro_ZINC_shard_1.csv"
386386
df = pd.read_csv(csv_file)
@@ -423,15 +423,17 @@ def test_splits_file(self, tmp_path):
423423
self.assertEqual(len(ds.val_ds), len(split_val))
424424
self.assertEqual(len(ds.test_ds), len(split_test))
425425

426-
# Create a TemporaryFile to save the splits, and test the datamodule
427-
with tempfile.NamedTemporaryFile(suffix=".pt", dir=tmp_path) as temp:
426+
try:
427+
# Create a TemporaryFile to save the splits, and test the datamodule
428+
temp_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)
429+
428430
# Save the splits
429-
torch.save(splits, temp)
431+
torch.save(splits, temp_file)
430432

431433
# Test the datamodule
432434
task_kwargs = {
433435
"df_path": csv_file,
434-
"splits_path": temp.name,
436+
"splits_path": temp_file.name,
435437
"split_val": 0.0,
436438
"split_test": 0.0,
437439
}
@@ -468,6 +470,10 @@ def test_splits_file(self, tmp_path):
468470
)
469471
np.testing.assert_array_equal(ds.val_ds.smiles_offsets_tensor, ds2.val_ds.smiles_offsets_tensor)
470472
np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor)
473+
474+
finally:
475+
temp_file.close()
476+
os.unlink(temp_file.name)
471477

472478

473479
if __name__ == "__main__":

0 commit comments

Comments
 (0)