diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 64846bd91..f84debd95 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,8 +16,13 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] - pytorch-version: ["2.0"] + include: + - python-version: "3.10" + pytorch-version: "2.0" + - python-version: "3.11" + pytorch-version: "2.0" + - python-version: "3.12" + pytorch-version: "2.3" runs-on: "ubuntu-latest" timeout-minutes: 30 @@ -49,8 +54,11 @@ jobs: - name: Install library run: python -m pip install --no-deps -e . # `-e` required for correct `coverage` run. - - name: Run tests - run: pytest -m 'not ipu' + - name: Install test dependencies + run: micromamba install -c conda-forge pytdc # Required to run the `test_finetuning.py` + + - name: Install C++ library + run: cd graphium/graphium_cpp && git clone https://github.com/pybind/pybind11.git && export PYTHONPATH=$PYTHONPATH:./pybind11 && python -m pip install . && cd ../.. - name: Test CLI run: graphium --help diff --git a/.github/workflows/test_ipu.yml b/.github/workflows/test_ipu.yml deleted file mode 100644 index 886c4c2b7..000000000 --- a/.github/workflows/test_ipu.yml +++ /dev/null @@ -1,69 +0,0 @@ -name: test-ipu - -on: - push: - branches: ["main"] - tags: ["*"] - pull_request: - branches: - - "*" - - "!gh-pages" - schedule: - - cron: "0 4 * * *" - -jobs: - test-ipu: - strategy: - fail-fast: false - matrix: - python-version: ["3.8"] - pytorch-version: ["2.0"] - - runs-on: "ubuntu-20.04" - timeout-minutes: 30 - - defaults: - run: - shell: bash -l {0} - - name: | - poptorch_env - - python=${{ matrix.python-version }} - - pytorch=${{ matrix.pytorch-version }} - - steps: - - name: Checkout the code - uses: actions/checkout@v3 - - - name: Activate SDK + Install Requirements - run: | - python3 -m pip install --upgrade pip - wget -q -O 'poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz' 'https://downloads.graphcore.ai/direct?package=poplar-poplar_sdk_ubuntu_20_04_3.3.0_208993bbb7-3.3.0&file=poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz' - tar -xzf poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz - python3 -m pip install poplar_sdk-ubuntu_20_04-3.3.0+1403-208993bbb7/poptorch-3.3.0+113432_960e9c294b_ubuntu_20_04-cp38-cp38-linux_x86_64.whl - # Enable Poplar SDK (including Poplar and PopART) - source poplar_sdk-ubuntu_20_04-3.3.0+1403-208993bbb7/enable - - python -c "import poptorch" - - # Download the datafiles (Total ~ 10Mb - nothing compared to the libraries) - wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz - wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz - wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz - - - # Install the IPU specific and graphium requirements - pip install -r requirements_ipu.txt - # Install Graphium in dev mode - python -m pip install --no-deps -e . - python3 -m pytest -m 'not skip_ipu' - - - name: Codecov Upload - uses: codecov/codecov-action@v3 - with: - files: ./coverage.xml - flags: unittests - name: codecov-umbrella - fail_ci_if_error: false - verbose: false - env_vars: ${{ matrix.python-version }},${{ matrix.pytorch-version }} diff --git a/.gitignore b/.gitignore index 289f10a4d..751e6aeb9 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ draft/ scripts-expts/ sweeps/ mup/ +loc-* # Data and predictions graphium/data/ZINC_bench_gnn/ @@ -38,6 +39,7 @@ graphium/data/cache/ graphium/data/b3lyp/ graphium/data/PCQM4Mv2/ graphium/data/PCQM4M/ +graphium/data/largemix/ graphium/data/neurips2023/small-dataset/ graphium/data/neurips2023/large-dataset/ graphium/data/neurips2023/dummy-dataset/ @@ -53,15 +55,6 @@ debug/ change_commits.sh graphium/features/test_new_pes.ipynb -# IPU related ignores and profiler outputs -*.a -*.cbor -*.capnp -*.pop -*.popart -*.pop_cache -*.popef -*.pvti* ############ END graphium Custom GitIgnore ############## diff --git a/LICENSE b/LICENSE index 4cef7c9e1..cbca6ebfd 100644 --- a/LICENSE +++ b/LICENSE @@ -189,6 +189,7 @@ Copyright 2023 Valence Labs Copyright 2023 Recursion Pharmaceuticals Copyright 2023 Graphcore Limited + Copyright 2024 NVIDIA CORPORATION & AFFILIATES Various Academic groups have also contributed to this software under the given license. These include, but are not limited, to the following diff --git a/README.md b/README.md index 53a7172bb..e76e993ee 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,6 @@ [![GitHub Repo stars](https://img.shields.io/github/stars/datamol-io/graphium)](https://github.com/datamol-io/graphium/stargazers) [![GitHub Repo stars](https://img.shields.io/github/forks/datamol-io/graphium)](https://github.com/datamol-io/graphium/network/members) [![test](https://github.com/datamol-io/graphium/actions/workflows/test.yml/badge.svg)](https://github.com/datamol-io/graphium/actions/workflows/test.yml) -[![test-ipu](https://github.com/datamol-io/graphium/actions/workflows/test_ipu.yml/badge.svg)](https://github.com/datamol-io/graphium/actions/workflows/test_ipu.yml) [![release](https://github.com/datamol-io/graphium/actions/workflows/release.yml/badge.svg)](https://github.com/datamol-io/graphium/actions/workflows/release.yml) [![code-check](https://github.com/datamol-io/graphium/actions/workflows/code-check.yml/badge.svg)](https://github.com/datamol-io/graphium/actions/workflows/code-check.yml) [![doc](https://github.com/datamol-io/graphium/actions/workflows/doc.yml/badge.svg)](https://github.com/datamol-io/graphium/actions/workflows/doc.yml) @@ -35,8 +34,6 @@ Visit https://graphium-docs.datamol.io/. ## Installation for developers -### For CPU and GPU developers - Use [`mamba`](https://github.com/mamba-org/mamba), a faster and better alternative to `conda`. If you are using a GPU, we recommend enforcing the CUDA version that you need with `CONDA_OVERRIDE_CUDA=XX.X`. @@ -53,25 +50,67 @@ mamba activate graphium pip install --no-deps -e . ``` -### For IPU developers +## Training a model + +To learn how to train a model, we invite you to look at the documentation, or the jupyter notebooks available [here](https://github.com/datamol-io/graphium/tree/master/docs/tutorials/model_training). + +If you are not familiar with [PyTorch](https://pytorch.org/docs) or [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), we highly recommend going through their tutorial first. + +## Running an experiment + +### Datasets + +Graphium provides configs for 2 datasets: `toymix` and `largemix`. +`Toymix` uses 3 datasets, which are referenced in datamodule [here](https://github.com/datamol-io/graphium/blob/d12df7e06828fa7d7f8792141d058a60b2b2d258/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml#L59-L102). Its datasets and their splits files can be downloaded from here: + ```bash -# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu` -./install_ipu.sh .graphium_ipu +# Change or make the directory to where the dataset is to be downloaded +cd expts/data/neurips2023/small-dataset + +# QM9 +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt + +# Tox21 +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.p + +# Zinc +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt ``` -The above step needs to be done once. After that, enable the SDK and the environment as follows: +`Largemix` uses datasets referenced in datamodule [here](https://github.com/datamol-io/graphium/blob/e887176f71ee95c3b82f8f6b56c706eaa9765bf1/expts/hydra-configs/tasks/loss_metrics_datamodule/largemix.yaml#L82C1-L155C37). Its datasets and their splits files can be downloaded from here: + ```bash -source enable_ipu.sh .graphium_ipu -``` +# Change or make the directory to where the dataset is to be downloaded +cd ../data/graphium/large-dataset/ -## Training a model +# L1000_VCAP +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt -To learn how to train a model, we invite you to look at the documentation, or the jupyter notebooks available [here](https://github.com/datamol-io/graphium/tree/master/docs/tutorials/model_training). +# L1000_MCF7 +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt -If you are not familiar with [PyTorch](https://pytorch.org/docs) or [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), we highly recommend going through their tutorial first. +# PCBA_1328 +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt + +# PCQM4M_G25 +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt + +#PCQM4M_N4 +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet +wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt +``` +These datasets can be used further for pretraining. + +### Pretraining -## Running an experiment We have setup Graphium with `hydra` for managing config files. To run an experiment go to the `expts/` folder. For example, to benchmark a GCN on the ToyMix dataset run ```bash graphium-train architecture=toymix tasks=toymix training=toymix model=gcn @@ -86,10 +125,7 @@ Integrating `hydra` also allows you to quickly switch between accelerators. E.g. graphium-train architecture=toymix tasks=toymix training=toymix model=gcn accelerator=gpu ``` automatically selects the correct configs to run the experiment on GPU. -Finally, you can also run a fine-tuning loop: -```bash -graphium-train +finetuning=admet -``` +To use Largemix dataset instead, replace `toymix` to `largemix` in the above commmands. To use a config file you built from scratch you can run ```bash @@ -97,23 +133,38 @@ graphium-train --config-path [PATH] --config-name [CONFIG] ``` Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium. -## Preparing the data in advance -The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`. +### Finetuning -However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). Preparing data in advance is also beneficial when running lots of concurrent jobs with identical molecular featurization, so that resources aren't wasted and processes don't conflict reading/writing in the same directory. +After pretraining a model and saving a model checkpoint, the model can be finetuned to a new task -The following command-line will prepare the data and cache it, then use it to train a model. ```bash -# First prepare the data and cache it in `path_to_cached_data` -graphium data prepare ++datamodule.args.processed_graph_data_path=[path_to_cached_data] +graphium-train +finetuning [example-custom OR example-tdc] finetuning.pretrained_model=[model_identifier] +``` -# Then train the model on the prepared data -graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data] +The `[model_identifier]` serves to identify the pretrained model among those maintained in the `GRAPHIUM_PRETRAINED_MODELS_DICT` in `graphium/utils/spaces.py`, where the `[model_identifier]` maps to the location of the checkpoint of the pretrained model. + +We have provided two example yaml configs under `expts/hydra-configs/finetuning` for finetuning on a custom dataset (`example-custom.yaml`) or for a task from the TDC benchmark collection (`example-tdc.yaml`). + +When using `example-custom.yaml`, to finetune on a custom dataset, we nee to provide the location of the data (`constants.data_path=[path_to_data]`) and the type of task (`constants.task_type=[cls OR reg]`). + +When using `example-tdc.yaml`, to finetune on a TDC task, we only need to provide the task name (`constants.task=[task_name]`) and the task type is inferred automatically. + +Custom datasets to finetune from consist of two files `raw.csv` and `split.csv`. The `raw.csv` contains two columns, namely `smiles` with the smiles strings, and `target` with the corresponding targets. In `split.csv`, three columns `train`, `val`, `test` contain the indices of the rows in `raw.csv`. Examples can be found under `expts/data/finetuning_example-reg` (regression) and `expts/data/finetuning_example-cls` (binary classification). + +### Fingerprinting + +Alternatively, we can also obtain molecular embeddings (fingerprints) from a pretrained model: +```bash +graphium fps create [example-custom OR example-tdc] pretrained.model=[model_identifier] pretrained.layers=[layer_identifiers] ``` -**Note** that `datamodule.args.processed_graph_data_path` can also be specified at `expts/hydra_configs/`. +We have provided two example yaml configs under `expts/hydra-configs/fingerprinting` for extracting fingerprints for a custom dataset (`example-custom.yaml`) or for a dataset from the TDC benchmark collection (`expample-tdc.yaml`). + +After specifiying the `[model_identifier]`, we need to provide a list of layers from that model where we want to read out embeddings via `[layer_identifiers]` (which requires knowledge of the architecture of the pretrained model). + +When using `example-custom.yaml`, the location of the smiles to be embedded needs to be passed via `datamodule.df_path=[path_to_data]`. The data can be passed as a csv/parquet file with a column `smiles`, similar to `expts/data/finetuning_example-reg/raw.csv`. -**Note** that, every time the configs of `datamodule.args.featurization` changes, you will need to run a new data preparation, which will automatically be saved in a separate directory that uses a hash unique to the configs. +When extracting fingerprints for a TDC task using `expample-tdc.yaml`, we need to specify `datamodule.benchmark` and `datamodule.task` instead of `datamodule.df_path`. ## License diff --git a/codecov.yml b/codecov.yml index 60a2c37d6..94e8bd149 100644 --- a/codecov.yml +++ b/codecov.yml @@ -16,8 +16,3 @@ component_management: target: auto branches: - "!main" - individual_components: - - component_id: ipu # this is an identifier that should not be changed - name: ipu # this is a display name, and can be changed freely - paths: - - graphium/ipu/** diff --git a/docs/api/graphium.features.md b/docs/api/graphium.features.md index 758d14135..fa9080700 100644 --- a/docs/api/graphium.features.md +++ b/docs/api/graphium.features.md @@ -5,37 +5,8 @@ Feature extraction and manipulation === "Contents" * [Featurizer](#featurizer) - * [Positional Encoding](#positional-encoding) - * [Properties](#properties) - * [Spectral PE](#spectral-pe) - * [Random Walk PE](#random-walk-pe) - * [NMP](#nmp) ## Featurizer ------------ ::: graphium.features.featurizer - -## Positional Encoding ------------- -::: graphium.features.positional_encoding - - -## Properties ------------- -::: graphium.features.properties - - -## Spectral PE ------------- -::: graphium.features.spectral - - -## Random Walk PE ------------- -::: graphium.features.rw - - -## NMP ------------- -::: graphium.features.nmp diff --git a/docs/api/graphium.finetuning.md b/docs/api/graphium.finetuning.md index 7e2b7f444..fb8c5e418 100644 --- a/docs/api/graphium.finetuning.md +++ b/docs/api/graphium.finetuning.md @@ -10,4 +10,4 @@ Module for finetuning models and doing linear probing (fingerprinting). ::: graphium.finetuning.finetuning_architecture.FinetuningHead -::: graphium.finetuning.fingerprinting.Fingerprinter +::: graphium.fingerprinting.fingerprinter.Fingerprinter diff --git a/docs/api/graphium.ipu.md b/docs/api/graphium.ipu.md deleted file mode 100644 index 2fdf82416..000000000 --- a/docs/api/graphium.ipu.md +++ /dev/null @@ -1,48 +0,0 @@ -graphium.ipu -==================== -Code for adapting to run on IPU - -=== "Contents" - - * [IPU Dataloader](#ipu-dataloader) - * [IPU Losses](#ipu-losses) - * [IPU Metrics](#ipu-metrics) - * [IPU Simple Lightning](#ipu-simple-lightning) - * [IPU Utils](#ipu-utils) - * [IPU Wrapper](#ipu-wrapper) - * [To Dense Batch](#to-dense-batch) - -## IPU Dataloader ------------- -::: graphium.ipu.ipu_dataloader - - -## IPU Losses ------------- -::: graphium.ipu.ipu_losses - - -## IPU Metrics ------------- -::: graphium.ipu.ipu_metrics - - -## IPU Simple Lightning ------------- -::: graphium.ipu.ipu_simple_lightning - - -## IPU Utils ------------- -::: graphium.ipu.ipu_utils - - -## IPU Wrapper ------------- -::: graphium.ipu.ipu_wrapper - - -## To Dense Batch ------------- -::: graphium.ipu.to_dense_batch - diff --git a/docs/api/graphium.utils.md b/docs/api/graphium.utils.md index 5804a060e..632c6ea06 100644 --- a/docs/api/graphium.utils.md +++ b/docs/api/graphium.utils.md @@ -46,10 +46,6 @@ module for utility functions ::: graphium.utils.mup -## Read File ----------------- -::: graphium.utils.read_file - ## Safe Run ---------------- ::: graphium.utils.safe_run diff --git a/docs/cli/graphium-train.md b/docs/cli/graphium-train.md index 0b421be67..b51f7e50a 100644 --- a/docs/cli/graphium-train.md +++ b/docs/cli/graphium-train.md @@ -24,7 +24,7 @@ graphium-train architecture=toymix tasks=toymix training=toymix model=gcn accele automatically selects the correct configs to run the experiment on GPU. Finally, you can also run a fine-tuning loop: ```bash -graphium-train +finetuning=admet +graphium-train +finetuning=example-tdc ``` To use a config file you built from scratch you can run diff --git a/docs/cli/graphium.md b/docs/cli/graphium.md index d90aa8aad..b2d816fad 100644 --- a/docs/cli/graphium.md +++ b/docs/cli/graphium.md @@ -103,7 +103,7 @@ $ graphium finetune [OPTIONS] COMMAND [ARGS]... **Commands**: -* `admet`: Utility CLI to easily fine-tune a model on... +* `tdc`: Utility CLI to easily fine-tune a model on... * `fingerprint`: Endpoint for getting fingerprints from a... ### `graphium finetune admet` @@ -135,7 +135,7 @@ Endpoint for getting fingerprints from a pretrained model. The pretrained model should be a `.ckpt` path or pre-specified, named model within Graphium. The fingerprint layer specification should be of the format `module:layer`. If specified as a list, the fingerprints from all the specified layers will be concatenated. -See the docs of the `graphium.finetuning.fingerprinting.Fingerprinter` class for more info. +See the docs of the `graphium.fingerprinting.fingerprinter.Fingerprinter` class for more info. **Usage**: diff --git a/docs/contribute.md b/docs/contribute.md index b4fef7ce0..4f9f71fae 100644 --- a/docs/contribute.md +++ b/docs/contribute.md @@ -18,21 +18,6 @@ mamba activate graphium pip install --no-deps -e . ``` -### For IPU developers - -Download the SDK and use pypi to create your environment: - -```bash -# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu` -./install_ipu.sh .graphium_ipu -``` - -The above step needs to be done once. After that, enable the SDK and the environment as follows: - -```bash -source enable_ipu.sh .graphium_ipu -``` - ## Build the documentation You can build and serve the documentation locally with: diff --git a/docs/design.md b/docs/design.md index 380ac28e4..43594fc98 100644 --- a/docs/design.md +++ b/docs/design.md @@ -42,7 +42,6 @@ Below are a list of directory and their respective documentations: - [data](https://github.com/datamol-io/graphium/blob/main/graphium/data/README.md) - [features](https://github.com/datamol-io/graphium/tree/main/graphium/features/README.md) - finetuning -- [ipu](https://github.com/datamol-io/graphium/tree/main/graphium/ipu/README.md) - [nn](https://github.com/datamol-io/graphium/tree/main/graphium/nn/README.md) - [trainer](https://github.com/datamol-io/graphium/tree/main/graphium/trainer/README.md) - [utils](https://github.com/datamol-io/graphium/tree/main/graphium/features/README.md) @@ -56,7 +55,7 @@ Hence, we use [hydra](https://hydra.cc/docs/intro/) to enable splitting the conf Examples of possibilities include: -- Switching between accelerators (CPU, GPU and IPU) +- Switching between accelerators (CPU, GPU) - Benchmarking different models on the same dataset - Fine-tuning a pre-trained model on a new dataset diff --git a/docs/index.md b/docs/index.md index ef01e54be..32170ee29 100644 --- a/docs/index.md +++ b/docs/index.md @@ -25,17 +25,6 @@ or pip: pip install graphium ``` -### For IPU -```bash -# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu` -./install_ipu.sh .graphium_ipu -``` - -The above step needs to be done once. After that, enable the SDK and the environment as follows: - -```bash -source enable_ipu.sh .graphium_ipu -``` Finally, you will need to install graphium with pip ```bash diff --git a/docs/tutorials/feature_processing/timing_parallel.ipynb b/docs/tutorials/feature_processing/timing_parallel.ipynb index 477251e71..22e3552e5 100644 --- a/docs/tutorials/feature_processing/timing_parallel.ipynb +++ b/docs/tutorials/feature_processing/timing_parallel.ipynb @@ -476,7 +476,7 @@ ], "metadata": { "kernelspec": { - "display_name": "graphium_ipu", + "display_name": "graphium", "language": "python", "name": "python3" }, diff --git a/docs/tutorials/model_training/simple-molecular-model.ipynb b/docs/tutorials/model_training/simple-molecular-model.ipynb index 26a45cfa0..545109ebc 100644 --- a/docs/tutorials/model_training/simple-molecular-model.ipynb +++ b/docs/tutorials/model_training/simple-molecular-model.ipynb @@ -405,9 +405,9 @@ " zinc:\n", " - mae\n", " loss_fun:\n", - " qm9: mae_ipu\n", - " tox21: bce_logits_ipu\n", - " zinc: mae_ipu\n", + " qm9: mae\n", + " tox21: bce_logits\n", + " zinc: mae\n", " random_seed: ${constants.seed}\n", " optim_kwargs:\n", " lr: 4.0e-05\n", @@ -451,28 +451,28 @@ "metrics:\n", " qm9:\n", " - name: mae\n", - " metric: mae_ipu\n", + " metric: mae\n", " target_nan_mask: null\n", " multitask_handling: flatten\n", " threshold_kwargs: null\n", " - name: pearsonr\n", - " metric: pearsonr_ipu\n", + " metric: pearsonr\n", " threshold_kwargs: null\n", " target_nan_mask: null\n", " multitask_handling: mean-per-label\n", " - name: r2_score\n", - " metric: r2_score_ipu\n", + " metric: r2_score\n", " target_nan_mask: null\n", " multitask_handling: mean-per-label\n", " threshold_kwargs: null\n", " tox21:\n", " - name: auroc\n", - " metric: auroc_ipu\n", + " metric: auroc\n", " task: binary\n", " multitask_handling: mean-per-label\n", " threshold_kwargs: null\n", " - name: avpr\n", - " metric: average_precision_ipu\n", + " metric: average_precision\n", " task: binary\n", " multitask_handling: mean-per-label\n", " threshold_kwargs: null\n", @@ -498,17 +498,17 @@ " th_on_target: true\n", " zinc:\n", " - name: mae\n", - " metric: mae_ipu\n", + " metric: mae\n", " target_nan_mask: null\n", " multitask_handling: flatten\n", " threshold_kwargs: null\n", " - name: pearsonr\n", - " metric: pearsonr_ipu\n", + " metric: pearsonr\n", " threshold_kwargs: null\n", " target_nan_mask: null\n", " multitask_handling: mean-per-label\n", " - name: r2_score\n", - " metric: r2_score_ipu\n", + " metric: r2_score\n", " target_nan_mask: null\n", " multitask_handling: mean-per-label\n", " threshold_kwargs: null\n", diff --git a/enable_ipu.sh b/enable_ipu.sh deleted file mode 100755 index 63dd34987..000000000 --- a/enable_ipu.sh +++ /dev/null @@ -1,29 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Graphcore Limited. -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Graphcore Limited is not liable -for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -#!/bin/bash - -# Default location for the virtual environment -default_venv_name=".graphium_ipu" - -# Allow the user to specify the location of their virtual environment -# If not specified, use the default location -venv_name=${1:-$default_venv_name} - -# Constants -sdk_path="${venv_name}/poplar_sdk-ubuntu_20_04-3.3.0+1403-208993bbb7" - -# Source the virtual environment -source ${venv_name}/bin/activate -source ${sdk_path}/enable \ No newline at end of file diff --git a/env.yml b/env.yml index fa4e89136..f73a5d219 100644 --- a/env.yml +++ b/env.yml @@ -12,7 +12,7 @@ dependencies: - platformdirs # scientific - - numpy + - numpy == 1.26.4 - scipy >=1.4 - pandas >=1.0 - scikit-learn @@ -31,7 +31,7 @@ dependencies: - cuda-version # works also with CPU-only system. - pytorch >=1.12 - lightning >=2.0 - - torchmetrics >=0.7.0,<0.11 + - torchmetrics - ogb - pytorch_geometric >=2.0 # Use `pyg` for Windows instead of `pytorch_geometric` - wandb @@ -41,8 +41,9 @@ dependencies: - pytorch_scatter >=2.0 # chemistry - - rdkit + - rdkit == 2024.03.4 - datamol >=0.10 + - boost # needed by rdkit # Optional deps - sympy @@ -69,7 +70,16 @@ dependencies: - markdown-include - mike >=1.0.0 + # C++ dependencies + - gcc_linux-64 # Sometimes I find that I need to enforce `gcc_linux-64`, but that won't work with Mac, Windows, or Arm-Linux + - gxx_linux-64 # Sometimes I find that I need to enforce `gxx_linux-64`, but that won't work with Mac, Windows, or Arm-Linux + - libgcc + - pybind11 + - boost + + # Optional + - pytdc + - pip: - - lightning-graphcore # optional, for using IPUs only - hydra-core>=1.3.2 - hydra-optuna-sweeper diff --git a/expts/configs/config_gps_10M_pcqm4m.yaml b/expts/configs/config_gps_10M_pcqm4m.yaml index 10faa3b1e..8fc38812e 100644 --- a/expts/configs/config_gps_10M_pcqm4m.yaml +++ b/expts/configs/config_gps_10M_pcqm4m.yaml @@ -1,22 +1,15 @@ -# Testing the mpnn only model with the PCQMv2 dataset on IPU. +# Testing the mpnn only model with the PCQMv2 dataset. constants: name: &name pcqm4mv2_mpnn_4layer seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 20 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 60 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 16 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 120 + # Data handling-related batch_size_training: 64 batch_size_inference: 16 @@ -28,38 +21,8 @@ accelerator: precision: 16 accumulate_grad_batches: 4 - ipu_config: - - deviceIterations(20) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - - ipu_inference_config: # Optional. If not provided, same as `ipu_config` - - deviceIterations(80) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 256 -# batch_size_inference: 64 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" homolumo: @@ -76,10 +39,6 @@ datamodule: split_test: 0.1 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -115,7 +74,6 @@ datamodule: num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -218,7 +176,7 @@ predictor: metrics_on_progress_bar: homolumo: ["mae", "pearsonr"] loss_fun: - homolumo: mse_ipu + homolumo: mse random_seed: *seed optim_kwargs: lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -241,12 +199,12 @@ predictor: metrics: homolumo: - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/configs/config_gps_10M_pcqm4m_mod.yaml b/expts/configs/config_gps_10M_pcqm4m_mod.yaml index e2cdb44c2..1587d2d06 100644 --- a/expts/configs/config_gps_10M_pcqm4m_mod.yaml +++ b/expts/configs/config_gps_10M_pcqm4m_mod.yaml @@ -1,14 +1,13 @@ -# Testing the mpnn only model with the PCQMv2 dataset on IPU. +# Testing the mpnn only model with the PCQMv2 dataset. constants: name: &name pcqm4mv2_mpnn_4layer seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: gpu # cpu or ipu or gpu + type: gpu # cpu or gpu datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" homolumo: @@ -25,10 +24,6 @@ datamodule: split_test: 0.1 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -84,19 +79,6 @@ datamodule: num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" - - # ipu_dataloader_training_opts: - # mode: async - # max_num_nodes_per_graph: 20 # train max nodes: 20, max_edges: 54 - # max_num_edges_per_graph: 60 - - # ipu_dataloader_inference_opts: - # mode: async - # max_num_nodes_per_graph: 20 # valid max nodes: 51, max_edges: 118 - # max_num_edges_per_graph: 120 - # # test-dev max nodes: 50, max_edges: 116 - # # test-challenge max nodes: 51, max_edges: 106 architecture: model_type: FullGraphMultiTaskNetwork @@ -229,7 +211,7 @@ predictor: metrics_on_progress_bar: homolumo: ["mae", "pearsonr"] loss_fun: - homolumo: mse_ipu + homolumo: mse random_seed: *seed optim_kwargs: lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -252,12 +234,12 @@ predictor: metrics: homolumo: - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/configs/config_mpnn_10M_b3lyp.yaml b/expts/configs/config_mpnn_10M_b3lyp.yaml index c385d7689..8403a1ba1 100644 --- a/expts/configs/config_mpnn_10M_b3lyp.yaml +++ b/expts/configs/config_mpnn_10M_b3lyp.yaml @@ -1,22 +1,15 @@ -# Testing the mpnn only model with the b3lyp dataset on IPU. +# Testing the mpnn only model with the b3lyp dataset. constants: name: &name b3lyp_mpnn_4layer seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 20 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 60 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 16 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 120 + # Data handling-related batch_size_training: 64 batch_size_inference: 16 @@ -28,39 +21,8 @@ accelerator: precision: 16 accumulate_grad_batches: 4 - ipu_config: - - deviceIterations(20) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - - ipu_inference_config: # Optional. If not provided, same as `ipu_config` - - deviceIterations(80) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 256 -# batch_size_inference: 64 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" betagap: @@ -88,12 +50,7 @@ datamodule: split_test: 0.1 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/b3lyp/" - dataloading_from: ram featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -127,7 +84,6 @@ datamodule: num_workers: 0 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -251,8 +207,8 @@ predictor: alphagap: ["mae", "pearsonr"] betagap: ["mae", "pearsonr"] loss_fun: - alphagap: mse_ipu - betagap: mse_ipu + alphagap: mse + betagap: mse random_seed: *seed optim_kwargs: lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -275,12 +231,12 @@ predictor: metrics: alphagap: &alpha_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/configs/config_mpnn_pcqm4m.yaml b/expts/configs/config_mpnn_pcqm4m.yaml index 9735f9555..0ba8d3bd5 100644 --- a/expts/configs/config_mpnn_pcqm4m.yaml +++ b/expts/configs/config_mpnn_pcqm4m.yaml @@ -1,14 +1,13 @@ -# Testing the mpnn only model with the PCQMv2 dataset on IPU. +# Testing the mpnn only model with the PCQMv2 dataset. constants: name: &name pcqm4mv2_mpnn_4layer seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: cpu # cpu or ipu or gpu + type: cpu # cpu or gpu datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" homolumo: @@ -26,12 +25,7 @@ datamodule: split_names: ["train", "valid", "test-dev"] # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 20 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "graphium/data/PCQM4Mv2/" - dataloading_from: ram featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', @@ -61,19 +55,6 @@ datamodule: num_workers: 40 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" - - # ipu_dataloader_training_opts: - # mode: async - # max_num_nodes_per_graph: 20 # train max nodes: 20, max_edges: 54 - # max_num_edges_per_graph: 60 - - # ipu_dataloader_inference_opts: - # mode: async - # max_num_nodes_per_graph: 20 # valid max nodes: 51, max_edges: 118 - # max_num_edges_per_graph: 120 - # # test-dev max nodes: 50, max_edges: 116 - # # test-challenge max nodes: 51, max_edges: 106 architecture: model_type: FullGraphMultiTaskNetwork @@ -173,7 +154,7 @@ predictor: metrics_on_progress_bar: homolumo: ["mae", "pearsonr"] loss_fun: - homolumo: mse_ipu + homolumo: mse random_seed: *seed optim_kwargs: lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -196,12 +177,12 @@ predictor: metrics: homolumo: - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/data/finetuning_example-cls/raw.csv b/expts/data/finetuning_example-cls/raw.csv new file mode 100644 index 000000000..e5df1c58c --- /dev/null +++ b/expts/data/finetuning_example-cls/raw.csv @@ -0,0 +1,201 @@ +,Drug_ID,smiles,target +0,644675,CC(=O)N(c1ccc2oc(=O)sc2c1)S(=O)(=O)c1cccs1,0 +1,644890,COc1ccccc1C(c1nnnn1C(C)(C)C)N1CCN(Cc2ccncc2)CC1,1 +2,645164,CCC(c1nnnn1CC1CCCO1)N(CCN1CCOCC1)Cc1cc2cc(C)ccc2[nH]c1=O,0 +3,6602688,Br.N=c1n(CCN2CCOCC2)c2ccccc2n1CC(=O)c1ccc(Cl)c(Cl)c1,1 +4,645448,CCC(C)(C)NC(=O)c1ccc2c(c1)N(CC(=O)OC)C(=O)C(C)(C)O2,0 +5,645569,CCc1cc2c(=O)[nH]cnc2s1,0 +6,645818,COc1cccc2c(=O)c(C(=O)NCc3cccs3)c[nH]c12,1 +7,645911,CCc1nnc(SCc2ccc(OC(C)C)cc2)n1N,0 +8,645965,Cc1ccc2c(c1)nnn2C1CCN(CC(=O)N2c3ccccc3CC2C)CC1,1 +9,646164,CCOC(=O)CSC1=C(C#N)C(C)C2=C(CCCC2=O)N1,0 +10,646293,Cc1ccc2cc(C)c3nnc(SCC(=O)NCc4ccco4)n3c2c1,0 +11,646353,CCOC(=O)N1CCN(S(=O)(=O)Cc2ccccc2)CC1,0 +12,646472,CCOC(=O)c1cc2sc(C)cc2n1CC(=O)N1CCN(C(=O)c2ccco2)CC1,0 +13,646515,CCOC(=O)C(NC(C)=O)C(OC(C)=O)c1cccc(N(CCO)CCO)c1,0 +14,646597,CCC(=O)Nc1cc2c(cc1C(=O)c1ccccc1)OCCO2,0 +15,6602690,CN(C)CCNC(=O)C(C(=O)c1ccc(F)cc1)n1ccccc1=O.Cl,0 +16,646768,Cc1cc(C)nc(N2CCC(C(=O)NCCc3ccc(F)cc3)CC2)n1,0 +17,646780,COc1cc(-c2nnc(-c3ccc(N4CCOCC4)cc3)o2)cc(OC)c1OC,0 +18,646897,CN(C)C=C1C(=O)N(C2CCCCC2)C(=O)N(C2CCCCC2)C1=O,0 +19,646955,COC(=O)CN(c1ccccn1)S(=O)(=O)c1ccccc1,0 +20,6398903,Cc1ccc(C)c(/C(O)=C2/C(=O)C(=O)N(CCN3CCOCC3)C2c2ccco2)c1,0 +21,647114,Cn1c(SC2=CS(=O)(=O)c3ccccc32)nc2ccccc21,0 +22,647205,CC(C(=O)NC1CCCC1)N(C(=O)c1snc(C(N)=O)c1N)c1ccc2c(c1)OCCO2,1 +23,647430,Cc1cc(C)n(-c2nc3c(c(=O)[nH]c(=O)n3C)n2C(C)C)n1,0 +24,647727,O=C(Nc1ccc(S(=O)(=O)N2CCCC2)cc1)c1ccc(CN2CCOCC2)cc1,1 +25,6602522,Cl.O=C(CN1CCN(c2ncccn2)CC1)NCCC1=CCCCC1,1 +26,647937,CC(=O)NC1(c2cccc(F)c2)CCN(CC(=O)NC2CCCCC2)CC1,0 +27,647996,Cc1cccc(Nc2nc3c(c(=O)n(C)c(=O)n3C)n2CC(O)CO)c1,0 +28,648175,CCn1c(Cc2ccccc2)nnc1SCC(=O)NC(C)(C)C,0 +29,648282,COc1cc2cc(CN(CCCO)S(=O)(=O)c3ccccc3Cl)c(=O)[nH]c2cc1OC,0 +30,648407,CCc1nnc(NC(=O)CSc2nnc(COc3ccccc3)n2Cc2ccccc2)s1,0 +31,648481,CSc1nc2nc3c(c(=O)n2[nH]1)CN(Cc1ccccc1)CC3,0 +32,648708,Cc1nn(C)c(C)c1CNC(=O)c1cnn2c1NC(c1ccccc1)CC2C(F)(F)F,0 +33,648836,Cc1ccc(C(c2nnnn2CC2CCCO2)N2CCN(C(=O)c3ccco3)CC2)cc1,0 +34,648878,CCN(CC)c1ccc2c(Cl)c(Br)c(=O)oc2c1,0 +35,648947,COc1ccc(OC)c(NC(=O)C(CC(=O)O)NCc2ccco2)c1,0 +36,649015,O=C1CC(c2cccs2)c2cc3c(cc2N1)OCO3,1 +37,649453,CCn1c(COc2ccccc2)nnc1SCC(=O)Nc1cc(OC)ccc1OC,0 +38,649754,Nc1ncc(-c2ccccc2)n1CC1CCCO1,1 +39,649786,CCCCn1c(SCC(=O)N2CCCC2)nc2c1c(=O)n(C)c(=O)n2C,0 +40,649878,O=C(O)CCn1nnc(-c2cccs2)n1,0 +41,650002,Cn1c(=O)[nH]c(=O)c2c1nc(NCCCO)n2CCCc1ccccc1,0 +42,650100,COc1ccccc1OCCn1cc(C(=O)c2ccco2)c2ccccc21,0 +43,650250,Cc1ccc(-c2csc(N3CCC(NS(=O)(=O)c4ccc5c(c4)OCCO5)CC3)n2)cc1,0 +44,650341,COc1ccc(S(=O)(=O)N2CCC(NC(=O)Nc3ccc(C)cc3)CC2)cc1,0 +45,650486,Cn1c(CNC(=O)Nc2ccccc2)nnc1SCc1ccccc1,0 +46,650558,O=C(CSc1nnc(CNc2ccccc2)o1)N1CCCc2ccccc21,0 +47,650691,CCCCNC(=O)NS(=O)(=O)c1ccc(C(=O)OC(C)C)o1,0 +48,6602999,CCOC(=O)C1Cc2c([nH]c3ccccc23)CN1.Cl,1 +49,650985,COc1ccc(CCNc2c(C)c(C)nc3ncnn23)cc1,1 +50,5768893,COCCN1C(=O)C(=O)/C(=C(/O)c2cccc(OC)c2)C1c1ccco1,0 +51,651076,CCC(=O)Nc1cccc(NC(=O)CSc2nnnn2Cc2ccccc2)c1,0 +52,651205,CCc1c(C)nc2c(C#N)c(C)[nH]n2c1=O,0 +53,651338,CCC(=O)N(Cc1ccco1)c1nc(-c2ccccc2)cs1,0 +54,651587,O=C(OCCCN1CCCCC1)c1ccc(O)cc1,1 +55,651589,Cc1cc(NC(=O)CCC(=O)N2CCC3(CC2)OCCO3)no1,0 +56,651769,O=C(CNC(=O)c1ccco1)OCc1c(F)cccc1Cl,0 +57,652002,O=C(CSc1nnc(-c2cccnc2)o1)Oc1ccccc1,0 +58,652521,CN(C)S(=O)(=O)c1ccc(C(=O)Nc2ccc(CN3CCCC3)cc2)cc1,1 +59,652549,CC(C)C(C(=O)NCC1CCCO1)N(Cc1ccco1)C(=O)CNS(=O)(=O)c1ccc(F)cc1,0 +60,652700,COCCNC(=O)COC(=O)c1nsc(Cl)c1Cl,0 +61,652799,CCOc1ccc(-c2nnn(CC(=O)Nc3cc(OC)ccc3OC)n2)cc1OCC,0 +62,6603138,Cl.O=C(CN1CCN(C2CCCCC2)CC1)NCCC1=CCCCC1,1 +63,653279,O=C(O)C1C2C=CC3(CN(Cc4ccccn4)C(=O)C13)O2,0 +64,653412,Cn1c(SCC(=O)NCc2ccc3c(c2)OCO3)nnc1-c1cc2ccccc2cc1O,1 +65,6603457,CCC(c1nnnn1CCOC)N1CCN(C(=O)c2ccco2)CC1.Cl,0 +66,653646,O=S(=O)(c1ccccc1)N1CCN(c2cc(-c3ccccc3)nc3ncnn23)CC1,0 +67,653695,O=C(CSc1n[nH]c(-c2ccccc2O)n1)N1CCCc2ccccc21,1 +68,653778,COc1ccccc1-n1c(SCC(=O)N2CCCC2)nc2cccnc21,0 +69,653799,CCc1nnc2sc(-c3ccc(NC(=O)c4ccco4)cc3)nn12,0 +70,653914,COc1ccc(-c2nnn(CC(=O)N(CC(=O)NCCC(C)C)Cc3cccs3)n2)cc1OC,1 +71,654078,CCn1c(SCc2ccc(C#N)cc2)nnc1-c1ccc(S(=O)(=O)N2CCCCC2)cc1,0 +72,654182,CCOC(=O)Cc1cc(=O)n2[nH]c(C)c(-c3ccccc3)c2n1,0 +73,6398932,CCOc1cccc(/C(O)=C2/C(=O)C(=O)N(Cc3ccco3)C2c2ccncc2)c1,0 +74,654363,O=C(NC1CCCCC1)C(c1cccs1)N(Cc1cccs1)C(=O)c1ccco1,1 +75,654435,O=C(CSc1nc2ccccc2o1)Nc1nc2ccccc2s1,0 +76,5373216,COc1n[nH]c2nncnc12,0 +77,654546,CCc1ccc(N2CC(C(=O)NC3=NCCS3)CC2=O)cc1,0 +78,654623,COc1ccc(CNC(=O)CN(CC2CCCO2)C(=O)CNS(=O)(=O)c2ccccc2)cc1,0 +79,654635,COc1cc(C2C(C(=O)c3ccc(C)o3)=C(O)C(=O)N2CCc2ccccc2)ccc1O,0 +80,654761,Cc1ccc(C)n1C(Cc1ccccc1)C(=O)O,0 +81,655183,COc1ccccc1CN(Cc1cc2cc(C)cc(C)c2[nH]c1=O)Cc1nnnn1CC1CCCO1,0 +82,655265,COc1ccc(OCc2nnc(SCC(=O)O)n2N)cc1,0 +83,5768421,COCCCN1C(=O)C(=O)/C(=C(/O)c2ccc(OC(C)C)c(C)c2)C1c1ccncc1,0 +84,655401,CCN(C1CCCCC1)S(=O)(=O)c1ccc(S(=O)(=O)NCc2ccncc2)cc1,1 +85,655439,CCOC(=O)C1=C(C)NC(=O)NC1c1ccoc1,0 +86,655857,c1ccc(Cn2nnc3c(N4CCc5ccccc5C4)ncnc32)cc1,0 +87,655866,COc1ccc(C(=O)NC2CC3CCCC(C2)N3CC(C)C)cc1OC,1 +88,655948,Cc1cc(N2CCN(c3nc4ccccc4s3)CC2)n2ncnc2n1,0 +89,656017,CC(C)C(=O)Nc1cc2c(cc1C(=O)c1ccccc1)OCCO2,0 +90,656027,COc1ccc(CCN2C(=O)C(O)=C(C(=O)c3ccco3)C2c2cccs2)cc1OC,0 +91,656095,COc1ccc(C2C(C(=O)N3CCOCC3)=C(C)NC3=C2C(=O)CC(C)(C)C3)c(OC)c1,0 +92,656157,O=c1oc(-c2ccco2)nc2c1cnn2-c1ccccc1,0 +93,656183,CC(C)OC(=O)NCCOC(=O)Nc1cccc(Cl)c1,0 +94,656257,Cc1ccccc1OCC1Cn2c(nc3c2c(=O)[nH]c(=O)n3C)O1,0 +95,656272,Cc1ccc(-c2[nH]n3c(=O)c4c(nc3c2C)CCCC4)cc1,1 +96,656290,Cc1cccc(NC(=O)Cn2c(=O)oc3ccccc32)c1,1 +97,6603060,I.OCCNC1=NCCN1,0 +98,6449251,COc1cc2c(cc1OC)/C(=C/C(=O)N1CCOCC1)NC(C)(C)C2.Cl,1 +99,135449532,Cc1cc(=O)[nH]c(-n2nc(C)cc2C)n1,0 +100,5940036,CCOC(=O)C1=C(N)n2c(s/c(=C\c3ccco3)c2=O)=C(C(=O)OCC)C1c1ccco1,0 +101,208296,O=c1nc(-c2ccccc2)cn[nH]1,0 +102,658411,COC(=O)c1ccc(Oc2cc(C)nc(-n3nc(C)cc3C)n2)cc1,0 +103,658723,COC(=O)c1[nH]c2ccc(Br)cc2c1NC(=O)CN1CCN(C2CCCCC2)CC1,1 +104,658813,CCc1c(C)c(C#N)c2nc3ccccc3n2c1Nc1c(C)n(C)n(-c2ccccc2)c1=O,0 +105,135415833,O=c1c(Cc2ccccc2)c(O)nc2n1CCS2,0 +106,658879,CCC(NC(=O)Nc1cc(OC)c(OC)c(OC)c1)(C(F)(F)F)C(F)(F)F,0 +107,659040,O=c1[nH]c(=S)[nH]nc1Cc1ccccc1,0 +108,135435901,COc1ccc(C2CC(=O)C(C=NCCN3CCOCC3)=C(O)C2)cc1,0 +109,16411130,Cc1nc(/N=C(\N)Nc2ccccc2)nc2ccccc12,0 +110,659321,CCOc1ccc(CSC(CC(=O)O)C(=O)O)cc1,0 +111,2838016,CC1=CC=CN2CC(O)CN=C12.Cl,0 +112,6603569,Cc1ccc2c(c1)[C@@H]1CN(C)CC[C@@H]1N2S(=O)(=O)c1ccc(F)cc1.Cl,1 +113,1922089,Cc1cc(C)c(-n2c(O)c(C=NCCN3CCOCC3)c(=O)[nH]c2=O)c(C)c1,0 +114,659756,O=C1CCCN1CC(CN1CCOCC1)Sc1nnnn1-c1ccccc1,0 +115,660120,Cc1cc(-c2cc(-c3ccc(Cl)cc3)nc(N)c2C#N)co1,0 +116,660285,COc1ccc(S(=O)(=O)N2CCC(N3CCCCC3)CC2)cc1,1 +117,660304,O=C(COc1ccccc1)Nc1ccc(-c2nnc(-c3ccco3)o2)cc1,0 +118,5389248,COc1ccc(C2/C(=C(/O)c3ccc(Cl)cc3)C(=O)C(=O)N2CCN2CCOCC2)c(OC)c1,0 +119,5389254,Cc1oc2cc(O)ccc2c(=O)c1-c1cnn(-c2ccccc2)c1,0 +120,660546,Cc1cccc(OCC(=O)N2CCC(N3CCCCCC3)CC2)c1,1 +121,135420605,CCCCc1c(O)nc(SCCN(C)C)n(-c2ccccc2)c1=O,0 +122,660831,COC1(OC)N=C(NC(=O)Nc2ccccc2)C2(C#N)C(c3ccccc3)C12C#N,0 +123,660995,Cc1ccc(S(=O)(=O)NCCSc2nnnn2C)cc1,0 +124,661065,CNc1oc(-c2cccs2)nc1C#N,0 +125,661098,CC(=O)Nc1ccc(Nc2ncnc3c2cnn3-c2ccccc2)cc1,0 +126,661170,O=C(Cc1ccc(Cl)cc1)Nc1cccc(-c2nnc(-c3ccco3)o2)c1,0 +127,661178,CCCOc1ccc(CSC(CC(=O)O)C(=O)O)cc1,0 +128,661187,c1ccc(-n2ncc3c(NCCCN4CCOCC4)ncnc32)cc1,0 +129,661203,CC1CCc2cccc3c2N1c1cc(C#N)c(C#N)cc1O3,0 +130,661217,N#Cc1nc(-c2cccs2)oc1NCc1ccccc1,1 +131,661296,CCc1c(C(=O)O)[nH]c2ccc(Br)cc12,0 +132,661300,S=c1nc(-c2ccccc2)[nH]n1-c1ccccc1,0 +133,661349,CC(=O)c1c(C(C)=O)c(C)n(NC(=O)c2ccncc2)c1C,0 +134,661355,Cn1c(=O)n(CCC(=O)O)c2ccccc21,0 +135,661406,CC1(C)CC(=O)C(CCCN2C(=O)c3ccccc3C2=O)C(=O)C1,0 +136,6603015,CN(C)CC(O)COc1cccc(OCC(O)CN(C)C)c1.Cl,0 +137,6603014,CN(C)CC(O)COc1ccc(C(C)(C)c2ccc(OCC(O)CN(C)C)cc2)cc1.Cl,0 +138,661455,Oc1ccccc1CNn1cnnc1,0 +139,661513,CCn1c(SCC(=O)Nc2ccccc2C(=O)OC)nnc1-c1ccc(N)cc1,0 +140,661518,CCc1cc2c(=O)c(-c3nc4ccccc4[nH]3)coc2cc1OC(=O)N1CCOCC1,0 +141,661528,Cc1ccc(C2Nc3ccccc3C(=O)N2Cc2ccco2)cc1,0 +142,661552,Cc1cn2c(-c3ccncc3)nnc2s1,0 +143,661761,CCOC(=O)c1[nH]c2cc3c(cc2c1NC(=O)CN1CCc2ccccc2C1)OCO3,1 +144,5389368,COc1ccc(/C(O)=C2\C(=O)C(=O)N(c3cc(C)on3)C2c2ccc(OC)c(OC)c2)cc1OC,0 +145,5389389,COc1cccc(/C(O)=C2/C(=O)C(=O)N(c3cc(C)on3)C2c2cccs2)c1,0 +146,5389423,CCN(CC)CCN1C(=O)C(=O)/C(=C(/O)c2cccc(OC)c2)C1c1ccccn1,0 +147,661999,COc1cccc(C2C(C(=O)c3cc4ccccc4o3)=C(O)C(=O)N2c2cc(C)on2)c1OC,0 +148,662011,CCN1CCCC1Cn1cnc2c([nH]c3ccc(C)cc32)c1=O,1 +149,6881185,COCCCNC(=O)c1c(N)n(/N=C/c2ccccn2)c2nc3ccccc3nc12,0 +150,6881246,CCO/C(C)=N/n1c2nc3ccccc3nc2c2c(=O)n(CC(C)C)c(C)nc21,0 +151,662144,CCOC(=O)c1c(C)n(C)c2ccc(OC)c(NC(=O)CN3CCN(Cc4ccccc4)CC3)c12,1 +152,662340,CCOC(=O)c1cc2c(=O)n3cccc(C)c3nc2n(CCCOC)c1=NC(=O)c1cccnc1,0 +153,5389504,CCOc1ccc(/C(O)=C2/C(=O)C(=O)N(CCOCCO)C2c2cccnc2)cc1,0 +154,662407,NC(=O)C1CCN(C(=O)CN2C(=O)c3ccccc3S2(=O)=O)CC1,0 +155,662515,CCOC(=O)c1[nH]c2cc(OC)c(OC)cc2c1NC(=O)c1nonc1C,0 +156,6603499,COc1ccc(C(=O)OC(C)CN2CCN(C)CC2)cc1OC.Cl,1 +157,9614332,C[n+]1cccc(CNC(=O)/C=N/O)c1.[I-],0 +158,662647,COc1ccccc1-c1nnc2n1N=C(c1ccc(O)c(O)c1)CS2,0 +159,662710,COc1cccc(-c2nnc3sc(-c4ccc(C)cc4)nn23)c1,0 +160,662745,CCOCCCn1cnc2c([nH]c3cc(OC)ccc32)c1=O,0 +161,200556,Cl.OC1(c2ccc(F)cc2)CCNC1,0 +162,662794,CC(C)(C)OC(=O)NCCc1nnc(SCC(=O)Nc2cccc(Cl)c2)o1,0 +163,662799,O=C1C2ON(c3ccccc3)C(c3ccncc3)C2C(=O)N1c1ccccc1,0 +164,662838,CCCn1nc(NC(=O)CC(C)C)c2cc3ccccc3nc21,0 +165,662878,CCOc1ccccc1NC(=O)CSc1nnc(-c2cnccn2)n1C,1 +166,662996,CCOc1ccc(-c2nnc3n2N=C(c2ccc(OC)cc2)CS3)cc1,0 +167,663008,CC(C)COP(=O)(c1ccc(N(C)C)cc1)C(O)c1ccccc1F,0 +168,6603621,Cl.NCCCCCc1nnc(SCc2ccccc2Cl)o1,1 +169,663121,Nc1c(S(=O)(=O)c2ccccc2)c2nc3ccccc3nc2n1Cc1ccco1,1 +170,9615342,Cn1c[n+](C)cc1/C=N/O.[I-],0 +171,663125,Oc1ccc(-c2[nH]ncc2-c2ccc(Cl)cc2)c(O)c1,1 +172,663143,CCOc1ccc(C2=Nn3c(nnc3-c3ccccc3OC)SC2)cc1,0 +173,663146,COc1ccc2c(c1)[nH]c1c(N3CCN(Cc4ccc5c(c4)OCO5)CC3)ncnc12,1 +174,663168,COc1cccc(-c2nnc3n2N=C(C(C)(C)C)CS3)c1,0 +175,663337,CCOC(=O)c1[nH]c2cc(OC)c(OC)cc2c1NC(=O)c1ccc2c(c1)OCO2,0 +176,663340,COc1ccc(CCn2c(=N)c(C(=O)NCc3ccco3)cc3c(=O)n4ccccc4nc32)cc1,1 +177,663539,Cc1nc(SCC(=O)Nc2ccc3c(c2)OCCO3)c2oc3ccccc3c2n1,0 +178,663581,COC(=O)[C@@H](NC(=O)Nc1ccc(C(C)=O)cc1)C(C)C,0 +179,5389740,CCCN1C(=O)C2(/C(=C(\O)c3ccc4c(c3)OCCO4)C(=O)C(=O)N2CCCOC)c2ccccc21,0 +180,663736,COC(=O)[C@H](Cc1ccccc1)NC(=O)N1CCN(Cc2ccccc2)CC1,1 +181,663792,CCOCCCn1c(=N)c(C(=O)NCc2ccc3c(c2)OCO3)cc2c(=O)n3cccc(C)c3nc21,1 +182,54676164,CCOC(=O)C1=C(O)C(=O)N(c2ccc(S(N)(=O)=O)cc2)C1c1ccc(OC)cc1,0 +183,664033,CC1(C)CCCN(C(=O)c2coc(=O)c(Br)c2)C1,0 +184,664154,COC(=O)c1[nH]c2cc(C)ccc2c1NC(=O)CN1CCCc2ccccc21,1 +185,5389802,Cc1nc2ccccn2c1/C(O)=C1\C(=O)C(=O)N(CCCn2ccnc2)C1c1ccncc1,0 +186,664250,CC(C)(C)OC(=O)N1CCCC1C(=O)NCCc1ccccc1,0 +187,135513628,CCOC(=O)/C(C(N)=NCCCO)=C(\O)OCC,0 +188,664461,O=C1COc2ccc(OCc3ccc(F)cc3)cc21,0 +189,6603365,Cl.c1ccc2c(c1)oc1c(NCCCn3ccnc3)ncnc12,1 +190,5389869,COc1ccc(-c2c(C)oc3c(CN4CCN(CCO)CC4)c(O)ccc3c2=O)cc1OC,0 +191,5389875,CC1Cc2cc(/C(O)=C3/C(=O)C(=O)N(CCN4CCOCC4)C3c3cccc(Cl)c3)ccc2O1,0 +192,5389878,COc1ccc(C2/C(=C(/O)c3ccc4c(c3)CC(C)O4)C(=O)C(=O)N2CCN2CCOCC2)cc1,0 +193,664733,Cc1nc2c3cnn(-c4ccc(C)c(C)c4)c3ncn2n1,0 +194,664737,CCn1c(=N)c(S(=O)(=O)c2ccc(F)cc2)cc2c(=O)n3ccccc3nc21,1 +195,664759,O=C(CN1CCN(Cc2ccccc2)CC1)c1ccc(Br)cc1,1 +196,5389891,C/C(Cl)=C\Cn1c(N2CCCC2)nc2c1c(=O)[nH]c(=O)n2C,0 +197,664983,CCc1ccc(N2CC(C)Cn3c2nc2c3c(=O)n(CCN3CCOCC3)c(=O)n2C)cc1,0 +198,6603255,Br.CC1(C(=O)CSc2nc3ccccc3s2)CCC(=O)O1,0 +199,665081,Cn1c(-c2ccc(CN3CCCCC3)o2)nc2ccccc21,0 diff --git a/expts/data/finetuning_example-cls/split.csv b/expts/data/finetuning_example-cls/split.csv new file mode 100644 index 000000000..2b47272b1 --- /dev/null +++ b/expts/data/finetuning_example-cls/split.csv @@ -0,0 +1,121 @@ +,train,val,test +0,0,120,160 +1,1,121,161 +2,2,122,162 +3,3,123,163 +4,4,124,164 +5,5,125,165 +6,6,126,166 +7,7,127,167 +8,8,128,168 +9,9,129,169 +10,10,130,170 +11,11,131,171 +12,12,132,172 +13,13,133,173 +14,14,134,174 +15,15,135,175 +16,16,136,176 +17,17,137,177 +18,18,138,178 +19,19,139,179 +20,20,140,180 +21,21,141,181 +22,22,142,182 +23,23,143,183 +24,24,144,184 +25,25,145,185 +26,26,146,186 +27,27,147,187 +28,28,148,188 +29,29,149,189 +30,30,150,190 +31,31,151,191 +32,32,152,192 +33,33,153,193 +34,34,154,194 +35,35,155,195 +36,36,156,196 +37,37,157,197 +38,38,158,198 +39,39,159,199 +40,40,, +41,41,, +42,42,, +43,43,, +44,44,, +45,45,, +46,46,, +47,47,, +48,48,, +49,49,, +50,50,, +51,51,, +52,52,, +53,53,, +54,54,, +55,55,, +56,56,, +57,57,, +58,58,, +59,59,, +60,60,, +61,61,, +62,62,, +63,63,, +64,64,, +65,65,, +66,66,, +67,67,, +68,68,, +69,69,, +70,70,, +71,71,, +72,72,, +73,73,, +74,74,, +75,75,, +76,76,, +77,77,, +78,78,, +79,79,, +80,80,, +81,81,, +82,82,, +83,83,, +84,84,, +85,85,, +86,86,, +87,87,, +88,88,, +89,89,, +90,90,, +91,91,, +92,92,, +93,93,, +94,94,, +95,95,, +96,96,, +97,97,, +98,98,, +99,99,, +100,100,, +101,101,, +102,102,, +103,103,, +104,104,, +105,105,, +106,106,, +107,107,, +108,108,, +109,109,, +110,110,, +111,111,, +112,112,, +113,113,, +114,114,, +115,115,, +116,116,, +117,117,, +118,118,, +119,119,, diff --git a/expts/data/finetuning_example-reg/raw.csv b/expts/data/finetuning_example-reg/raw.csv new file mode 100644 index 000000000..de8e39137 --- /dev/null +++ b/expts/data/finetuning_example-reg/raw.csv @@ -0,0 +1,161 @@ +,smiles,target +0,CCCS(=O)(=O)Nc1ccc(F)c(C(=O)c2c[nH]c3ncc(-c4ccc(Cl)cc4)cc23)c1F,-1.59345982 +1,CCCc1cc(N2CCc3c(nc(C4CC4)n3C)C2)n2ncnc2n1,1.677187053 +2,COc1ccccc1C(=O)Nc1ccc2cc[nH]c2c1,0.505149978 +3,Cn1cc(Nc2nc(N)nc(-c3cccc(-n4ccc5cc(C6CC6)cc(F)c5c4=O)c3CO)n2)cn1,0.850401148 +4,C=CC(=O)N(C)CCOc1c(N)ncnc1-c1cc(F)cc(NC(=O)c2ccc(C3CC3)cc2F)c1C,0.838502776 +5,C=CC(=O)N1CCC(CNc2ncnc(N)c2-c2ccc(Oc3ccccc3)cc2)CC1,0.752995432 +6,C=CC(=O)N1CCC[C@@H](n2nc(-c3ccc(Oc4ccccc4)cc3)c3c(N)ncnc32)C1,0.62603 +7,C=CC(=O)N1CCC[C@H](n2c(=O)n(-c3ccc(Oc4ccccc4)cc3)c3c(N)ncnc32)C1,0.983851719 +8,Cc1cccc(/C=N/Nc2cc(N3CCOCC3)nc(OCCc3ccccn3)n2)c1,-0.548213564 +9,C=CC(=O)Nc1cccc(Nc2nc(Nc3ccc(Oc4ccnc(C(=O)NC)c4)cc3)ncc2F)c1,-0.403402904 +10,C=CC(=O)Nc1cccc(Oc2nc(Nc3ccc(N4CCN(C)CC4)c(F)c3)nc3[nH]ccc23)c1,0.039414119 +11,C=CC(=O)Nc1cccc(Oc2nc(Nc3ccc(N4CCN(C)CC4)cc3)nc3ccoc23)c1,0.146128036 +12,C=CC(=O)Nc1cccc(Oc2nc(Nc3ccc(N4CCN(C)CC4)cc3)nc3ccsc23)c1,-0.049635146 +13,CNC(=O)C1(Cc2ccc(-c3cccnc3)cc2)CCN(Cc2cccc(F)c2)C1,0.829753919 +14,COc1nn(C)cc1C(=O)Nc1cccc(-c2cnc3n2CCC3)c1,0.866877814 +15,CNC(=O)C1(Cc2ccc(-c3ccncc3)cc2)CCN(Cc2cccc(F)c2)C1,1.10503305 +16,COc1nn(C)cc1C(=O)Nc1cccc2cnccc12,1.504729052 +17,O=C(Nc1ccc(CN2CCCCC2)cc1)c1ccnc2[nH]cnc12,1.767526899 +18,CCN(C/C=C\c1ccc(C2CCCCC2)c(Cl)c1)C1CCCCC1,-0.853871964 +19,CC#CC(=O)N1CC[C@@H](n2c(=O)n(-c3ccc(Oc4ccccc4)cc3)c3c(N)ncnc32)C1,0.935859798 +20,CC#CC(=O)N[C@H]1CCCN(c2c(F)cc(C(N)=O)c3[nH]c(C)c(C)c23)C1,0.163856803 +21,O=C(Nc1ccc2ccccc2n1)c1ccc(N2CCOC2=O)cc1,0.449015316 +22,O=C(Nc1cccc(-c2nncn2C2CC2)c1)c1cc(-n2cnc(C3CC3)c2)ccn1,0.262925469 +23,O=C(Nc1cccc(N2CCNC2=O)c1)C1NCC12CCCC2,1.718202645 +24,Fc1ccc(-c2ccc3c(c2)[nH]c2ccncc23)cn1,0.653405491 +25,CC(=O)N1CCN(c2c(Cl)cccc2NC(=O)COc2ccccc2Cl)CC1,0.256958153 +26,CC(=O)N1CCN(c2nc(C(F)(F)F)nc3sc(C)c(C)c23)CC1,0.13481437 +27,CNc1cc(Nc2cccn(-c3ccccn3)c2=O)nn2c(C(=O)N[C@@H]3C[C@@H]3F)cnc12,0.991226076 +28,CCN1CCN(S(=O)(=O)Cc2ccc(Cl)c(Cl)c2)CC1,1.249124949 +29,CNc1nc(C)cc(C(=O)Nc2ccc3[nH]ncc3c2)n1,1.493625323 +30,Fc1ccccc1-c1c[nH]nc1C1CCCN1Cc1ccc2ncccc2c1,0.545801757 +31,O=C(Nc1cnccc1-c1ccc(Cl)cc1)c1ccnc(NC(=O)C2CC2)c1,0.369957607 +32,O=C(Nc1nc2cccc(-c3ccc(CN4CCS(=O)(=O)CC4)cc3)n2n1)C1CC1,1.681349797 +33,C[C@@H]1CCN(C(=O)CC#N)C[C@@H]1N(C)c1ncnc2[nH]ccc12,1.991815076 +34,C[C@@H]1c2nnn(-c3ncc(F)cn3)c2CCN1C(=O)c1cccc(C(F)(F)F)c1Cl,0.969089603 +35,CC(=O)Nc1ccc(C(=O)N2CCCCC2c2nc(N)ncc2-c2ccc(Cl)cc2)cc1,0.596926814 +36,CC(=O)Nc1ccc(C(=O)N2CCCCC2c2nc(N)ncc2-c2cccc(Cl)c2)cc1,0.520483533 +37,C[C@H]1CN(C2COC2)CCN1c1ccc(Nc2cc(-c3ccnc(N4CCn5c(cc6c5CC(C)(C)C6)C4=O)c3CO)cn(C)c2=O)nc1,1.209139536 +38,CC(=O)Nc1ccc(O)cc1,1.887859133 +39,Cc1cccnc1Nc1cccc(C2CCCN(CC(=O)Nc3nccs3)C2)n1,0.180699201 +40,COCC(=O)N1CCC(Cc2ccccc2-c2cccc(F)c2)(C(=O)NC(C)C)CC1,0.596597096 +41,Cc1[nH]nc2c1C1(CCCCC1)CC(=O)N2,1.323066376 +42,CC(=O)Nc1ncc(C(=O)O)s1,1.81935965 +43,N#Cc1cc(F)c(NS(=O)(=O)c2c[nH]c3cc(Cl)ccc23)cc1F,-0.026872146 +44,Cc1c(C(=O)NCCCN(C)C)sc2ncnc(Nc3ccc(F)cc3OC(C)C)c12,0.741821047 +45,Cc1c(Cl)ccc2cc3n(c12)[C@@H](C)CNC3=O,0.328379603 +46,Cc1cn2nc(-c3cc(=O)n4cc(N5CCNC6(CC6)C5)ccc4n3)cc(C)c2n1,0.985246791 +47,CCOc1cc(CC(=O)N[C@@H](CC(C)C)c2ccccc2N2CCCCC2)ccc1C(=O)O,0.106870544 +48,COCCNC(=O)c1ccnc(C2CCNCC2)c1,2.0 +49,Cc1c[nH]c(=O)n1-c1ccc(C(=O)Nc2ccc3ccccc3n2)cc1,0.639386869 +50,Cc1cnc(C(=O)NCCc2ccc(S(=O)(=O)NC(=O)NC3CCCCC3)cc2)cn1,0.01745073 +51,Cc1c[nH]c2nccc(Oc3c(F)cc(Nc4cc(Cl)nc(N)n4)cc3F)c12,-1.22184875 +52,CC(C)(C)C(=O)N1CCC(Cc2ccc(-c3cccs3)cc2)(C(=O)N2CCCC2)CC1,-0.1837587 +53,O=C(c1ccc(Oc2ccccc2)cc1Cl)c1c[nH]c2ncnc(N[C@@H]3CC[C@@H](CO)OC3)c12,-0.467245621 +54,CC(C)(C)c1ccc(-c2nc3n(c(=O)c2C#N)CCS3)cc1,0.731266349 +55,CC(C)(C)c1ccc(C(O)CN2CCC(O)(c3ccc4c(c3)OCO4)CC2)cc1,0.779018972 +56,O=C1CCCC[C@@H]2[C@H](C[C@@H](Cc3ccccc3F)N2C(=O)c2cccc3ncccc23)N1,1.277998644 +57,CCc1c[nH]c2ncnc(N3CCC(CN4CCN(C)CC4)CC3)c12,1.728012707 +58,CC(C)(Oc1ccc(-c2cnc(N)c(-c3ccc(Cl)cc3)c2)cc1)C(=O)O,-1.158015195 +59,NC(=O)c1cnc(N2CCc3[nH]nc(C(F)(F)F)c3C2)c(Cl)c1,0.986009932 +60,NC1CC(NC(=O)c2ccc(-c3cn[nH]c3)cn2)C12CCC2,1.709702344 +61,NC1CCC(C(=O)N2CCC(c3c[nH]c4ncccc34)CC2)C1,1.917395215 +62,NC1CCCC(C(=O)N2CCC(c3c[nH]c4ncccc34)CC2)C1,1.752816431 +63,NC1CCCC(C(=O)Nc2ccc3[nH]ncc3c2)C1,2.0 +64,NC1CCCC(C(=O)Nc2cccc(N3CCNC3=O)c2)C1,1.763113391 +65,NC1CCCC1C(=O)N1CCC(c2c[nH]c3ncccc23)CC1,1.71701274 +66,O=C1NCCN(C(=O)c2ccc3nccn3c2)C1c1ccccc1C(F)(F)F,1.851001366 +67,O=C1NCCN(C(=O)c2ccncc2)C1c1ccccc1Cl,1.876996793 +68,NCC1CCCC1NC(=O)c1cc(N2CCNC2=O)ccc1F,1.87495702 +69,O=C1NCCSc2c1sc1ccc(O)cc21,0.891593204 +70,NCCN1CCN(C/C=C/C(=O)N2CCC[C@@H](n3nc(-c4ccc(Oc5ccccc5)cc4)c4c(N)ncnc43)C2)CC1,0.699837726 +71,CCc1nc(C)cn2nc(-c3cc(=O)n4cc(C5CCN(C)CC5)cc(C)c4n3)cc12,0.868232868 +72,O=S(=O)(c1cccc2cnccc12)N1CCCNCC1,1.892077899 +73,Nc1c(F)ccc2cnc(-n3ccc4ccncc43)cc12,0.713910354 +74,CCc1nc2c(C)cc(N3CCN(CC(=O)N4CC(O)C4)CC3)cn2c1N(C)c1nc(-c2ccc(F)cc2)c(C#N)s1,-0.204728421 +75,Cc1nc(Nc2nccs2)cc(C2CN(c3ncccn3)C2)n1,0.444669231 +76,CC(C)NC(=O)COc1cccc(-c2nc(Nc3ccc4[nH]ncc4c3)c3ccccc3n2)c1,-1.384078213 +77,Cc1cc(F)ccc1C1C(=O)NCCN1C(=O)c1ccc2nccn2c1,1.78096503 +78,Cc1cc(N2CCCC2c2cc(CCC(=O)O)cc(C)n2)ncn1,2.0 +79,CC(C)[C@H](CO)Nc1nc(Nc2cc(N)cc(Cl)c2)c2ncn(C(C)C)c2n1,-0.061980903 +80,CCn1c(-c2nonc2N)nc2cnc(Oc3cccc(NC(=O)c4ccc(OCCN5CCOCC5)cc4)c3)cc21,-0.449771647 +81,Nc1ncc(-c2cccc(C(F)(F)F)c2)c(C2CCCCN2C(=O)c2ccccc2)n1,-0.04431225 +82,CCn1c(=O)oc2cc(NC(=O)c3ccc(C(C)(C)C)cc3)ccc21,0.21005085 +83,Nc1ncc(C(=O)NC2CN(C(=O)C3CC3)C2)c2ccc(-c3cccc(F)c3)nc12,0.623352682 +84,CCn1c(CO)nn(-c2cc(O[C@@H](C)C(F)(F)F)c(C(=O)Nc3c(F)cccc3Cl)cc2F)c1=O,0.975753389 +85,Nc1ncnc2c1c(-c1ccc(Oc3ccccc3)cc1)nn2[C@@H]1CCCNC1,0.956648579 +86,Nc1ncnc2c1c(-c1cnc3[nH]ccc3c1)nn2C1CCCC1,0.862608364 +87,CCn1cc(C(=O)O)c(=O)c2c(N)c(F)c(NC3CCCCC3)cc21,0.427486109 +88,Cc1nccc(-c2cn(Cc3ccccc3)c3cnccc23)n1,1.122707254 +89,Cc1cc(Nc2cnccn2)cc(C2CCCN(CC(=O)N3CCCC3)C2)n1,1.642246824 +90,Cc1ncsc1C(=O)N1CCCCC1c1nc(N)ncc1-c1cccc(C(F)(F)F)c1,0.615318657 +91,CN(C(=O)c1cc(N2CCNC2=O)ccc1F)C1CCNC1,1.885728632 +92,OC[C@H](Nc1cncc(-c2ccc3[nH]ncc3n2)c1)c1ccccc1,1.017826038 +93,CN(C)C(=O)C1(Cc2ccccc2-c2cccc(F)c2)CCN(C(=O)C2CC=CCC2)CC1,-0.104025268 +94,CN(C)C(=O)C1(Cc2ccccc2-c2ccccc2)CCN(C(=O)C2CCCO2)CC1,0.902546779 +95,CN(C)C(=O)C1(Cc2ccccc2-c2ccccc2)CCN(C(=O)c2cnn(C)c2)CC1,1.041353202 +96,O=C(CCNC(=O)c1ccc(OC(F)(F)F)cc1)N[C@@H]1CCCc2ccccc21,0.526080692 +97,CC(O)(C#Cc1ccc2c(c1)N(c1nc(N)ncc1Cl)CC2)c1nccs1,-0.122628654 +98,COc1ccc(CCCN2CCN(c3cnn(C)c3)C(=O)C2)cc1F,1.343408594 +99,CC/C(=C(\c1ccccc1)c1ccc(OCCN(C)C)cc1)c1ccccc1,-0.580044252 +100,CC1(C)CC(Oc2ccc(-c3ccc(-c4cn[nH]c4)cc3O)nn2)CC(C)(C)N1,1.199947058 +101,Cc1nnc(-c2ccc(N3CCC(Oc4cc(F)ccc4Cl)CC3)nn2)o1,0.161068385 +102,Cc1nnc(-c2ccc(N3CCC(Oc4ccccc4C(F)(F)F)CC3)nn2)s1,-0.906578315 +103,c1ccc(-c2ccc(CN3CCCCCCC3)cc2)cc1,-0.614393726 +104,COc1ccc(CNC(=O)c2sc3nc(C)cc(C)c3c2N)cc1,-0.356547324 +105,Cc1nnc(CN(C)CC(C)Oc2ccc(Cl)c(Cl)c2)n1C,0.926290987 +106,CC1(C)Cc2cc(NC(=O)c3cnn4cccnc34)c(OCC3CC3)nc2O1,-0.199970641 +107,O=C(CN1CCCC(c2cccc(Cc3cccc(F)c3)n2)C1)N1CCCC1,1.062130535 +108,CC1(CNC(=O)c2cncc(C3CCNCC3)n2)CCCO1,1.988550039 +109,Cc1noc(C(C)C)c1C(=O)N1CC(C)OC(c2ccccc2)C1,1.180469962 +110,CC1CN(C(=O)c2cccnc2N2CCOCC2)CC(c2ccccc2)O1,1.573185017 +111,CCC(=O)N1CCN(c2ccc(Cl)cc2NC(=O)COc2ccccc2)CC1,0.250420002 +112,COc1ccc(S(=O)(=O)N2CCC(N3CCC(C)CC3)CC2)cc1,1.729799023 +113,COc1ccc([C@H]2CN(C(C)=O)[C@@H]3CCCN(Cc4cccc(F)c4)[C@H]23)cc1,0.595385981 +114,Cc1oc2ccccc2c1CNc1nnc(-c2ccncc2)o1,0.717254313 +115,c1sc(NCC2CCCO2)nc1C12CC3CC(CC(C3)C1)C2,-1.096910013 +116,CCC1=C(C)CN(C(=O)NCCc2ccc(S(=O)(=O)NC(=O)N[C@H]3CC[C@H](C)CC3)cc2)C1=O,-1.180456064 +117,CN(Cc1ccccc1)C1(C(=O)N2CCNC(=O)CC2)Cc2ccccc2C1,0.904931827 +118,CCCCNC(=O)NS(=O)(=O)c1ccc(C)cc1,0.439332694 +119,O=C(NC1CCNCC1)c1ccc2[nH]ncc2c1,1.842696589 +120,Cn1c(C2CC2)nc2c1CCN(c1ncnc3ccsc13)C2,1.237065953 +121,Cc1ccc(OCC2(O)CCN(CC3(O)CCN(c4ccccc4C)CC3)CC2)cc1,-0.040481623 +122,CCCNC(=O)NS(=O)(=O)c1ccc(Cl)cc1,0.969835093 +123,Cc1ccc(Oc2ccc(Cl)cc2NC(=O)CN(C)CC(=O)N(C)C)cc1,-0.053056729 +124,COc1ccccc1-c1cc(NC(=O)c2cccc(N3CCNC3=O)c2)[nH]n1,-0.397940009 +125,CN1C(N)=N[C@](C)(c2cc(NC(=O)c3ccc(F)cn3)ccc2F)CS1(=O)=O,1.560468571 +126,C=CC(=O)N1C[C@H](Nc2ncnc3[nH]ccc23)CC[C@@H]1C,1.828227965 +127,CC#CC(=O)N1CC[C@@H](n2cc(-c3ccc(Oc4c(F)cccc4F)cc3)c3c(N)n[nH]c(=O)c32)C1,-0.073143291 +128,CC(=O)NCCNc1cc(Cl)nn2c(-c3cccc(S(=O)(=O)N(C)C)c3)c(C)nc12,0.681693392 +129,CC(C)(O)CCn1cc2cc(NC(=O)c3cccc(C(F)(F)F)n3)c(C(C)(C)O)cc2n1,1.015611205 +130,CC1CC(N)CCN1C(=O)c1cc(N2CCNC2=O)c(F)cc1F,1.833504094 +131,CCOc1cc2nn(CCC(C)(C)O)cc2cc1NC(=O)c1cccc(C(F)F)n1,0.099680641 +132,CCS(=O)(=O)N1CC(CC#N)(n2cc(-c3ncnc4[nH]ccc34)cn2)C1,1.378597739 +133,CN(C(=O)c1cc(N2CCNC2=O)ccc1F)C1CCC(N)CC1,1.828717885 +134,CN1CCN(S(=O)(=O)c2ccc(-c3cnc(N)c(C(=O)Nc4cccnc4)n3)cc2)CC1,1.078638038 +135,CNC(=O)C1(Cc2ccc(-c3cccnc3)cc2)CCN(Cc2ccc(F)cc2)C1,1.021602716 +136,CNC(=O)C1(Cc2ccc(-c3cccnc3)cc2)CCN(Cc2ccccc2Cl)C1,0.387567779 +137,CNC(=O)c1cccc(NC(=O)N2CCC(Oc3ccccc3Cl)CC2)c1,1.068760828 +138,COCCc1noc(CN2CC(c3ccccc3)(c3ccccc3)CCC2=O)n1,1.167051359 +139,COc1cc2c(cc1OC)CC(=O)N(CCCN(C)C[C@H]1Cc3cc(OC)c(OC)cc31)CC2,1.730984039 +140,COc1cc2ncnc(Nc3cccc(O)c3)c2cc1OC,0.994317153 +141,COc1ccc(Cl)cc1C(=O)NCCc1ccc(S(=O)(=O)NC(=O)NC2CCCCC2)cc1,-0.91721463 +142,COc1ccc(Nc2c(C#N)cnc3cc(OC)c(OC)cc23)cc1Cl,-0.033858267 +143,COc1ccc(OCC(=O)Nc2cccc(Cl)c2N2CCN(C(C)=O)CC2)cc1,0.700790221 +144,COc1nc2sc(C(=O)NC3CC3)c(N)c2c(C)c1Cl,0.245512668 +145,Cc1cc(C)c2c(N)c(C(=O)NCc3ccc(Cl)cc3)sc2n1,-0.370590401 +146,Cc1cc(N2CCCC2)nc(C2CCCN(C)C2)n1,2.0 +147,Cc1ccc(C(=O)N2CCC(Cc3ccccc3-c3ccccc3)(C(=O)N(C)C)CC2)s1,-0.293282218 +148,Cc1ncc(CN2CCC(Nc3ccc4nnnn4n3)C(C)C2)s1,1.652285029 +149,Cc1ncsc1C(=O)N1CCCCC1c1nc(N(C)C)ncc1-c1cccc(Cl)c1,-0.083019953 +150,Cn1cc(-c2cn3nccc3c(-c3cnn([C@]4(CC#N)C[C@@H](C#N)C4)c3)n2)cn1,1.559583476 +151,N#CC[C@H](C1CCCC1)n1cc(-c2ncnc3[nH]ccc23)cn1,0.630986911 +152,NC(=O)C1(Cc2ccc(-c3ccncc3)cc2)CCN(C(=O)Cc2cccc(F)c2)CC1,1.006508828 +153,O=C(NC12CCCC1NCC2)c1ccc(-c2cn[nH]c2)cc1,2.0 +154,O=C(NCCc1ccccc1)c1ccc(NC(=O)N2CCCCc3ccccc32)cc1,0.075911761 +155,O=C1CN(c2ccc(Nc3nccc(C(F)(F)F)n3)cn2)CCN1,1.543819805 +156,OCC1CCCCN1Cc1ccc(-c2ccccc2)cc1,1.368007805 +157,OCC1CCCCN1Cc1ccc(Cl)c(Cl)c1,1.390069186 +158,[2H]C([2H])([2H])NC(=O)c1nnc(NC(=O)C2CC2)cc1Nc1cccc(-c2ncn(C)n2)c1OC,1.186566481 +159,c1ccc(Oc2cccc(CN(CCN3CCOCC3)Cc3cccnc3)c2)cc1,-0.222573178 diff --git a/expts/data/finetuning_example-reg/split.csv b/expts/data/finetuning_example-reg/split.csv new file mode 100644 index 000000000..e88c74b5e --- /dev/null +++ b/expts/data/finetuning_example-reg/split.csv @@ -0,0 +1,110 @@ +,train,val,test +0,0,60.0,126.0 +1,1,5.0,127.0 +2,2,35.0,128.0 +3,3,23.0,129.0 +4,4,15.0,130.0 +5,6,68.0,131.0 +6,7,12.0,132.0 +7,8,45.0,133.0 +8,9,119.0,134.0 +9,10,113.0,135.0 +10,11,41.0,136.0 +11,13,88.0,137.0 +12,16,30.0,138.0 +13,17,74.0,139.0 +14,18,54.0,140.0 +15,19,73.0,141.0 +16,20,14.0,142.0 +17,21,,143.0 +18,22,,144.0 +19,24,,145.0 +20,25,,146.0 +21,26,,147.0 +22,27,,148.0 +23,28,,149.0 +24,29,,150.0 +25,31,,151.0 +26,32,,152.0 +27,33,,153.0 +28,34,,154.0 +29,36,,155.0 +30,37,,156.0 +31,38,,157.0 +32,39,,158.0 +33,40,,159.0 +34,42,, +35,43,, +36,44,, +37,46,, +38,47,, +39,48,, +40,49,, +41,50,, +42,51,, +43,52,, +44,53,, +45,55,, +46,56,, +47,57,, +48,58,, +49,59,, +50,61,, +51,62,, +52,63,, +53,64,, +54,65,, +55,66,, +56,67,, +57,69,, +58,70,, +59,71,, +60,72,, +61,75,, +62,76,, +63,77,, +64,78,, +65,79,, +66,80,, +67,81,, +68,82,, +69,83,, +70,84,, +71,85,, +72,86,, +73,87,, +74,89,, +75,90,, +76,91,, +77,92,, +78,93,, +79,94,, +80,95,, +81,96,, +82,97,, +83,98,, +84,99,, +85,100,, +86,101,, +87,102,, +88,103,, +89,104,, +90,105,, +91,106,, +92,107,, +93,108,, +94,109,, +95,110,, +96,111,, +97,112,, +98,114,, +99,115,, +100,116,, +101,117,, +102,118,, +103,120,, +104,121,, +105,122,, +106,123,, +107,124,, +108,125,, diff --git a/expts/dataset_benchmark.py b/expts/dataset_benchmark.py index e8bf2e24d..948a03688 100644 --- a/expts/dataset_benchmark.py +++ b/expts/dataset_benchmark.py @@ -21,7 +21,6 @@ # CONFIG_FILE = "expts/neurips2023_configs/debug/config_large_gcn_debug.yaml" CONFIG_FILE = "expts/neurips2023_configs/config_large_gcn.yaml" # CONFIG_FILE = "expts/configs/config_pcqmv2_mpnn.yaml" -# CONFIG_FILE = "expts/configs/config_ipu_qm9.yaml" def benchmark(fn, *args, message="", log2wandb=False, **kwargs): diff --git a/expts/hydra-configs/README.md b/expts/hydra-configs/README.md index 40625917d..77d2569ba 100644 --- a/expts/hydra-configs/README.md +++ b/expts/hydra-configs/README.md @@ -1,14 +1,14 @@ # Configuring Graphium with Hydra This document provides users with a point of entry to composing configs in Graphium. As a flexible library with many features, configuration is an important part of Graphium. To make configurations as reusable as possible while providing maximum flexibility, we integrated Graphium with `hydra`. Our config structure is designed to make the following functionality as accessible as possible: -- Switching between **accelerators** (CPU, GPU and IPU) +- Switching between **accelerators** (CPU, GPU) - **Benchmarking** different models on the same dataset - **Fine-tuning** a pre-trained model on a new dataset In what follows, we describe how each of the above functionality is achieved and how users can benefit from this design to achieve the most with Graphium with as little configuration as possible. ## Accelerators -With Graphium supporting CPU, GPU and IPU hardware, easily switching between these accelerators is pre-configured. General, accelerator-specific configs are specified under `accelerator/`, whereas experiment-specific differences between the accelerators are specialized under `training/accelerator`. +With Graphium supporting CPU, GPU hardware, easily switching between these accelerators is pre-configured. General, accelerator-specific configs are specified under `accelerator/`, whereas experiment-specific differences between the accelerators are specialized under `training/accelerator`. ## Benchmarking Benchmarking multiple models on the same datasets and tasks requires us to easily switch between model configurations without redefining major parts of the architecture, task heads, featurization, metrics, predictor, etc. For example, when changing from a GCN to a GIN model, a simple switch of `architecture.gnn.layer_type: 'pyg:gin'` might suffice. Hence, we abstract the `model` configs under `model/` where such model configurations can be specified. diff --git a/expts/hydra-configs/accelerator/ipu.yaml b/expts/hydra-configs/accelerator/ipu.yaml deleted file mode 100644 index 3e6fb4429..000000000 --- a/expts/hydra-configs/accelerator/ipu.yaml +++ /dev/null @@ -1,18 +0,0 @@ -type: ipu -ipu_config: - - deviceIterations(60) # IPU would require large batches to be ready for the model. - # 60 for PCQM4mv2 - # 30 for largemix - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 96) - - Precision.enableStochasticRounding(True) - -ipu_inference_config: - # set device iteration and replication factor to 1 during inference - # gradient accumulation was set to 1 in the code - - deviceIterations(1) - - replicationFactor(1) - - Precision.enableStochasticRounding(False) diff --git a/expts/hydra-configs/accelerator/ipu_pipeline.yaml b/expts/hydra-configs/accelerator/ipu_pipeline.yaml deleted file mode 100644 index 996218646..000000000 --- a/expts/hydra-configs/accelerator/ipu_pipeline.yaml +++ /dev/null @@ -1,22 +0,0 @@ -type: ipu -ipu_config: - - deviceIterations(60) # IPU would require large batches to be ready for the model. - # 60 for PCQM4mv2 - # 30 for largemix - - replicationFactor(4) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 96) - - Precision.enableStochasticRounding(True) - -ipu_inference_config: - # set device iteration and replication factor to 1 during inference - # gradient accumulation was set to 1 in the code - - deviceIterations(60) - - replicationFactor(1) - - Precision.enableStochasticRounding(False) - -accelerator_kwargs: - _accelerator: "ipu" - gnn_layers_per_ipu: [4, 4, 4, 4] \ No newline at end of file diff --git a/expts/hydra-configs/architecture/largemix.yaml b/expts/hydra-configs/architecture/largemix.yaml index 32efef778..5b92050a9 100644 --- a/expts/hydra-configs/architecture/largemix.yaml +++ b/expts/hydra-configs/architecture/largemix.yaml @@ -83,14 +83,9 @@ architecture: datamodule: module_type: "MultitaskFromSmilesDataModule" args: - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 20 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: ${constants.datacache_path} - dataloading_from: "disk" num_workers: 20 # -1 to use all - persistent_workers: True + persistent_workers: False featurization: atom_property_list_onehot: [atomic-number, group, period, total-valence] atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring] diff --git a/expts/hydra-configs/architecture/pcqm4m.yaml b/expts/hydra-configs/architecture/pcqm4m.yaml index 494875765..f3fc04b63 100644 --- a/expts/hydra-configs/architecture/pcqm4m.yaml +++ b/expts/hydra-configs/architecture/pcqm4m.yaml @@ -81,13 +81,8 @@ architecture: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: ${constants.datacache_path} num_workers: 40 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. diff --git a/expts/hydra-configs/architecture/toymix.yaml b/expts/hydra-configs/architecture/toymix.yaml index a62b839cd..f4ae5a5db 100644 --- a/expts/hydra-configs/architecture/toymix.yaml +++ b/expts/hydra-configs/architecture/toymix.yaml @@ -74,12 +74,7 @@ architecture: datamodule: module_type: "MultitaskFromSmilesDataModule" args: - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: ${constants.datacache_path} - dataloading_from: ram num_workers: 30 # -1 to use all persistent_workers: False featurization: diff --git a/expts/hydra-configs/finetuning/admet.yaml b/expts/hydra-configs/finetuning/admet.yaml deleted file mode 100644 index 7360707df..000000000 --- a/expts/hydra-configs/finetuning/admet.yaml +++ /dev/null @@ -1,91 +0,0 @@ -# @package _global_ - -# == Fine-tuning configs in Graphium == -# -# A fine-tuning config is a appendum to a (pre-)training config. -# Since many things (e.g. the architecture), will stay constant between (pre-)training and fine-tuning, -# this config should be as minimal as possible to avoid unnecessary duplication. It only specifies -# what to override with regards to the config used for (pre-)training. -# -# Given the following training command: -# >>> graphium-train --cfg /path/to/train.yaml -# -# Fine-tuning now is as easy as: -# >>> graphium-train --cfg /path/to/train.yaml +finetune=admet -# -# NOTE: This config can be used for each of the benchmarks in the TDC ADMET benchmark suite. -# The only thing that needs to be changed is the `constants.task` key. - - -## == Overrides == - -defaults: - # This file contains all metrics and loss function info for all ADMET tasks. - # This config is filtered at runtime based on the `constants.task` key. - - override /tasks/loss_metrics_datamodule: admet - -constants: - - # For now, we assume a model is always fine-tuned on a single task at a time. - # You can override this value with any of the benchmark names in the TDC benchmark suite. - # See also https://tdcommons.ai/benchmark/admet_group/overview/ - task: lipophilicity_astrazeneca - - name: finetuning_${constants.task}_gcn - wandb: - name: ${constants.name} - project: ${constants.task} - entity: multitask-gnn - save_dir: logs/${constants.task} - seed: 42 - max_epochs: 100 - data_dir: expts/data/admet/${constants.task} - raise_train_error: true - -predictor: - optim_kwargs: - lr: 4.e-5 - -# == Fine-tuning config == - -finetuning: - - # For now, we assume a model is always fine-tuned on a single task at a time. - # You can override this value with any of the benchmark names in the TDC benchmark suite. - # See also https://tdcommons.ai/benchmark/admet_group/overview/ - task: ${constants.task} - level: graph - - # Pretrained model - pretrained_model: dummy-pretrained-model - finetuning_module: task_heads # gnn - sub_module_from_pretrained: zinc # optional - new_sub_module: ${constants.task} # optional - - # keep_modules_after_finetuning_module: # optional - # graph_output_nn/graph: {} - # task_heads/zinc: - # new_sub_module: lipophilicity_astrazeneca - # out_dim: 1 - - - # Changes to finetuning_module - drop_depth: 1 - new_out_dim: 8 - added_depth: 2 - - # Training - unfreeze_pretrained_depth: 0 - epoch_unfreeze_all: none - - # Optional finetuning head appended to model after finetuning_module - finetuning_head: - task: ${constants.task} - previous_module: task_heads - incoming_level: graph - model_type: mlp - in_dim: 8 - out_dim: 1 - hidden_dims: 8 - depth: 2 - last_layer_is_readout: true diff --git a/expts/hydra-configs/finetuning/admet_baseline.yaml b/expts/hydra-configs/finetuning/admet_baseline.yaml deleted file mode 100644 index 410d0dd64..000000000 --- a/expts/hydra-configs/finetuning/admet_baseline.yaml +++ /dev/null @@ -1,71 +0,0 @@ -# @package _global_ - -defaults: - - override /tasks/loss_metrics_datamodule: admet - -constants: - task: tbd - name: finetune_${constants.task} - wandb: - name: ${constants.name} - project: finetuning - entity: recursion - seed: 42 - max_epochs: 100 - data_dir: ../data/graphium/admet/${constants.task} - datacache_path: ../datacache/admet/${constants.task} - raise_train_error: true - metric: ${get_metric_name:${constants.task}} - -datamodule: - args: - batch_size_training: 32 - dataloading_from: ram - persistent_workers: true - num_workers: 4 - -trainer: - model_checkpoint: - # save_top_k: 1 - # monitor: graph_${constants.task}/${constants.metric}/val - # mode: ${get_metric_mode:${constants.task}} - # save_last: true - # filename: best - dirpath: model_checkpoints/finetuning/${constants.task}/${now:%Y-%m-%d_%H-%M-%S.%f}/ - every_n_epochs: 200 - trainer: - precision: 32 - check_val_every_n_epoch: 1 - # early_stopping: - # monitor: graph_${constants.task}/${constants.metric}/val - # mode: ${get_metric_mode:${constants.task}} - # min_delta: 0.001 - # patience: 10 - accumulate_grad_batches: none - # test_from_checkpoint: best.ckpt - # test_from_checkpoint: ${trainer.model_checkpoint.dirpath}/best.ckpt - -predictor: - optim_kwargs: - lr: 0.000005 - - -# == Fine-tuning config == - -finetuning: - task: ${constants.task} - level: graph - pretrained_model: tbd - finetuning_module: graph_output_nn - sub_module_from_pretrained: graph - new_sub_module: graph - - keep_modules_after_finetuning_module: # optional - task_heads-pcqm4m_g25: - new_sub_module: ${constants.task} - hidden_dims: 256 - depth: 2 - last_activation: ${get_last_activation:${constants.task}} - out_dim: 1 - - epoch_unfreeze_all: tbd \ No newline at end of file diff --git a/expts/hydra-configs/finetuning/example-custom.yaml b/expts/hydra-configs/finetuning/example-custom.yaml new file mode 100644 index 000000000..4b9e20197 --- /dev/null +++ b/expts/hydra-configs/finetuning/example-custom.yaml @@ -0,0 +1,87 @@ +# @package _global_ + +defaults: + - override /tasks/loss_metrics_datamodule: finetune + +constants: + benchmark: custom + task: finetuning_example-cls # finetuning_example-cls OR finetuning_example-reg + task_type: cls # cls OR reg + data_path: expts/data + # wandb: + # name: finetune_${constants.task} + # project: tbd + # entity: tbd + # tags: + # - finetuning + # - ${constants.task} + # - ${finetuning.pretrained_model} + seed: 42 + max_epochs: 20 + raise_train_error: true + model_dropout: 0. + +datamodule: + args: + batch_size_training: 256 + batch_size_inference: 256 + dataloading_from: ram + persistent_workers: true + num_workers: 8 + + task_specific_args: + finetune: + df: null + df_path: ${constants.data_path}/${constants.task}/raw.csv + splits_path: ${constants.data_path}/${constants.task}/split.csv + smiles_col: smiles + label_cols: target + task_level: graph + epoch_sampling_fraction: 1.0 + +trainer: + model_checkpoint: + save_top_k: 0 + dirpath: none + every_n_epochs: 200 + save_last: false + trainer: + precision: 32 + check_val_every_n_epoch: 1 + accumulate_grad_batches: 1 + +predictor: + optim_kwargs: + lr: 0.00001 + torch_scheduler_kwargs: + module_type: WarmUpLinearLR + max_num_epochs: ${constants.max_epochs} + warmup_epochs: 3 + verbose: False + + + +# == Fine-tuning config == + +finetuning: + task: finetune + level: graph + pretrained_model: dummy-pretrained-model + finetuning_module: graph_output_nn + sub_module_from_pretrained: graph + new_sub_module: graph + drop_depth: 1 + added_depth: 1 + new_out_dim: 256 + + keep_modules_after_finetuning_module: + task_heads-zinc: + new_sub_module: finetune + hidden_dims: ${finetuning.new_out_dim} + depth: 1 + dropout: 0. + last_activation: none + out_dim: 1 + + epoch_unfreeze_all: 0 + always_freeze_modules: [] \ No newline at end of file diff --git a/expts/hydra-configs/finetuning/example-tdc.yaml b/expts/hydra-configs/finetuning/example-tdc.yaml new file mode 100644 index 000000000..d5bfaf98f --- /dev/null +++ b/expts/hydra-configs/finetuning/example-tdc.yaml @@ -0,0 +1,76 @@ +# @package _global_ + +defaults: + - override /tasks/loss_metrics_datamodule: tdc +constants: + task: bbb_martins + # wandb: + # name: finetune_${constants.task} + # project: tbd + # entity: tbd + # tags: + # - finetuning + # - ${constants.task} + # - ${finetuning.pretrained_model} + seed: 42 + max_epochs: 20 + raise_train_error: true + metric: ${get_metric_name:${constants.task}} + model_dropout: 0. + +datamodule: + args: + batch_size_training: 256 + batch_size_inference: 256 + dataloading_from: ram + persistent_workers: true + num_workers: 2 + split_type: default + tdc_train_val_seed: 1 + +trainer: + model_checkpoint: + save_top_k: 0 + dirpath: none + every_n_epochs: 200 + save_last: false + trainer: + precision: 32 + check_val_every_n_epoch: 1 + accumulate_grad_batches: 1 + +predictor: + optim_kwargs: + lr: 0.00001 + torch_scheduler_kwargs: + module_type: WarmUpLinearLR + max_num_epochs: ${constants.max_epochs} + warmup_epochs: 3 + verbose: False + + + +# == Fine-tuning config == + +finetuning: + task: ${constants.task} + level: graph + pretrained_model: dummy-pretrained-model + finetuning_module: graph_output_nn + sub_module_from_pretrained: graph + new_sub_module: graph + drop_depth: 1 + added_depth: 1 + new_out_dim: 256 + + keep_modules_after_finetuning_module: + task_heads-zinc: + new_sub_module: ${constants.task} + hidden_dims: ${finetuning.new_out_dim} + depth: 1 + dropout: 0. + last_activation: none + out_dim: 1 + + epoch_unfreeze_all: 0 + always_freeze_modules: [] \ No newline at end of file diff --git a/expts/hydra-configs/fingerprinting/example-custom.yaml b/expts/hydra-configs/fingerprinting/example-custom.yaml new file mode 100644 index 000000000..d6f26a3c5 --- /dev/null +++ b/expts/hydra-configs/fingerprinting/example-custom.yaml @@ -0,0 +1,16 @@ +pretrained: + model: dummy-pretrained-model + layers: + - graph_output_nn-graph:0 + - task_heads-zinc:0 + +datamodule: + df_path: ./expts/data/finetuning_example-reg/raw.csv + benchmark: null + task: null + split_val: 0.0 + split_test: 1.0 + device: cpu # cpu or cuda + num_workers: 0 + fps_cache_dir: ./expts/data/fps/finetuning_example-reg + mol_cache_dir: ${datamodule.fps_cache_dir} diff --git a/expts/hydra-configs/fingerprinting/example-tdc.yaml b/expts/hydra-configs/fingerprinting/example-tdc.yaml new file mode 100644 index 000000000..1f603e9c6 --- /dev/null +++ b/expts/hydra-configs/fingerprinting/example-tdc.yaml @@ -0,0 +1,15 @@ +pretrained: + model: dummy-pretrained-model + layers: + - graph_output_nn-graph:0 + - task_heads-zinc:0 + +datamodule: + df_path: null + benchmark: tdc + task: herg + data_seed: 1 + device: cpu # cpu or cuda + num_workers: 0 + fps_cache_dir: ./expts/data/fps/${datamodule.benchmark}/${datamodule.task} + mol_cache_dir: ${datamodule.fps_cache_dir} \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/finetune.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/finetune.yaml new file mode 100644 index 000000000..0f93b26de --- /dev/null +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/finetune.yaml @@ -0,0 +1,77 @@ +# @package _global_ + +#Task-specific +predictor: + metrics_on_progress_bar: + reg: ["mae"] + cls: ["auroc"] + loss_fun: + reg: mae + cls: bce_logits + random_seed: ${constants.seed} + optim_kwargs: + lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs + torch_scheduler_kwargs: + module_type: WarmUpLinearLR + max_num_epochs: &max_epochs 10 + warmup_epochs: 10 + verbose: False + target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label + multitask_handling: flatten # flatten, mean-per-label + +# Task-specific +metrics: + reg: + - name: mae + metric: mae + target_nan_mask: null + multitask_handling: flatten + threshold_kwargs: null + - name: spearman + metric: spearmanr + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: pearson + metric: pearsonr + threshold_kwargs: null + target_nan_mask: null + multitask_handling: mean-per-label + - name: r2_score + metric: r2_score + target_nan_mask: null + multitask_handling: mean-per-label + threshold_kwargs: null + cls: + - name: auroc + metric: auroc + task: binary + multitask_handling: mean-per-label + threshold_kwargs: null + - name: auprc + metric: averageprecision + task: binary + multitask_handling: mean-per-label + threshold_kwargs: null + - name: accuracy + metric: accuracy + multitask_handling: mean-per-label + target_to_int: True + average: micro + threshold_kwargs: &threshold_05 + operator: greater + threshold: 0.5 + th_on_preds: True + th_on_target: True + +datamodule: + args: + task_specific_args: + finetune: + df: null + df_path: expts/data/finetuning_example-reg/raw.csv + smiles_col: smiles + label_cols: target + task_level: graph + splits_path: expts/data/finetuning_example-reg/split.csv + epoch_sampling_fraction: 1.0 \ No newline at end of file diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_mcf7.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_mcf7.yaml index 43933a7fa..53c753402 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_mcf7.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_mcf7.yaml @@ -7,7 +7,7 @@ predictor: l1000_mcf7: [] loss_fun: l1000_mcf7: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 3 alpha: 0.5 diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_vcap.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_vcap.yaml index 27b89d862..e385bf23e 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_vcap.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_vcap.yaml @@ -7,7 +7,7 @@ predictor: l1000_vcap: [] loss_fun: l1000_vcap: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 3 alpha: 0.5 diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/largemix.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/largemix.yaml index 921960cd1..f307f0441 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/largemix.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/largemix.yaml @@ -15,16 +15,16 @@ predictor: pcqm4m_n4: [] loss_fun: l1000_vcap: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 3 alpha: 0.5 l1000_mcf7: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 3 alpha: ${predictor.loss_fun.l1000_vcap.alpha} - pcba_1328: bce_logits_ipu - pcqm4m_g25: mae_ipu - pcqm4m_n4: mae_ipu + pcba_1328: bce_logits + pcqm4m_g25: mae + pcqm4m_n4: mae metrics: l1000_vcap: &classif_metrics @@ -48,7 +48,7 @@ metrics: threshold_kwargs: null l1000_mcf7: *classif_metrics pcba_1328: - # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly + - name: auroc metric: auroc task: binary @@ -63,17 +63,17 @@ metrics: threshold_kwargs: null pcqm4m_g25: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcba_1328.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcba_1328.yaml index adc3321a0..72f9fba35 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcba_1328.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcba_1328.yaml @@ -6,11 +6,11 @@ predictor: metrics_on_training_set: pcba_1328: [] loss_fun: - pcba_1328: bce_logits_ipu + pcba_1328: bce_logits metrics: pcba_1328: - # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly + - name: auroc metric: auroc task: binary diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml index d5b302dd1..8eb878b62 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml @@ -7,25 +7,24 @@ predictor: metrics_on_training_set: homolumo: ["pearsonr"] loss_fun: - homolumo: mae_ipu + homolumo: mae # Task-specific metrics: homolumo: - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" homolumo: diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_g25.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_g25.yaml index 047701f6e..8247b4c47 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_g25.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_g25.yaml @@ -6,22 +6,22 @@ predictor: metrics_on_training_set: pcqm4m_g25: [] loss_fun: - pcqm4m_g25: mae_ipu + pcqm4m_g25: mae metrics: pcqm4m_g25: - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_n4.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_n4.yaml index 494843464..2ef471be4 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_n4.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_n4.yaml @@ -6,22 +6,22 @@ predictor: metrics_on_training_set: pcqm4m_n4: [] loss_fun: - pcqm4m_n4: mae_ipu + pcqm4m_n4: mae metrics: pcqm4m_n4: - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/tdc.yaml similarity index 89% rename from expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml rename to expts/hydra-configs/tasks/loss_metrics_datamodule/tdc.yaml index 89176f2b6..d4a6e296d 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/tdc.yaml @@ -26,29 +26,29 @@ predictor: herg: ["auroc"] ames: ["auroc"] dili: ["auroc"] - ld50_zhu: ["auroc"] + ld50_zhu: ["mae"] loss_fun: caco2_wang: mae - hia_hou: bce - pgp_broccatelli: bce - bioavailability_ma: bce + hia_hou: bce_logits + pgp_broccatelli: bce_logits + bioavailability_ma: bce_logits lipophilicity_astrazeneca: mae solubility_aqsoldb: mae - bbb_martins: bce + bbb_martins: bce_logits ppbr_az: mae vdss_lombardo: mae - cyp2d6_veith: bce - cyp3a4_veith: bce - cyp2c9_veith: bce - cyp2d6_substrate_carbonmangels: bce - cyp3a4_substrate_carbonmangels: bce - cyp2c9_substrate_carbonmangels: bce + cyp2d6_veith: bce_logits + cyp3a4_veith: bce_logits + cyp2c9_veith: bce_logits + cyp2d6_substrate_carbonmangels: bce_logits + cyp3a4_substrate_carbonmangels: bce_logits + cyp2c9_substrate_carbonmangels: bce_logits half_life_obach: mae clearance_microsome_az: mae clearance_hepatocyte_az: mae - herg: bce - ames: bce - dili: bce + herg: bce_logits + ames: bce_logits + dili: bce_logits ld50_zhu: mae random_seed: ${constants.seed} optim_kwargs: @@ -134,7 +134,7 @@ metrics: ld50_zhu: *regression_metrics datamodule: - module_type: "ADMETBenchmarkDataModule" + module_type: "TDCBenchmarkDataModule" args: # TDC specific tdc_benchmark_names: null diff --git a/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml b/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml index 9ac744a52..bf2e044b4 100644 --- a/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml +++ b/expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml @@ -1,41 +1,45 @@ # @package _global_ predictor: + target_nan_mask: ignore + multitask_handling: flatten metrics_on_progress_bar: qm9: ["mae"] tox21: ["auroc"] zinc: ["mae"] loss_fun: - qm9: mae_ipu - tox21: bce_logits_ipu - zinc: mae_ipu + qm9: mae + tox21: bce_logits + zinc: mae metrics: qm9: &qm9_metrics - name: mae - metric: mae_ipu - target_nan_mask: null + metric: mae + target_nan_mask: ignore multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null - target_nan_mask: null + target_nan_mask: ignore multitask_handling: mean-per-label - name: r2_score - metric: r2_score_ipu - target_nan_mask: null + metric: r2_score + target_nan_mask: ignore multitask_handling: mean-per-label threshold_kwargs: null tox21: - name: auroc - metric: auroc_ipu + metric: auroc task: binary + target_nan_mask: ignore multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: averageprecision task: binary + target_nan_mask: ignore multitask_handling: mean-per-label threshold_kwargs: null - name: f1 > 0.5 @@ -44,6 +48,8 @@ metrics: target_to_int: True num_classes: 2 average: micro + task: binary + target_nan_mask: ignore threshold_kwargs: &threshold_05 operator: greater threshold: 0.5 @@ -53,6 +59,8 @@ metrics: metric: precision multitask_handling: mean-per-label average: micro + task: binary + target_nan_mask: ignore threshold_kwargs: *threshold_05 zinc: *qm9_metrics diff --git a/expts/hydra-configs/tasks/task_heads/admet.yaml b/expts/hydra-configs/tasks/task_heads/tdc.yaml similarity index 100% rename from expts/hydra-configs/tasks/task_heads/admet.yaml rename to expts/hydra-configs/tasks/task_heads/tdc.yaml diff --git a/expts/hydra-configs/tasks/admet.yaml b/expts/hydra-configs/tasks/tdc.yaml similarity index 77% rename from expts/hydra-configs/tasks/admet.yaml rename to expts/hydra-configs/tasks/tdc.yaml index 30dec61e0..f7fef1b57 100644 --- a/expts/hydra-configs/tasks/admet.yaml +++ b/expts/hydra-configs/tasks/tdc.yaml @@ -3,5 +3,5 @@ # want to override both. defaults: - - task_heads: admet - - loss_metrics_datamodule: admet \ No newline at end of file + - task_heads: tdc + - loss_metrics_datamodule: tdc \ No newline at end of file diff --git a/expts/hydra-configs/training/accelerator/largemix_cpu.yaml b/expts/hydra-configs/training/accelerator/largemix_cpu.yaml index 6f5e0606a..ea83fdf58 100644 --- a/expts/hydra-configs/training/accelerator/largemix_cpu.yaml +++ b/expts/hydra-configs/training/accelerator/largemix_cpu.yaml @@ -4,7 +4,6 @@ datamodule: args: batch_size_training: 200 batch_size_inference: 200 - featurization_n_jobs: 20 num_workers: 20 predictor: diff --git a/expts/hydra-configs/training/accelerator/largemix_gpu.yaml b/expts/hydra-configs/training/accelerator/largemix_gpu.yaml index ac728c982..17ac12ad8 100644 --- a/expts/hydra-configs/training/accelerator/largemix_gpu.yaml +++ b/expts/hydra-configs/training/accelerator/largemix_gpu.yaml @@ -7,7 +7,6 @@ datamodule: args: batch_size_training: 2048 batch_size_inference: 2048 - featurization_n_jobs: 6 num_workers: 6 predictor: diff --git a/expts/hydra-configs/training/accelerator/largemix_ipu.yaml b/expts/hydra-configs/training/accelerator/largemix_ipu.yaml index 090600e98..115cd9e53 100644 --- a/expts/hydra-configs/training/accelerator/largemix_ipu.yaml +++ b/expts/hydra-configs/training/accelerator/largemix_ipu.yaml @@ -2,14 +2,6 @@ datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 # Data handling-related batch_size_training: 30 batch_size_inference: 30 diff --git a/expts/hydra-configs/training/accelerator/pcqm4m_ipu.yaml b/expts/hydra-configs/training/accelerator/pcqm4m_ipu.yaml index a7e23f383..c49d10405 100644 --- a/expts/hydra-configs/training/accelerator/pcqm4m_ipu.yaml +++ b/expts/hydra-configs/training/accelerator/pcqm4m_ipu.yaml @@ -2,14 +2,6 @@ datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 16 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 60 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 120 # Data handling-related batch_size_inference: 16 diff --git a/expts/hydra-configs/training/accelerator/toymix_cpu.yaml b/expts/hydra-configs/training/accelerator/toymix_cpu.yaml index 9022eeb84..f81662285 100644 --- a/expts/hydra-configs/training/accelerator/toymix_cpu.yaml +++ b/expts/hydra-configs/training/accelerator/toymix_cpu.yaml @@ -4,7 +4,6 @@ datamodule: args: batch_size_training: 200 batch_size_inference: 200 - featurization_n_jobs: 4 num_workers: 4 predictor: diff --git a/expts/hydra-configs/training/accelerator/toymix_gpu.yaml b/expts/hydra-configs/training/accelerator/toymix_gpu.yaml index c2c8e4066..ac4e48c26 100644 --- a/expts/hydra-configs/training/accelerator/toymix_gpu.yaml +++ b/expts/hydra-configs/training/accelerator/toymix_gpu.yaml @@ -7,7 +7,6 @@ datamodule: args: batch_size_training: 200 batch_size_inference: 200 - featurization_n_jobs: 4 num_workers: 4 predictor: diff --git a/expts/hydra-configs/training/accelerator/toymix_ipu.yaml b/expts/hydra-configs/training/accelerator/toymix_ipu.yaml index 1bf28ce0b..8f5fe4941 100644 --- a/expts/hydra-configs/training/accelerator/toymix_ipu.yaml +++ b/expts/hydra-configs/training/accelerator/toymix_ipu.yaml @@ -2,14 +2,6 @@ datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 80 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 44 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 80 # Data handling-related batch_size_training: 50 batch_size_inference: 50 diff --git a/expts/hydra-configs/training/toymix.yaml b/expts/hydra-configs/training/toymix.yaml index 4afcbd56a..bfffdebe2 100644 --- a/expts/hydra-configs/training/toymix.yaml +++ b/expts/hydra-configs/training/toymix.yaml @@ -11,7 +11,7 @@ predictor: warmup_epochs: 10 verbose: False scheduler_kwargs: null - target_nan_mask: null + target_nan_mask: ignore multitask_handling: flatten # flatten, mean-per-label trainer: @@ -23,4 +23,4 @@ trainer: precision: 16 max_epochs: ${constants.max_epochs} min_epochs: 1 - check_val_every_n_epoch: 20 \ No newline at end of file + check_val_every_n_epoch: 2 \ No newline at end of file diff --git a/expts/neurips2023_configs/base_config/large.yaml b/expts/neurips2023_configs/base_config/large.yaml index 8a836f368..a6389a4a6 100644 --- a/expts/neurips2023_configs/base_config/large.yaml +++ b/expts/neurips2023_configs/base_config/large.yaml @@ -7,18 +7,11 @@ constants: datacache_path: "/localdata/neurips2023-large/" accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 + # Data handling-related batch_size_training: 30 batch_size_inference: 30 @@ -31,38 +24,8 @@ accelerator: precision: 16-true accumulate_grad_batches: 2 - ipu_config: - - deviceIterations(30) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 96) - - Precision.enableStochasticRounding(True) - # - Precision.enableFloatingPointExceptions(True) - - ipu_inference_config: - # set device iteration and replication factor to 1 during inference - # gradient accumulation was set to 1 in the code - - deviceIterations(1) - - replicationFactor(1) - - Precision.enableStochasticRounding(False) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# args: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -133,11 +96,6 @@ datamodule: epoch_sampling_fraction: 1.0 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - dataloading_from: disk processed_graph_data_path: ${constants.datacache_path} featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -170,7 +128,7 @@ datamodule: ksteps: 16 num_workers: 32 # -1 to use all - persistent_workers: True # if use persistent worker at the start of each epoch. + persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. @@ -331,16 +289,16 @@ predictor: pcqm4m_n4: [] loss_fun: l1000_vcap: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 3 alpha: 0.5 l1000_mcf7: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 3 alpha: ${predictor.loss_fun.l1000_vcap.alpha} - pcba_1328: bce_logits_ipu - pcqm4m_g25: mae_ipu - pcqm4m_n4: mae_ipu + pcba_1328: bce_logits + pcqm4m_g25: mae + pcqm4m_n4: mae random_seed: ${constants.seed} optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -380,7 +338,7 @@ metrics: threshold_kwargs: null l1000_mcf7: *classif_metrics pcba_1328: - # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly + - name: auroc metric: auroc task: binary @@ -395,17 +353,17 @@ metrics: threshold_kwargs: null pcqm4m_g25: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/base_config/large_pcba.yaml b/expts/neurips2023_configs/base_config/large_pcba.yaml index f90675e73..9eb574983 100644 --- a/expts/neurips2023_configs/base_config/large_pcba.yaml +++ b/expts/neurips2023_configs/base_config/large_pcba.yaml @@ -7,18 +7,11 @@ constants: datacache_path: "/localdata/neurips2023-large/" accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 + # Data handling-related batch_size_training: 30 batch_size_inference: 30 @@ -31,38 +24,8 @@ accelerator: precision: 16-true accumulate_grad_batches: 2 - ipu_config: - - deviceIterations(30) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 96) - - Precision.enableStochasticRounding(True) - # - Precision.enableFloatingPointExceptions(True) - - ipu_inference_config: - # set device iteration and replication factor to 1 during inference - # gradient accumulation was set to 1 in the code - - deviceIterations(1) - - replicationFactor(1) - - Precision.enableStochasticRounding(False) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# args: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" @@ -132,11 +95,6 @@ datamodule: #epoch_sampling_fraction: 1.0 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - dataloading_from: disk processed_graph_data_path: ${constants.datacache_path} featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -169,7 +127,7 @@ datamodule: ksteps: 16 num_workers: 32 # -1 to use all - persistent_workers: True # if use persistent worker at the start of each epoch. + persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. @@ -330,16 +288,16 @@ predictor: #pcqm4m_n4: [] loss_fun: # l1000_vcap: - # name: hybrid_ce_ipu + # name: hybrid_ce # n_brackets: 3 # alpha: 0.5 # l1000_mcf7: - # name: hybrid_ce_ipu + # name: hybrid_ce # n_brackets: 3 # alpha: ${predictor.loss_fun.l1000_vcap.alpha} - pcba_1328: bce_logits_ipu - # pcqm4m_g25: mae_ipu - #pcqm4m_n4: mae_ipu + pcba_1328: bce_logits + # pcqm4m_g25: mae + #pcqm4m_n4: mae random_seed: ${constants.seed} optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -379,7 +337,7 @@ metrics: # threshold_kwargs: null # l1000_mcf7: *classif_metrics pcba_1328: - # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly + - name: auroc metric: auroc task: binary @@ -394,17 +352,17 @@ metrics: threshold_kwargs: null # pcqm4m_n4: &pcqm_metrics #- name: mae - #metric: mae_ipu + #metric: mae #target_nan_mask: null #multitask_handling: mean-per-label #threshold_kwargs: null #- name: pearsonr - #metric: pearsonr_ipu + #metric: pearsonr #threshold_kwargs: null #target_nan_mask: null #multitask_handling: mean-per-label #- name: r2 - #metric: r2_score_ipu + #metric: r2_score #threshold_kwargs: null #target_nan_mask: null #multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/base_config/large_pcqm_g25.yaml b/expts/neurips2023_configs/base_config/large_pcqm_g25.yaml index 1fac9176b..97836f906 100644 --- a/expts/neurips2023_configs/base_config/large_pcqm_g25.yaml +++ b/expts/neurips2023_configs/base_config/large_pcqm_g25.yaml @@ -7,18 +7,11 @@ constants: datacache_path: "/localdata/neurips2023-large/" accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 + # Data handling-related batch_size_training: 30 batch_size_inference: 30 @@ -31,38 +24,8 @@ accelerator: precision: 16-true accumulate_grad_batches: 2 - ipu_config: - - deviceIterations(30) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 96) - - Precision.enableStochasticRounding(True) - # - Precision.enableFloatingPointExceptions(True) - - ipu_inference_config: - # set device iteration and replication factor to 1 during inference - # gradient accumulation was set to 1 in the code - - deviceIterations(1) - - replicationFactor(1) - - Precision.enableStochasticRounding(False) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# args: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" @@ -132,11 +95,6 @@ datamodule: # epoch_sampling_fraction: 1.0 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - dataloading_from: disk processed_graph_data_path: ${constants.datacache_path} featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -169,7 +127,7 @@ datamodule: ksteps: 16 num_workers: 32 # -1 to use all - persistent_workers: True # if use persistent worker at the start of each epoch. + persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. @@ -330,16 +288,16 @@ predictor: # pcqm4m_n4: [] loss_fun: # l1000_vcap: - # name: hybrid_ce_ipu + # name: hybrid_ce # n_brackets: 3 # alpha: 0.5 # l1000_mcf7: - # name: hybrid_ce_ipu + # name: hybrid_ce # n_brackets: 3 # alpha: ${predictor.loss_fun.l1000_vcap.alpha} - # pcba_1328: bce_logits_ipu - pcqm4m_g25: mae_ipu - # pcqm4m_n4: mae_ipu + # pcba_1328: bce_logits + pcqm4m_g25: mae + # pcqm4m_n4: mae random_seed: ${constants.seed} optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -379,7 +337,7 @@ metrics: # threshold_kwargs: null # l1000_mcf7: *classif_metrics # pcba_1328: - # # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly + # # - name: auroc # metric: auroc # task: binary @@ -394,17 +352,17 @@ metrics: # threshold_kwargs: null pcqm4m_g25: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/base_config/large_pcqm_n4.yaml b/expts/neurips2023_configs/base_config/large_pcqm_n4.yaml index f9a9e58b8..826b1d95b 100644 --- a/expts/neurips2023_configs/base_config/large_pcqm_n4.yaml +++ b/expts/neurips2023_configs/base_config/large_pcqm_n4.yaml @@ -7,18 +7,11 @@ constants: datacache_path: "/localdata/neurips2023-large/" accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 + # Data handling-related batch_size_training: 30 batch_size_inference: 30 @@ -31,38 +24,8 @@ accelerator: precision: 16-true accumulate_grad_batches: 2 - ipu_config: - - deviceIterations(30) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 96) - - Precision.enableStochasticRounding(True) - # - Precision.enableFloatingPointExceptions(True) - - ipu_inference_config: - # set device iteration and replication factor to 1 during inference - # gradient accumulation was set to 1 in the code - - deviceIterations(1) - - replicationFactor(1) - - Precision.enableStochasticRounding(False) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# args: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" @@ -132,11 +95,6 @@ datamodule: epoch_sampling_fraction: 1.0 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - dataloading_from: disk processed_graph_data_path: ${constants.datacache_path} featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -169,7 +127,7 @@ datamodule: ksteps: 16 num_workers: 32 # -1 to use all - persistent_workers: True # if use persistent worker at the start of each epoch. + persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. @@ -330,16 +288,16 @@ predictor: pcqm4m_n4: [] loss_fun: # l1000_vcap: - # name: hybrid_ce_ipu + # name: hybrid_ce # n_brackets: 3 # alpha: 0.5 # l1000_mcf7: - # name: hybrid_ce_ipu + # name: hybrid_ce # n_brackets: 3 # alpha: ${predictor.loss_fun.l1000_vcap.alpha} - # pcba_1328: bce_logits_ipu - # pcqm4m_g25: mae_ipu - pcqm4m_n4: mae_ipu + # pcba_1328: bce_logits + # pcqm4m_g25: mae + pcqm4m_n4: mae random_seed: ${constants.seed} optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -379,7 +337,7 @@ metrics: # threshold_kwargs: null # l1000_mcf7: *classif_metrics # pcba_1328: - # # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly + # # - name: auroc # metric: auroc # task: binary @@ -394,17 +352,17 @@ metrics: # threshold_kwargs: null pcqm4m_n4: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/base_config/small.yaml b/expts/neurips2023_configs/base_config/small.yaml index fd7ce3fbe..a8b706b2a 100644 --- a/expts/neurips2023_configs/base_config/small.yaml +++ b/expts/neurips2023_configs/base_config/small.yaml @@ -6,18 +6,11 @@ constants: entity: multitask-gnn accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 80 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 44 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 80 + # Data handling-related batch_size_training: 50 batch_size_inference: 50 @@ -29,29 +22,8 @@ accelerator: precision: 16 accumulate_grad_batches: 4 - ipu_config: - - deviceIterations(5) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" qm9: @@ -97,10 +69,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -252,9 +220,9 @@ predictor: tox21: ["auroc"] zinc: ["mae"] loss_fun: - qm9: mae_ipu - tox21: bce_ipu - zinc: mae_ipu + qm9: mae + tox21: bce + zinc: mae random_seed: *seed optim_kwargs: lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs @@ -275,28 +243,28 @@ predictor: metrics: qm9: &qm9_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2_score - metric: r2_score_ipu + metric: r2_score target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null tox21: - name: auroc - metric: auroc_ipu + metric: auroc task: binary multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision task: binary multitask_handling: mean-per-label threshold_kwargs: null diff --git a/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml b/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml index 7b2d2cbdf..22fbde029 100644 --- a/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml +++ b/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml @@ -1,22 +1,14 @@ -# Testing the gcn model with the PCQMv2 dataset on IPU. +# Testing the gcn model with the PCQMv2 dataset. constants: name: &name neurips2023_small_data_gcn seed: &seed 3000 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 80 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 100 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 200 # Data handling-related batch_size_training: 5 batch_size_inference: 2 @@ -28,29 +20,8 @@ accelerator: precision: 16 accumulate_grad_batches: 4 - ipu_config: - - deviceIterations(5) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" qm9: @@ -96,10 +67,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -134,7 +101,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -253,9 +219,9 @@ predictor: tox21: ["auroc"] zinc: ["mae"] loss_fun: - qm9: mae_ipu - tox21: bce_ipu - zinc: mae_ipu + qm9: mae + tox21: bce + zinc: mae random_seed: *seed optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -276,28 +242,28 @@ predictor: metrics: qm9: &qm9_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2_score - metric: r2_score_ipu + metric: r2_score target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null tox21: - name: auroc - metric: auroc_ipu + metric: auroc task: binary multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision task: binary multitask_handling: mean-per-label threshold_kwargs: null diff --git a/expts/neurips2023_configs/baseline/config_small_gin_baseline.yaml b/expts/neurips2023_configs/baseline/config_small_gin_baseline.yaml index bc96d1057..3dbdca7d7 100644 --- a/expts/neurips2023_configs/baseline/config_small_gin_baseline.yaml +++ b/expts/neurips2023_configs/baseline/config_small_gin_baseline.yaml @@ -1,4 +1,4 @@ -# Testing the gin model with the PCQMv2 dataset on IPU. +# Testing the gin model with the PCQMv2 dataset. constants: name: &name neurips2023_small_data_gin config_override: "expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml" diff --git a/expts/neurips2023_configs/baseline/config_small_gine_baseline.yaml b/expts/neurips2023_configs/baseline/config_small_gine_baseline.yaml index 431235bb4..90389e008 100644 --- a/expts/neurips2023_configs/baseline/config_small_gine_baseline.yaml +++ b/expts/neurips2023_configs/baseline/config_small_gine_baseline.yaml @@ -1,4 +1,4 @@ -# Testing the gine model with the PCQMv2 dataset on IPU. +# Testing the gine model with the PCQMv2 dataset. constants: name: &name neurips2023_small_data_gine config_override: "expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml" diff --git a/expts/neurips2023_configs/config_classifigression_l1000.yaml b/expts/neurips2023_configs/config_classifigression_l1000.yaml index 48f06d9d1..16d6d1c73 100644 --- a/expts/neurips2023_configs/config_classifigression_l1000.yaml +++ b/expts/neurips2023_configs/config_classifigression_l1000.yaml @@ -1,44 +1,11 @@ -# Testing the mpnn only model with the PCQMv2 dataset on IPU. +# Testing the mpnn only model with the PCQMv2 dataset. constants: name: &name neurips2023_small_data_mpnn seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training -#accelerator: -# type: ipu # cpu or ipu or gpu -# config_override: -# datamodule: -# args: -# ipu_dataloader_training_opts: -# mode: async -# max_num_nodes_per_graph: 24 # train max nodes: 20, max_edges: 54 -# max_num_edges_per_graph: 60 -# ipu_dataloader_inference_opts: -# mode: async -# max_num_nodes_per_graph: 24 # valid max nodes: 51, max_edges: 118 -# max_num_edges_per_graph: 60 -# # Data handling-related -# batch_size_training: 50 -# batch_size_inference: 50 -## predictor: -## optim_kwargs: -## loss_scaling: 1024 -# trainer: -# trainer: -# precision: 16 -# accumulate_grad_batches: 4 -# -# ipu_config: -# - deviceIterations(20) # IPU would require large batches to be ready for the model. -# - replicationFactor(16) -# # - enableProfiling("graph_analyser") # The folder where the profile will be stored -# # - enableExecutableCaching("pop_compiler_cache") -# - TensorLocations.numIOTiles(128) -# - _Popart.set("defaultBufferingDepth", 128) -# - Precision.enableStochasticRounding(True) - accelerator: - type: gpu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: batch_size_training: 64 @@ -50,7 +17,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -76,10 +42,6 @@ datamodule: splits_path: graphium/data/neurips2023/small-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 1 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -114,7 +76,6 @@ datamodule: num_workers: 5 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: diff --git a/expts/neurips2023_configs/config_large_gcn.yaml b/expts/neurips2023_configs/config_large_gcn.yaml index 1219401df..cf56dcc3a 100644 --- a/expts/neurips2023_configs/config_large_gcn.yaml +++ b/expts/neurips2023_configs/config_large_gcn.yaml @@ -1,4 +1,4 @@ -# Running the gcn model with the largemix dataset on IPU. +# Running the gcn model with the largemix dataset. defaults: - base_config: large diff --git a/expts/neurips2023_configs/config_large_gcn_g25.yaml b/expts/neurips2023_configs/config_large_gcn_g25.yaml index 35c4e27d5..80a72cc3b 100644 --- a/expts/neurips2023_configs/config_large_gcn_g25.yaml +++ b/expts/neurips2023_configs/config_large_gcn_g25.yaml @@ -1,4 +1,4 @@ -# Running the gcn model with the largemix dataset on IPU. +# Running the gcn model with the largemix dataset. defaults: # - base_config: large diff --git a/expts/neurips2023_configs/config_large_gcn_gpu.yaml b/expts/neurips2023_configs/config_large_gcn_gpu.yaml index 2830530aa..6c5be45fe 100644 --- a/expts/neurips2023_configs/config_large_gcn_gpu.yaml +++ b/expts/neurips2023_configs/config_large_gcn_gpu.yaml @@ -49,7 +49,6 @@ datamodule: df_path: expts/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet splits_path: expts/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt` - featurization_n_jobs: 4 # 30 processed_graph_data_path: "../datacache/neurips2023-small/" num_workers: 4 # 30 @@ -60,7 +59,7 @@ architecture: predictor: loss_fun: - pcba_1328: bce_logits_ipu + pcba_1328: bce_logits torch_scheduler_kwargs: max_num_epochs: &max_epochs 20 diff --git a/expts/neurips2023_configs/config_large_gcn_n4.yaml b/expts/neurips2023_configs/config_large_gcn_n4.yaml index 61d335c12..616feec09 100644 --- a/expts/neurips2023_configs/config_large_gcn_n4.yaml +++ b/expts/neurips2023_configs/config_large_gcn_n4.yaml @@ -1,4 +1,4 @@ -# Running the gcn model with the largemix dataset on IPU. +# Running the gcn model with the largemix dataset. defaults: # - base_config: large diff --git a/expts/neurips2023_configs/config_large_gcn_pcba.yaml b/expts/neurips2023_configs/config_large_gcn_pcba.yaml index f11d8595d..d95401cf6 100644 --- a/expts/neurips2023_configs/config_large_gcn_pcba.yaml +++ b/expts/neurips2023_configs/config_large_gcn_pcba.yaml @@ -1,4 +1,4 @@ -# Running the gcn model with the largemix dataset on IPU. +# Running the gcn model with the largemix dataset. defaults: # - base_config: large diff --git a/expts/neurips2023_configs/config_large_gin.yaml b/expts/neurips2023_configs/config_large_gin.yaml index 6c3f47898..b2a0186f3 100644 --- a/expts/neurips2023_configs/config_large_gin.yaml +++ b/expts/neurips2023_configs/config_large_gin.yaml @@ -1,4 +1,4 @@ -# Running the gin model with the largemix dataset on IPU. +# Running the gin model with the largemix dataset. defaults: - base_config: large - _self_ diff --git a/expts/neurips2023_configs/config_large_gin_g25.yaml b/expts/neurips2023_configs/config_large_gin_g25.yaml index 538e994b1..702e9fe68 100644 --- a/expts/neurips2023_configs/config_large_gin_g25.yaml +++ b/expts/neurips2023_configs/config_large_gin_g25.yaml @@ -1,4 +1,4 @@ -# Running the gin model with the largemix dataset on IPU. +# Running the gin model with the largemix dataset. defaults: # - base_config: large - base_config: large_pcqm_g25 diff --git a/expts/neurips2023_configs/config_large_gin_n4.yaml b/expts/neurips2023_configs/config_large_gin_n4.yaml index c51e0e07d..3e41cf95b 100644 --- a/expts/neurips2023_configs/config_large_gin_n4.yaml +++ b/expts/neurips2023_configs/config_large_gin_n4.yaml @@ -1,4 +1,4 @@ -# Running the gin model with the largemix dataset on IPU. +# Running the gin model with the largemix dataset. defaults: # - base_config: large # - base_config: large_pcqm_g25 diff --git a/expts/neurips2023_configs/config_large_gin_pcba.yaml b/expts/neurips2023_configs/config_large_gin_pcba.yaml index 8bd33609b..af0e4945e 100644 --- a/expts/neurips2023_configs/config_large_gin_pcba.yaml +++ b/expts/neurips2023_configs/config_large_gin_pcba.yaml @@ -1,4 +1,4 @@ -# Running the gin model with the largemix dataset on IPU. +# Running the gin model with the largemix dataset. defaults: # - base_config: large # - base_config: large_pcqm_g25 diff --git a/expts/neurips2023_configs/config_large_gine.yaml b/expts/neurips2023_configs/config_large_gine.yaml index 793304ce0..6f82d3233 100644 --- a/expts/neurips2023_configs/config_large_gine.yaml +++ b/expts/neurips2023_configs/config_large_gine.yaml @@ -1,4 +1,4 @@ -# Running the gine model with the largemix dataset on IPU. +# Running the gine model with the largemix dataset. defaults: - base_config: large diff --git a/expts/neurips2023_configs/config_large_gine_g25.yaml b/expts/neurips2023_configs/config_large_gine_g25.yaml index e8002be3b..cceaa448f 100644 --- a/expts/neurips2023_configs/config_large_gine_g25.yaml +++ b/expts/neurips2023_configs/config_large_gine_g25.yaml @@ -1,4 +1,4 @@ -# Running the gine model with the largemix dataset on IPU. +# Running the gine model with the largemix dataset. defaults: # - base_config: large diff --git a/expts/neurips2023_configs/config_large_gine_n4.yaml b/expts/neurips2023_configs/config_large_gine_n4.yaml index df07380a4..d298fc183 100644 --- a/expts/neurips2023_configs/config_large_gine_n4.yaml +++ b/expts/neurips2023_configs/config_large_gine_n4.yaml @@ -1,4 +1,4 @@ -# Running the gine model with the largemix dataset on IPU. +# Running the gine model with the largemix dataset. defaults: # - base_config: large diff --git a/expts/neurips2023_configs/config_large_gine_pcba.yaml b/expts/neurips2023_configs/config_large_gine_pcba.yaml index 505935a57..668b7dfc2 100644 --- a/expts/neurips2023_configs/config_large_gine_pcba.yaml +++ b/expts/neurips2023_configs/config_large_gine_pcba.yaml @@ -1,4 +1,4 @@ -# Running the gine model with the largemix dataset on IPU. +# Running the gine model with the largemix dataset. defaults: # - base_config: large diff --git a/expts/neurips2023_configs/config_large_mpnn.yaml b/expts/neurips2023_configs/config_large_mpnn.yaml index 365927473..ca280f68a 100644 --- a/expts/neurips2023_configs/config_large_mpnn.yaml +++ b/expts/neurips2023_configs/config_large_mpnn.yaml @@ -1,4 +1,4 @@ -# Running the mpnn model with the largemix dataset on IPU. +# Running the mpnn model with the largemix dataset. defaults: - base_config: large diff --git a/expts/neurips2023_configs/config_luis_jama.yaml b/expts/neurips2023_configs/config_luis_jama.yaml index 5135c5cae..dc25ed212 100644 --- a/expts/neurips2023_configs/config_luis_jama.yaml +++ b/expts/neurips2023_configs/config_luis_jama.yaml @@ -1,44 +1,11 @@ -# Testing the mpnn only model with the PCQMv2 dataset on IPU. +# Testing the mpnn only model with the PCQMv2 dataset. constants: name: &name neurips2023_small_data_mpnn seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training -# accelerator: -# type: ipu # cpu or ipu or gpu -# config_override: -# datamodule: -# args: -# ipu_dataloader_training_opts: -# mode: async -# max_num_nodes_per_graph: 24 # train max nodes: 20, max_edges: 54 -# max_num_edges_per_graph: 60 -# ipu_dataloader_inference_opts: -# mode: async -# max_num_nodes_per_graph: 24 # valid max nodes: 51, max_edges: 118 -# max_num_edges_per_graph: 60 -# # Data handling-related -# batch_size_training: 50 -# batch_size_inference: 50 -# predictor: -# optim_kwargs: -# loss_scaling: 1024 -# trainer: -# trainer: -# precision: 16 -# accumulate_grad_batches: 4 - - # ipu_config: - # - deviceIterations(20) # IPU would require large batches to be ready for the model. - # - replicationFactor(16) - # # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # # - enableExecutableCaching("pop_compiler_cache") - # - TensorLocations.numIOTiles(128) - # - _Popart.set("defaultBufferingDepth", 128) - # - Precision.enableStochasticRounding(True) - accelerator: - type: cpu # cpu or ipu or gpu + type: cpu # cpu or gpu config_override: datamodule: batch_size_training: 64 @@ -50,7 +17,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" @@ -84,10 +50,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -122,7 +84,6 @@ datamodule: num_workers: 4 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -257,8 +218,8 @@ predictor: pcqm20k_g13: [] pcqm20k_n4: [] loss_fun: - pcqm20k_g13: mae_ipu - pcqm20k_n4: mae_ipu + pcqm20k_g13: mae + pcqm20k_n4: mae random_seed: *seed optim_kwargs: lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs @@ -279,12 +240,12 @@ predictor: metrics: pcqm20k_g13: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/config_small_gated_gcn.yaml b/expts/neurips2023_configs/config_small_gated_gcn.yaml index 8e00d26f6..5a542d96d 100644 --- a/expts/neurips2023_configs/config_small_gated_gcn.yaml +++ b/expts/neurips2023_configs/config_small_gated_gcn.yaml @@ -1,4 +1,4 @@ -# Testing the gated_gcn model with the PCQMv2 dataset on IPU. +# Testing the gated_gcn model with the PCQMv2 dataset. defaults: - base_config: small diff --git a/expts/neurips2023_configs/config_small_gcn.yaml b/expts/neurips2023_configs/config_small_gcn.yaml index 114ce26dc..d43080a4c 100644 --- a/expts/neurips2023_configs/config_small_gcn.yaml +++ b/expts/neurips2023_configs/config_small_gcn.yaml @@ -1,4 +1,4 @@ -# Testing the gcn model with the toymix dataset on IPU. +# Testing the gcn model with the toymix dataset. defaults: - base_config: small diff --git a/expts/neurips2023_configs/config_small_gcn_gpu.yaml b/expts/neurips2023_configs/config_small_gcn_gpu.yaml index 8b5a46e26..e1223da2f 100644 --- a/expts/neurips2023_configs/config_small_gcn_gpu.yaml +++ b/expts/neurips2023_configs/config_small_gcn_gpu.yaml @@ -12,7 +12,7 @@ architecture: layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps accelerator: - type: gpu # cpu or ipu or gpu + type: gpu # cpu or gpu float32_matmul_precision: medium config_override: datamodule: @@ -41,7 +41,6 @@ datamodule: zinc: df_path: expts/data/neurips2023/small-dataset/ZINC12k.csv.gz splits_path: expts/data/neurips2023/small-dataset/ZINC12k_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt` - featurization_n_jobs: 4 # 30 processed_graph_data_path: "../datacache/neurips2023-small/" num_workers: 4 # 30 @@ -52,7 +51,7 @@ architecture: predictor: loss_fun: - tox21: bce_logits_ipu + tox21: bce_logits torch_scheduler_kwargs: max_num_epochs: &max_epochs 300 diff --git a/expts/neurips2023_configs/config_small_gin.yaml b/expts/neurips2023_configs/config_small_gin.yaml index e018f722a..ff86251f7 100644 --- a/expts/neurips2023_configs/config_small_gin.yaml +++ b/expts/neurips2023_configs/config_small_gin.yaml @@ -1,4 +1,4 @@ -# Testing the gin model with the PCQMv2 dataset on IPU. +# Testing the gin model with the PCQMv2 dataset. defaults: - base_config: small diff --git a/expts/neurips2023_configs/config_small_gine.yaml b/expts/neurips2023_configs/config_small_gine.yaml index 111bebbc2..4ec66a4f2 100644 --- a/expts/neurips2023_configs/config_small_gine.yaml +++ b/expts/neurips2023_configs/config_small_gine.yaml @@ -1,4 +1,4 @@ -# Testing the gine model with the PCQMv2 dataset on IPU. +# Testing the gine model with the PCQMv2 dataset. defaults: - base_config: small diff --git a/expts/neurips2023_configs/config_small_mpnn.yaml b/expts/neurips2023_configs/config_small_mpnn.yaml index 357a8f560..12c7a17e1 100644 --- a/expts/neurips2023_configs/config_small_mpnn.yaml +++ b/expts/neurips2023_configs/config_small_mpnn.yaml @@ -1,4 +1,4 @@ -# Testing the mpnn only model with the PCQMv2 dataset on IPU. +# Testing the mpnn only model with the PCQMv2 dataset. defaults: - base_config: small diff --git a/expts/neurips2023_configs/debug/config_debug.yaml b/expts/neurips2023_configs/debug/config_debug.yaml index 3d31e5e8c..21a8c30b2 100644 --- a/expts/neurips2023_configs/debug/config_debug.yaml +++ b/expts/neurips2023_configs/debug/config_debug.yaml @@ -51,7 +51,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" @@ -70,10 +69,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 0 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml index ec05bf6eb..236673699 100644 --- a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml +++ b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml @@ -60,7 +60,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -131,10 +130,6 @@ datamodule: epoch_sampling_fraction: 1.0 # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml b/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml index 26b50756f..773ca8814 100644 --- a/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml +++ b/expts/neurips2023_configs/debug/config_small_gcn_debug.yaml @@ -40,7 +40,6 @@ accelerator: datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" qm9: @@ -84,10 +83,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-small/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml index e05d1be8d..d4dc601fa 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml @@ -5,18 +5,10 @@ constants: raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 # Data handling-related batch_size_training: 10 batch_size_inference: 2 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(1) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_mcf7: @@ -65,10 +36,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/mcf7/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +70,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -202,7 +168,7 @@ predictor: l1000_mcf7: [] loss_fun: l1000_mcf7: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 5 random_seed: *seed optim_kwargs: @@ -224,13 +190,13 @@ predictor: metrics: l1000_mcf7: &classif_metrics - name: auroc - metric: auroc_ipu + metric: auroc num_classes: 5 task: multiclass multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision num_classes: 5 task: multiclass target_to_int: True diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml index cf924850e..a59215b90 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml @@ -5,18 +5,10 @@ constants: raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 200 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 400 # Data handling-related batch_size_training: 10 batch_size_inference: 2 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(10) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcba_1328: @@ -65,10 +36,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/pcba/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +70,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -199,7 +165,7 @@ predictor: metrics_on_training_set: pcba_1328: [] loss_fun: - pcba_1328: bce_ipu + pcba_1328: bce random_seed: *seed optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -220,12 +186,12 @@ predictor: metrics: pcba_1328: - name: auroc - metric: auroc_ipu + metric: auroc task: binary multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision task: binary multitask_handling: mean-per-label threshold_kwargs: null diff --git a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml index f1c9bcfd4..4a260e3ee 100644 --- a/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml @@ -5,18 +5,10 @@ constants: raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 150 # Data handling-related batch_size_training: 10 batch_size_inference: 2 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(1) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -65,10 +36,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/vcap/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +70,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -200,7 +166,7 @@ predictor: l1000_vcap: [] loss_fun: l1000_vcap: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 5 random_seed: *seed optim_kwargs: @@ -222,13 +188,13 @@ predictor: metrics: l1000_vcap: &classif_metrics - name: auroc - metric: auroc_ipu + metric: auroc num_classes: 5 task: multiclass multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision num_classes: 5 task: multiclass target_to_int: True diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml index 01988e527..2e4cdaf53 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml @@ -5,18 +5,10 @@ constants: raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 # Data handling-related batch_size_training: 10 batch_size_inference: 10 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(10) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_g25: @@ -68,10 +39,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/g25/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -106,7 +73,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -202,7 +168,7 @@ predictor: metrics_on_training_set: pcqm4m_g25: [] loss_fun: - pcqm4m_g25: mae_ipu + pcqm4m_g25: mae random_seed: *seed optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -223,17 +189,17 @@ predictor: metrics: pcqm4m_g25: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml index fdeb4b399..09425ec37 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml @@ -1,22 +1,14 @@ -# Testing the gine model with the PCQMv2 dataset on IPU. +# Testing the gine model with the PCQMv2 dataset. constants: name: &name neurips2023_large_data_gine_mcf7 seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 # Data handling-related batch_size_training: 10 batch_size_inference: 2 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(1) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_mcf7: @@ -65,10 +36,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/mcf7/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +70,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -202,7 +168,7 @@ predictor: l1000_mcf7: [] loss_fun: l1000_mcf7: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 5 random_seed: *seed optim_kwargs: @@ -224,13 +190,13 @@ predictor: metrics: l1000_mcf7: &classif_metrics - name: auroc - metric: auroc_ipu + metric: auroc num_classes: 5 task: multiclass multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision num_classes: 5 task: multiclass target_to_int: True diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml index 5920a80f6..9ef7254fc 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml @@ -5,18 +5,10 @@ constants: raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 # Data handling-related batch_size_training: 10 batch_size_inference: 10 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(10) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_n4: @@ -69,10 +40,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/n4/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -107,7 +74,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -203,7 +169,7 @@ predictor: metrics_on_training_set: pcqm4m_n4: [] loss_fun: - pcqm4m_n4: mae_ipu + pcqm4m_n4: mae random_seed: *seed optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -224,17 +190,17 @@ predictor: metrics: pcqm4m_n4: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml index de2f7fbc4..ecf18ce9b 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml @@ -5,18 +5,9 @@ constants: raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: - args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 200 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 400 # Data handling-related batch_size_training: 10 batch_size_inference: 2 @@ -28,29 +19,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(10) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcba_1328: @@ -65,10 +35,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/pcba/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +69,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -199,7 +164,7 @@ predictor: metrics_on_training_set: pcba_1328: [] loss_fun: - pcba_1328: bce_ipu + pcba_1328: bce random_seed: *seed optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -220,12 +185,12 @@ predictor: metrics: pcba_1328: - name: auroc - metric: auroc_ipu + metric: auroc task: binary multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision task: binary multitask_handling: mean-per-label threshold_kwargs: null diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml index ca820e86b..13a44e1c0 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml @@ -5,18 +5,10 @@ constants: raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 # Data handling-related batch_size_training: 10 batch_size_inference: 10 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(10) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_g25: @@ -83,10 +54,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/pcq/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -121,7 +88,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -241,8 +207,8 @@ predictor: pcqm4m_g25: [] pcqm4m_n4: [] loss_fun: - pcqm4m_g25: mae_ipu - pcqm4m_n4: mae_ipu + pcqm4m_g25: mae + pcqm4m_n4: mae random_seed: *seed optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -263,17 +229,17 @@ predictor: metrics: pcqm4m_g25: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml b/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml index c21b765b3..3e716d5b1 100644 --- a/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml @@ -1,22 +1,14 @@ -# Testing the gine model with the PCQMv2 dataset on IPU. +# Testing the gine model with the PCQMv2 dataset. constants: name: &name neurips2023_large_data_gine_vcap seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 150 # Data handling-related batch_size_training: 10 batch_size_inference: 2 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(1) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -65,10 +36,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/vcap/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +70,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -200,7 +166,7 @@ predictor: l1000_vcap: [] loss_fun: l1000_vcap: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 5 random_seed: *seed optim_kwargs: @@ -222,13 +188,13 @@ predictor: metrics: l1000_vcap: &classif_metrics - name: auroc - metric: auroc_ipu + metric: auroc num_classes: 5 task: multiclass multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision num_classes: 5 task: multiclass target_to_int: True diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml index b88314797..345620aff 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml @@ -1,22 +1,14 @@ -# Testing the gine model with the PCQMv2 dataset on IPU. +# Testing the gine model with the PCQMv2 dataset. constants: name: &name neurips2023_large_data_gine_g25 seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 # Data handling-related batch_size_training: 10 batch_size_inference: 10 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(10) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_g25: @@ -68,10 +39,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/g25/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -106,7 +73,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -210,7 +176,7 @@ predictor: metrics_on_training_set: pcqm4m_g25: [] loss_fun: - pcqm4m_g25: mae_ipu + pcqm4m_g25: mae random_seed: *seed optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -231,17 +197,17 @@ predictor: metrics: pcqm4m_g25: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml index b96fc8daf..c7a03a80f 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml @@ -1,22 +1,14 @@ -# Testing the gine model with the PCQMv2 dataset on IPU. +# Testing the gine model with the PCQMv2 dataset. constants: name: &name neurips2023_large_data_gine_mcf7 seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 # Data handling-related batch_size_training: 10 batch_size_inference: 2 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(1) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_mcf7: @@ -65,10 +36,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/mcf7/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +70,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -210,7 +176,7 @@ predictor: l1000_mcf7: [] loss_fun: l1000_mcf7: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 5 random_seed: *seed optim_kwargs: @@ -232,13 +198,13 @@ predictor: metrics: l1000_mcf7: &classif_metrics - name: auroc - metric: auroc_ipu + metric: auroc num_classes: 5 task: multiclass multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision num_classes: 5 task: multiclass target_to_int: True diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml index e98ae03da..edba240f9 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml @@ -1,22 +1,14 @@ -# Testing the gine model with the PCQMv2 dataset on IPU. +# Testing the gine model with the PCQMv2 dataset. constants: name: &name neurips2023_large_data_gine_n4 seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 # Data handling-related batch_size_training: 10 batch_size_inference: 10 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(10) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_n4: @@ -69,10 +40,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/n4/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -107,7 +74,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -211,7 +177,7 @@ predictor: metrics_on_training_set: pcqm4m_n4: [] loss_fun: - pcqm4m_n4: mae_ipu + pcqm4m_n4: mae random_seed: *seed optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -232,17 +198,17 @@ predictor: metrics: pcqm4m_n4: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml index 427f7ca0f..32c1af644 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml @@ -1,22 +1,14 @@ -# Testing the gine model with the PCQMv2 dataset on IPU. +# Testing the gine model with the PCQMv2 dataset. constants: name: &name neurips2023_large_data_gine_pcba seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 200 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 400 # Data handling-related batch_size_training: 10 batch_size_inference: 2 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(10) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcba_1328: @@ -65,10 +36,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/pcba/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +70,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -207,7 +173,7 @@ predictor: metrics_on_training_set: pcba_1328: [] loss_fun: - pcba_1328: bce_ipu + pcba_1328: bce random_seed: *seed optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -228,12 +194,12 @@ predictor: metrics: pcba_1328: - name: auroc - metric: auroc_ipu + metric: auroc task: binary multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision task: binary multitask_handling: mean-per-label threshold_kwargs: null diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml index 07fc6d009..15026ae74 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml @@ -1,22 +1,14 @@ -# Testing the gine model with the PCQMv2 dataset on IPU. +# Testing the gine model with the PCQMv2 dataset. constants: name: &name neurips2023_large_data_gine_pcq seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 100 # Data handling-related batch_size_training: 10 batch_size_inference: 10 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(10) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" pcqm4m_g25: @@ -83,10 +54,6 @@ datamodule: method: "normal" # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/pcq/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -121,7 +88,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -249,8 +215,8 @@ predictor: pcqm4m_g25: [] pcqm4m_n4: [] loss_fun: - pcqm4m_g25: mae_ipu - pcqm4m_n4: mae_ipu + pcqm4m_g25: mae + pcqm4m_n4: mae random_seed: *seed optim_kwargs: lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs @@ -271,17 +237,17 @@ predictor: metrics: pcqm4m_g25: &pcqm_metrics - name: mae - metric: mae_ipu + metric: mae target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr - metric: pearsonr_ipu + metric: pearsonr threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2 - metric: r2_score_ipu + metric: r2_score threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label diff --git a/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml b/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml index b63263b3d..089aa8ed3 100644 --- a/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml +++ b/expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml @@ -1,22 +1,14 @@ -# Testing the gine model with the PCQMv2 dataset on IPU. +# Testing the gine model with the PCQMv2 dataset. constants: name: &name neurips2023_large_data_gine_vcap seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: gpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 100 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 150 # Data handling-related batch_size_training: 10 batch_size_inference: 2 @@ -28,29 +20,8 @@ accelerator: precision: 32 accumulate_grad_batches: 8 - ipu_config: - - deviceIterations(1) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - datamodule: module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data args: # Matches that in the test_multitask_datamodule.py case. task_specific_args: # To be replaced by a new class "DatasetParams" l1000_vcap: @@ -65,10 +36,6 @@ datamodule: splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt` # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" processed_graph_data_path: "../datacache/neurips2023-large/vcap/" featurization: # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), @@ -103,7 +70,6 @@ datamodule: num_workers: 30 # -1 to use all persistent_workers: False # if use persistent worker at the start of each epoch. # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" architecture: @@ -208,7 +174,7 @@ predictor: l1000_vcap: [] loss_fun: l1000_vcap: - name: hybrid_ce_ipu + name: hybrid_ce n_brackets: 5 random_seed: *seed optim_kwargs: @@ -230,13 +196,13 @@ predictor: metrics: l1000_vcap: &classif_metrics - name: auroc - metric: auroc_ipu + metric: auroc num_classes: 5 task: multiclass multitask_handling: mean-per-label threshold_kwargs: null - name: avpr - metric: average_precision_ipu + metric: average_precision num_classes: 5 task: multiclass target_to_int: True diff --git a/expts/run_validation_test.py b/expts/run_validation_test.py index 06804301c..48cb0183e 100644 --- a/expts/run_validation_test.py +++ b/expts/run_validation_test.py @@ -8,7 +8,7 @@ import timeit from loguru import logger from datetime import datetime -from pytorch_lightning.utilities.model_summary import ModelSummary +from lightning.pytorch.utilities.model_summary import ModelSummary # Current project imports import graphium diff --git a/graphium/cli/__init__.py b/graphium/cli/__init__.py index e190d9ac4..1e60140fb 100644 --- a/graphium/cli/__init__.py +++ b/graphium/cli/__init__.py @@ -1,4 +1,5 @@ from .data import data_app from .parameters import param_app from .finetune_utils import finetune_app +from .fingerprint import fp_app from .main import app diff --git a/graphium/cli/finetune_utils.py b/graphium/cli/finetune_utils.py index 5566e7961..9500a8371 100644 --- a/graphium/cli/finetune_utils.py +++ b/graphium/cli/finetune_utils.py @@ -3,7 +3,6 @@ import fsspec import numpy as np import torch -import tqdm import typer import yaml from datamol.utils import fs @@ -13,7 +12,7 @@ from omegaconf import OmegaConf from graphium.config._loader import load_accelerator, load_datamodule -from graphium.finetuning.fingerprinting import Fingerprinter +from graphium.fingerprinting.fingerprinter import Fingerprinter from graphium.utils import fs from graphium.trainer.predictor import PredictorModule @@ -24,7 +23,7 @@ app.add_typer(finetune_app, name="finetune") -@finetune_app.command(name="admet") +@finetune_app.command(name="tdc") def benchmark_tdc_admet_cli( overrides: List[str], name: Optional[List[str]] = None, @@ -52,7 +51,7 @@ def benchmark_tdc_admet_cli( # Use the Compose API to construct the config for n in name: - overrides += ["+finetuning=admet", f"constants.task={n}"] + overrides += ["+finetuning=tdc", f"constants.task={n}"] with initialize(version_base=None, config_path="../../expts/hydra-configs"): cfg = compose( @@ -138,14 +137,14 @@ def get_fingerprints_from_model( def get_tdc_task_specific(task: str, output: Literal["name", "mode", "last_activation"]): if output == "last_activation": - config_arch_path = "expts/hydra-configs/tasks/task_heads/admet.yaml" + config_arch_path = "expts/hydra-configs/tasks/task_heads/tdc.yaml" with open(config_arch_path, "r") as yaml_file: config_tdc_arch = yaml.load(yaml_file, Loader=yaml.FullLoader) return config_tdc_arch["architecture"]["task_heads"][task]["last_activation"] else: - config_metrics_path = "expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml" + config_metrics_path = "expts/hydra-configs/tasks/loss_metrics_datamodule/tdc.yaml" with open(config_metrics_path, "r") as yaml_file: config_tdc_task_metric = yaml.load(yaml_file, Loader=yaml.FullLoader) diff --git a/graphium/cli/fingerprint.py b/graphium/cli/fingerprint.py new file mode 100644 index 000000000..0b0319fe9 --- /dev/null +++ b/graphium/cli/fingerprint.py @@ -0,0 +1,57 @@ +from typing import Any, List, Dict + +from loguru import logger + +from omegaconf import OmegaConf + +import wandb + +from graphium.fingerprinting.data import FingerprintDatamodule + +import typer +from hydra import initialize, compose + +from graphium.cli.main import app + +fp_app = typer.Typer(help="Automated fingerprinting from pretrained models.") +app.add_typer(fp_app, name="fps") + +@fp_app.command(name="create", help="Create fingerprints for pretrained model.") +def smiles_to_fps(cfg_name: str, overrides: List[str]) -> Dict[str, Any]: + with initialize(version_base=None, config_path="../../expts/hydra-configs/fingerprinting"): + cfg = compose( + config_name=cfg_name, + overrides=overrides, + ) + cfg = OmegaConf.to_container(cfg, resolve=True) + + if "wandb" in cfg.keys(): + wandb_cfg = cfg.get("wandb") + wandb.init(**wandb_cfg) + + pretrained_models = cfg.get("pretrained") + + # Allow alternative definition of `pretrained_models` with the single model specifier and desired layers + if "layers" in pretrained_models.keys(): + assert "model" in pretrained_models.keys(), "this workflow allows easier definition of fingerprinting sweeps" + model, layers = pretrained_models.pop("model"), pretrained_models.pop("layers") + pretrained_models[model] = layers + + data_kwargs = cfg.get("datamodule") + + datamodule = FingerprintDatamodule( + pretrained_models=pretrained_models, + **data_kwargs, + ) + + datamodule.prepare_data() + + logger.info(f"Fingerprints saved in {datamodule.fps_cache_dir}/fps.pt.") + try: + wandb.run.finish() + except: + pass + + +if __name__ == "__main__": + smiles_to_fps(cfg_name="example-tdc", overrides=[]) \ No newline at end of file diff --git a/graphium/cli/fingerprints.py b/graphium/cli/fingerprints.py deleted file mode 100644 index 62b078eb9..000000000 --- a/graphium/cli/fingerprints.py +++ /dev/null @@ -1,6 +0,0 @@ -from .main import app - - -@app.command(name="fp") -def get_fingerprints_from_model(): - ... diff --git a/graphium/cli/train_finetune_test.py b/graphium/cli/train_finetune_test.py index 09183c69e..d0b5a6597 100644 --- a/graphium/cli/train_finetune_test.py +++ b/graphium/cli/train_finetune_test.py @@ -20,6 +20,7 @@ from graphium.config._loader import ( load_accelerator, load_architecture, + load_mup, load_datamodule, load_metrics, load_predictor, @@ -43,7 +44,6 @@ TESTING_ONLY_CONFIG_KEY = "testing_only" - @hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main") def cli(cfg: DictConfig) -> None: """ @@ -52,76 +52,6 @@ def cli(cfg: DictConfig) -> None: return run_training_finetuning_testing(cfg) -def get_replication_factor(cfg): - try: - ipu_config = cfg.get("accelerator", {}).get("ipu_config", []) - for item in ipu_config: - if "replicationFactor" in item: - # Extract the number between parentheses - start = item.find("(") + 1 - end = item.find(")") - if start != 0 and end != -1: - return int(item[start:end]) - except Exception as e: - print(f"An error occurred: {e}") - - # Return default value if replicationFactor is not found or an error occurred - return 1 - - -def get_gradient_accumulation_factor(cfg): - """ - WARNING: This MUST be called after accelerator overrides have been applied - (i.e. after `load_accelerator` has been called) - """ - try: - # Navigate through the nested dictionaries and get the gradient accumulation factor - grad_accumulation_factor = cfg.get("trainer", {}).get("trainer", {}).get("accumulate_grad_batches", 1) - - # Ensure that the extracted value is an integer - return int(grad_accumulation_factor) - except Exception as e: - print(f"An error occurred: {e}") - - # Return default value if an error occurred - return 1 - - -def get_training_batch_size(cfg): - """ - WARNING: This MUST be called after accelerator overrides have been applied - (i.e. after `load_accelerator` has been called) - """ - try: - # Navigate through the nested dictionaries and get the training batch size - batch_size_training = cfg.get("datamodule", {}).get("args", {}).get("batch_size_training", 1) - - # Ensure that the extracted value is an integer - return int(batch_size_training) - except Exception as e: - print(f"An error occurred: {e}") - - # Return default value if an error occurred - return 1 - - -def get_training_device_iterations(cfg): - try: - ipu_config = cfg.get("accelerator", {}).get("ipu_config", []) - for item in ipu_config: - if "deviceIterations" in item: - # Extract the number between parentheses - start = item.find("(") + 1 - end = item.find(")") - if start != 0 and end != -1: - return int(item[start:end]) - except Exception as e: - print(f"An error occurred: {e}") - - # Return default value if deviceIterations is not found or an error occurred - return 1 - - def run_training_finetuning_testing(cfg: DictConfig) -> None: """ The main (pre-)training and fine-tuning loop. @@ -187,6 +117,8 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: predictor = PredictorModule.load_pretrained_model( name_or_path=get_checkpoint_path(cfg), device=accelerator_type ) + mup_base_path = cfg["architecture"].pop("mup_base_path", None) + predictor = load_mup(mup_base_path, predictor) else: ## Architecture @@ -195,14 +127,6 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: ## Metrics metrics = load_metrics(cfg) - # Note: these MUST be called after `cfg, accelerator = load_accelerator(cfg)` - replicas = get_replication_factor(cfg) - gradient_acc = get_gradient_accumulation_factor(cfg) - micro_bs = get_training_batch_size(cfg) - device_iterations = get_training_device_iterations(cfg) - - global_bs = replicas * gradient_acc * micro_bs * device_iterations - ## Predictor predictor = load_predictor( config=cfg, @@ -213,17 +137,15 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: accelerator_type=accelerator_type, featurization=datamodule.featurization, task_norms=datamodule.task_norms, - replicas=replicas, - gradient_acc=gradient_acc, - global_bs=global_bs, ) logger.info(predictor.model) logger.info(ModelSummary(predictor, max_depth=4)) + metrics_on_progress_bar = predictor.get_metrics_on_progress_bar ## Trainer date_time_suffix = datetime.now().strftime("%d.%m.%Y_%H.%M.%S") - trainer = load_trainer(cfg, accelerator_type, date_time_suffix) + trainer = load_trainer(cfg, accelerator_type, date_time_suffix, metrics_on_progress_bar=metrics_on_progress_bar) if not testing_only: # Add the fine-tuning callback to trainer diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py index 1e542592d..7550c6cbb 100644 --- a/graphium/config/_loader.py +++ b/graphium/config/_loader.py @@ -15,7 +15,7 @@ # Misc import os from copy import deepcopy -from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union, Iterable import joblib import mup @@ -33,17 +33,15 @@ from graphium.data.datamodule import BaseDataModule, MultitaskFromSmilesDataModule from graphium.finetuning.finetuning_architecture import FullGraphFinetuningNetwork -from graphium.ipu.ipu_dataloader import IPUDataloaderOptions -from graphium.ipu.ipu_utils import import_poptorch, load_ipu_options from graphium.nn.architectures import FullGraphMultiTaskNetwork from graphium.nn.utils import MupMixin from graphium.trainer.metrics import MetricWrapper from graphium.trainer.predictor import PredictorModule from graphium.utils.command_line_utils import get_anchors_and_aliases, update_config +from graphium.trainer.progress_bar import ProgressBarMetrics # Graphium from graphium.utils.mup import set_base_shapes -from graphium.utils.spaces import DATAMODULE_DICT, GRAPHIUM_PRETRAINED_MODELS_DICT from graphium.utils import fs @@ -62,107 +60,38 @@ def get_accelerator( if (accelerator_type == "gpu") and (not torch.cuda.is_available()): raise ValueError(f"GPUs selected, but GPUs are not available on this device") - # Get the IPU info - if accelerator_type == "ipu": - poptorch = import_poptorch() - if poptorch is None: - raise ValueError("IPUs selected, but PopTorch is not available") - if not poptorch.ipuHardwareIsAvailable(): - raise ValueError( - "IPUs selected, but no IPU is available/visible on this device. " - "If you do have IPUs, please check that the IPUOF_VIPU_API_PARTITION_ID and " - "IPUOF_VIPU_API_HOST environment variables are set." - ) - # Fall on cpu at the end if accelerator_type is None: accelerator_type = "cpu" return accelerator_type -def _get_ipu_opts(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -> Tuple[str, str]: - r""" - Get the paths of the IPU-specific config files from the main YAML config - """ - - accelerator_options = config["accelerator"] - accelerator_type = accelerator_options["type"] - - if accelerator_type != "ipu": - return None, None - ipu_opts = accelerator_options["ipu_config"] - ipu_inference_opts = accelerator_options.get("ipu_inference_config", None) - - return ipu_opts, ipu_inference_opts - - def load_datamodule( config: Union[omegaconf.DictConfig, Dict[str, Any]], accelerator_type: str ) -> BaseDataModule: """ Load the datamodule from the specified configurations at the key `datamodule: args`. - If the accelerator is IPU, load the IPU options as well. Parameters: config: The config file, with key `datamodule: args` - accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu" + accelerator_type: The accelerator type, e.g. "cpu", "gpu" Returns: datamodule: The datamodule used to process and load the data """ + from graphium.utils.spaces import DATAMODULE_DICT # Avoid circular imports with `spaces.py` + cfg_data = config["datamodule"]["args"] # Instanciate the datamodule module_class = DATAMODULE_DICT[config["datamodule"]["module_type"]] - if accelerator_type != "ipu": - datamodule = module_class( - **config["datamodule"]["args"], - ) - return datamodule - - # IPU specific adjustments - else: - ipu_opts, ipu_inference_opts = _get_ipu_opts(config) - - # Default empty values for the IPU configurations - ipu_training_opts = None - - ipu_dataloader_training_opts = cfg_data.pop("ipu_dataloader_training_opts", {}) - ipu_dataloader_inference_opts = cfg_data.pop("ipu_dataloader_inference_opts", {}) - ipu_training_opts, ipu_inference_opts = load_ipu_options( - ipu_opts=ipu_opts, - seed=config["constants"]["seed"], - model_name=config["constants"]["name"], - gradient_accumulation=config["trainer"]["trainer"].get("accumulate_grad_batches", None), - ipu_inference_opts=ipu_inference_opts, - precision=config["trainer"]["trainer"].get("precision"), - ) - - # Define the Dataloader options for the IPU on the training sets - bz_train = cfg_data["batch_size_training"] - ipu_dataloader_training_opts = IPUDataloaderOptions( - batch_size=bz_train, **ipu_dataloader_training_opts - ) - ipu_dataloader_training_opts.set_kwargs() - - # Define the Dataloader options for the IPU on the inference sets - bz_test = cfg_data["batch_size_inference"] - ipu_dataloader_inference_opts = IPUDataloaderOptions( - batch_size=bz_test, **ipu_dataloader_inference_opts - ) - ipu_dataloader_inference_opts.set_kwargs() - - datamodule = module_class( - ipu_training_opts=ipu_training_opts, - ipu_inference_opts=ipu_inference_opts, - ipu_dataloader_training_opts=ipu_dataloader_training_opts, - ipu_dataloader_inference_opts=ipu_dataloader_inference_opts, - **config["datamodule"]["args"], - ) + datamodule = module_class( + **config["datamodule"]["args"], + ) + return datamodule - return datamodule def load_metrics(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -> Dict[str, MetricWrapper]: @@ -203,8 +132,6 @@ def load_architecture( architecture: The datamodule used to process and load the data """ - if isinstance(config, dict) and "finetuning" not in config: - config = omegaconf.OmegaConf.create(config) cfg_arch = config["architecture"] # Select the architecture @@ -262,10 +189,6 @@ def load_architecture( else: gnn_kwargs.setdefault("in_dim", edge_in_dim) - # Set the parameters for the full network - if "finetuning" not in config: - task_heads_kwargs = omegaconf.OmegaConf.to_object(task_heads_kwargs) - # Set all the input arguments for the model model_kwargs = dict( gnn_kwargs=gnn_kwargs, @@ -304,25 +227,17 @@ def load_predictor( accelerator_type: str, featurization: Dict[str, str] = None, task_norms: Optional[Dict[Callable, Any]] = None, - replicas: int = 1, - gradient_acc: int = 1, - global_bs: int = 1, ) -> PredictorModule: """ Defining the predictor module, which handles the training logic from `lightning.LighningModule` Parameters: model_class: The torch Module containing the main forward function - accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu" + accelerator_type: The accelerator type, e.g. "cpu", "gpu" Returns: predictor: The predictor module """ - if accelerator_type == "ipu": - from graphium.ipu.ipu_wrapper import PredictorModuleIPU - - predictor_class = PredictorModuleIPU - else: - predictor_class = PredictorModule + predictor_class = PredictorModule cfg_pred = dict(deepcopy(config["predictor"])) predictor = predictor_class( @@ -332,9 +247,6 @@ def load_predictor( task_levels=task_levels, featurization=featurization, task_norms=task_norms, - replicas=replicas, - gradient_acc=gradient_acc, - global_bs=global_bs, **cfg_pred, ) @@ -351,9 +263,6 @@ def load_predictor( task_levels=task_levels, featurization=featurization, task_norms=task_norms, - replicas=replicas, - gradient_acc=gradient_acc, - global_bs=global_bs, **cfg_pred, ) @@ -390,47 +299,21 @@ def load_trainer( config: Union[omegaconf.DictConfig, Dict[str, Any]], accelerator_type: str, date_time_suffix: str = "", + metrics_on_progress_bar: Optional[Iterable[str]] = None, ) -> Trainer: """ Defining the pytorch-lightning Trainer module. Parameters: config: The config file, with key `trainer` - accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu" + accelerator_type: The accelerator type, e.g. "cpu", "gpu" date_time_suffix: The date and time of the current run. To be used for logging. Returns: trainer: the trainer module """ cfg_trainer = deepcopy(config["trainer"]) - # Define the IPU plugin if required - strategy = cfg_trainer["trainer"].pop("strategy", "auto") - if accelerator_type == "ipu": - ipu_opts, ipu_inference_opts = _get_ipu_opts(config) - - training_opts, inference_opts = load_ipu_options( - ipu_opts=ipu_opts, - ipu_inference_opts=ipu_inference_opts, - seed=config["constants"]["seed"], - model_name=config["constants"]["name"], - gradient_accumulation=config["trainer"]["trainer"].get("accumulate_grad_batches", None), - precision=config["trainer"]["trainer"].get("precision"), - ) - - if strategy != "auto": - raise ValueError("IPUs selected, but strategy is not set to 'auto'") - - from lightning_graphcore import IPUStrategy - - strategy = IPUStrategy(training_opts=training_opts, inference_opts=inference_opts) - # Get devices devices = cfg_trainer["trainer"].pop("devices", 1) - if accelerator_type == "ipu": - devices = 1 # number of IPUs used is defined in the ipu options files - - # Remove the gradient accumulation from IPUs, since it's handled by the device - if accelerator_type == "ipu": - cfg_trainer["trainer"].pop("accumulate_grad_batches", None) # Define the early stopping parameters trainer_kwargs = {} @@ -453,14 +336,16 @@ def load_trainer( name = wandb_cfg.pop("name", "main") if len(date_time_suffix) > 0: name += f"_{date_time_suffix}" - trainer_kwargs["logger"] = WandbLogger(name=name, log_model=True, **wandb_cfg) + trainer_kwargs["logger"] = WandbLogger(name=name, **wandb_cfg) + + progress_bar_callback = ProgressBarMetrics(metrics_on_progress_bar = metrics_on_progress_bar) + callbacks.append(progress_bar_callback) - trainer_kwargs["callbacks"] = callbacks trainer = Trainer( detect_anomaly=True, - strategy=strategy, accelerator=accelerator_type, devices=devices, + callbacks=callbacks, **cfg_trainer["trainer"], **trainer_kwargs, ) @@ -631,6 +516,8 @@ def get_checkpoint_path(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -> Otherwise, assume it refers to a file in the checkpointing dir. """ + from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT # Avoid circular imports with `spaces.py` + cfg_trainer = config["trainer"] path = config.get("ckpt_name_for_testing", "last.ckpt") diff --git a/graphium/config/dummy_finetuning_from_gnn.yaml b/graphium/config/dummy_finetuning_from_gnn.yaml index ca9493d30..becdada2c 100644 --- a/graphium/config/dummy_finetuning_from_gnn.yaml +++ b/graphium/config/dummy_finetuning_from_gnn.yaml @@ -55,7 +55,8 @@ finetuning: constants: seed: 42 - max_epochs: 2 + max_epochs: 5 + model_dropout: 0. accelerator: float32_matmul_precision: medium @@ -64,14 +65,14 @@ accelerator: predictor: random_seed: ${constants.seed} optim_kwargs: - lr: 4.e-5 + lr: 1.e-3 scheduler_kwargs: null - target_nan_mask: null + target_nan_mask: ignore multitask_handling: flatten # flatten, mean-per-label torch_scheduler_kwargs: module_type: WarmUpLinearLR - max_num_epochs: 2 + max_num_epochs: 4 warmup_epochs: 1 verbose: False @@ -84,22 +85,22 @@ metrics: lipophilicity_astrazeneca: - name: mae metric: mae - target_nan_mask: null + target_nan_mask: ignore multitask_handling: flatten threshold_kwargs: null - name: spearman metric: spearmanr threshold_kwargs: null - target_nan_mask: null + target_nan_mask: ignore multitask_handling: mean-per-label - name: pearson metric: pearsonr threshold_kwargs: null - target_nan_mask: null + target_nan_mask: ignore multitask_handling: mean-per-label - name: r2_score metric: r2_score - target_nan_mask: null + target_nan_mask: ignore multitask_handling: mean-per-label threshold_kwargs: null @@ -107,7 +108,7 @@ trainer: seed: ${constants.seed} trainer: precision: 32 - max_epochs: 2 + max_epochs: 5 min_epochs: 1 check_val_every_n_epoch: 1 accumulate_grad_batches: 1 @@ -120,18 +121,14 @@ datamodule: ### FROM FINETUNING ### - module_type: "ADMETBenchmarkDataModule" + module_type: "TDCBenchmarkDataModule" args: + processed_graph_data_path: datacache/processed_graph_data/dummy_finetuning_from_gnn # TDC specific tdc_benchmark_names: [lipophilicity_astrazeneca] tdc_train_val_seed: ${constants.seed} - - batch_size_training: 200 - batch_size_inference: 200 - featurization_n_jobs: 0 + batch_size_training: 20 + batch_size_inference: 20 num_workers: 0 - prepare_dict_or_graph: pyg:graph - featurization_progress: True - featurization_backend: "loky" persistent_workers: False \ No newline at end of file diff --git a/graphium/config/dummy_finetuning_from_task_head.yaml b/graphium/config/dummy_finetuning_from_task_head.yaml index 2682ccee3..77ee6852b 100644 --- a/graphium/config/dummy_finetuning_from_task_head.yaml +++ b/graphium/config/dummy_finetuning_from_task_head.yaml @@ -61,7 +61,8 @@ finetuning: constants: seed: 42 - max_epochs: 2 + max_epochs: 5 + model_dropout: 0. accelerator: float32_matmul_precision: medium @@ -70,14 +71,14 @@ accelerator: predictor: random_seed: ${constants.seed} optim_kwargs: - lr: 4.e-5 + lr: 1.e-3 scheduler_kwargs: null target_nan_mask: null multitask_handling: flatten # flatten, mean-per-label torch_scheduler_kwargs: module_type: WarmUpLinearLR - max_num_epochs: 2 + max_num_epochs: 5 warmup_epochs: 1 verbose: False @@ -113,7 +114,7 @@ trainer: seed: ${constants.seed} trainer: precision: 32 - max_epochs: 2 + max_epochs: 4 min_epochs: 1 check_val_every_n_epoch: 1 accumulate_grad_batches: 1 @@ -126,20 +127,16 @@ datamodule: ### FROM FINETUNING ### - module_type: "ADMETBenchmarkDataModule" + module_type: "TDCBenchmarkDataModule" args: + processed_graph_data_path: datacache/processed_graph_data/dummy_finetuning_task_head # TDC specific tdc_benchmark_names: [lipophilicity_astrazeneca] tdc_train_val_seed: ${constants.seed} - - batch_size_training: 200 - batch_size_inference: 200 - featurization_n_jobs: 0 + batch_size_training: 20 + batch_size_inference: 20 num_workers: 0 - prepare_dict_or_graph: pyg:graph - featurization_progress: True - featurization_backend: "loky" persistent_workers: False diff --git a/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml b/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml index 044a0129c..a34399dd1 100644 --- a/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml +++ b/graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml @@ -45,8 +45,6 @@ datamodule: weights_col: null # This may not always be provided # Featurization - featurization_n_jobs: 16 - featurization_progress: True featurization: atom_property_list_onehot: ["atomic-number", "degree"] atom_property_list_float: [] diff --git a/graphium/config/fake_multilevel_multitask_pyg.yaml b/graphium/config/fake_multilevel_multitask_pyg.yaml index 918807cb4..3cce7b5e2 100644 --- a/graphium/config/fake_multilevel_multitask_pyg.yaml +++ b/graphium/config/fake_multilevel_multitask_pyg.yaml @@ -45,8 +45,6 @@ datamodule: weights_col: null # This may not always be provided # Featurization - featurization_n_jobs: 16 - featurization_progress: True featurization: atom_property_list_onehot: ["atomic-number", "degree"] atom_property_list_float: [] diff --git a/graphium/config/loc-config_largemix.yaml b/graphium/config/loc-config_largemix.yaml new file mode 100644 index 000000000..2c96e39ba --- /dev/null +++ b/graphium/config/loc-config_largemix.yaml @@ -0,0 +1,424 @@ +constants: + wandb: + entity: valencelabs + project: graphium3.0 + name: ${constants.scale}/mpnn/large-no_l1000 + tags: + - mpnn + - large + - no-l1000 + - ${constants.scale} + data_dir: /home/domix/Gitx/graphium/graphium/data/largemix + datacache_path: /home/domix/Gitx/graphium/datacache/largemix + scale: 10M + max_epochs: 50 + name: scale_mpnn + raise_train_error: true + seed: 42 + variants: + 1M: + mup_scale_factor: 0.27 + epochs: 50 + batch_size: 1024 + accumulate_grad_batches: 1 + depth: 16 + # mup_base_path: /rxrx/data/user/frederik.wenkel/outgoing/mup/large-no_l1000/mpnn.yaml + train_frac: 1.0 + 3M: + mup_scale_factor: 0.505 + epochs: 50 + batch_size: 1024 + accumulate_grad_batches: 1 + depth: 16 + # mup_base_path: /rxrx/data/user/frederik.wenkel/outgoing/mup/large-no_l1000/mpnn.yaml + train_frac: 1.0 + 10M: + mup_scale_factor: null + epochs: 50 + batch_size: 1024 + accumulate_grad_batches: 1 + depth: 16 + mup_base_path: null + train_frac: 1.0 + 30M: + mup_scale_factor: 1.798 + epochs: 50 + batch_size: 1024 + accumulate_grad_batches: 1 + depth: 16 + # mup_base_path: /rxrx/data/user/frederik.wenkel/outgoing/mup/large-no_l1000/mpnn.yaml + train_frac: 1.0 + 100M: + mup_scale_factor: 3.38 + epochs: 30 + batch_size: 512 + accumulate_grad_batches: 2 + depth: 16 + # mup_base_path: /rxrx/data/user/frederik.wenkel/outgoing/mup/large-no_l1000/mpnn.yaml + train_frac: 1.0 + 300M: + mup_scale_factor: 5.91 + epochs: 30 + batch_size: 256 + accumulate_grad_batches: 1 + depth: 16 + # mup_base_path: /rxrx/data/user/frederik.wenkel/outgoing/mup/large-no_l1000/mpnn.yaml + train_frac: 1.0 + 1B: + mup_scale_factor: 11.0 + epochs: 20 + batch_size: 256 + accumulate_grad_batches: 1 + depth: 16 + # mup_base_path: /rxrx/data/user/frederik.wenkel/outgoing/mup/large-no_l1000/mpnn.yaml + train_frac: 1.0 + 3B: + mup_scale_factor: 18.8 + epochs: 20 + batch_size: 128 + accumulate_grad_batches: 1 + depth: 16 + # mup_base_path: /rxrx/data/user/frederik.wenkel/outgoing/mup/large-no_l1000/mpnn.yaml + train_frac: 1.0 + 125Mol: + mup_scale_factor: 3.38 + epochs: 30 + batch_size: 512 + accumulate_grad_batches: 2 + depth: 16 + # mup_base_path: /rxrx/data/user/frederik.wenkel/outgoing/mup/large-no_l1000/mpnn.yaml + train_frac: 0.125 + 250Mol: + mup_scale_factor: 3.38 + epochs: 30 + batch_size: 512 + accumulate_grad_batches: 2 + depth: 16 + # mup_base_path: /rxrx/data/user/frederik.wenkel/outgoing/mup/large-no_l1000/mpnn.yaml + train_frac: 0.25 + 500Mol: + mup_scale_factor: 3.38 + epochs: 30 + batch_size: 512 + accumulate_grad_batches: 2 + depth: 16 + # mup_base_path: /rxrx/data/user/frederik.wenkel/outgoing/mup/large-no_l1000/mpnn.yaml + train_frac: 0.5 + dataset_fraction: 1.0 +accelerator: + float32_matmul_precision: medium + type: gpu +architecture: + mup_scale_factor: ${constants.variants.${constants.scale}.mup_scale_factor} + mup_base_path: ${constants.variants.${constants.scale}.mup_base_path} + pre_nn: + activation: relu + depth: 2 + dropout: 0.1 + hidden_dims: 4 + last_activation: none + last_normalization: layer_norm + normalization: layer_norm + out_dim: 64 + residual_type: none + pre_nn_edges: + activation: relu + depth: 2 + dropout: 0.1 + hidden_dims: 4 + last_activation: none + last_normalization: layer_norm + normalization: layer_norm + out_dim: 32 + residual_type: none + gnn: + activation: gelu + depth: ${constants.variants.${constants.scale}.depth} + dropout: 0.1 + hidden_dims: 4 + hidden_dims_edges: 2 + out_dim_edges: 2 + in_dim: 64 + last_activation: none + last_normalization: layer_norm + layer_kwargs: + mlp_expansion_ratio: 1 + layer_type: pyg:mpnnplus + normalization: layer_norm + out_dim: 4 + residual_type: simple + virtual_node: none + graph_output_nn: + graph: + activation: relu + depth: 2 + dropout: 0.1 + hidden_dims: 4 + last_activation: none + last_normalization: none + normalization: layer_norm + out_dim: 653 + pooling: + - sum + residual_type: none + node: + activation: relu + depth: 2 + dropout: 0.1 + hidden_dims: 4 + last_activation: none + last_normalization: none + normalization: layer_norm + out_dim: 84 + pooling: + - sum + residual_type: none + model_type: FullGraphMultiTaskNetwork + pe_encoders: + encoders: + la_pos: + dropout: 0.1 + encoder_type: laplacian_pe + first_normalization: none + hidden_dim: 2 + input_keys: + - laplacian_eigvec + - laplacian_eigval + model_type: DeepSet + num_layers: 2 + num_layers_post: 1 + out_dim: 32 + output_keys: + - feat + rw_pos: + dropout: 0.1 + encoder_type: mlp + first_normalization: layer_norm + hidden_dim: 2 + input_keys: + - rw_return_probs + normalization: layer_norm + num_layers: 2 + out_dim: 32 + output_keys: + - feat + last_norm: None + out_dim: 32 + pool: sum + task_heads: + pcba_1328: + activation: relu + depth: 2 + dropout: 0.1 + hidden_dims: 4 + last_activation: none + last_normalization: none + normalization: layer_norm + out_dim: 448 + residual_type: none + task_level: graph + pcqm4m_g25: + activation: relu + depth: 2 + dropout: 0.1 + hidden_dims: 4 + last_activation: none + last_normalization: none + normalization: layer_norm + out_dim: 25 + residual_type: none + task_level: graph + pcqm4m_n4: + activation: relu + depth: 2 + dropout: 0.1 + hidden_dims: 4 + last_activation: none + last_normalization: none + normalization: layer_norm + out_dim: 4 + residual_type: none + task_level: node +datamodule: + args: + batch_size_inference: 1024 + batch_size_training: ${constants.variants.${constants.scale}.batch_size} + train_frac: ${constants.variants.${constants.scale}.train_frac} + featurization: + add_self_loop: false + atom_property_list_float: + - degree + - formal-charge + - radical-electron + - aromatic + - in-ring + atom_property_list_onehot: + - atomic-number + - group + - period + - total-valence + edge_property_list: + - bond-type-onehot + - stereo + - in-ring + explicit_H: false + max_num_atoms: 50 + pos_encoding_as_features: + pos_types: + lap_eigval: + disconnected_comp: true + normalization: none + num_pos: 8 + pos_level: node + pos_type: laplacian_eigval + lap_eigvec: + disconnected_comp: true + normalization: none + num_pos: 8 + pos_level: node + pos_type: laplacian_eigvec + rw_pos: + ksteps: 16 + pos_level: node + pos_type: rw_return_probs + use_bonds_weights: false + multiprocessing_context: spawn + num_workers: 4 + persistent_workers: true + processed_graph_data_path: ${constants.datacache_path} + task_specific_args: + pcba_1328: + df: null + df_path: ${constants.data_dir}/PCBA_1328_1564k.parquet + epoch_sampling_fraction: 1 + sample_size: ${constants.dataset_fraction} # use sample_size for test + label_cols: 'assayID-1*' + smiles_col: SMILES + splits_path: ${constants.data_dir}/pcba_1328_random_splits.pt + task_level: graph + pcqm4m_g25: + df: null + df_path: ${constants.data_dir}/PCQM4M_G25_N4.parquet + epoch_sampling_fraction: 1 + sample_size: ${constants.dataset_fraction} # use sample_size for test + label_cols: graph_* + label_normalization: + method: normal + normalize_val_test: true + smiles_col: ordered_smiles + splits_path: ${constants.data_dir}/pcqm4m_g25_n4_random_splits.pt + task_level: graph + pcqm4m_n4: + df: null + df_path: ${constants.data_dir}/PCQM4M_G25_N4.parquet + epoch_sampling_fraction: 1 + sample_size: ${constants.dataset_fraction} # use sample_size for test + label_cols: node_* + label_normalization: + method: normal + normalize_val_test: true + seed: 42 + smiles_col: ordered_smiles + splits_path: ${constants.data_dir}/pcqm4m_g25_n4_random_splits.pt + task_level: node + module_type: MultitaskFromSmilesDataModule +metrics: + # pcba_1328: [] + # pcqm4m_g25: [] + # pcqm4m_n4: [] + pcba_1328: + - metric: auroc + multitask_handling: mean-per-label + name: auroc + target_nan_mask: ignore + task: binary + threshold_kwargs: null + target_to_int: true + - metric: averageprecision + multitask_handling: mean-per-label + name: avpr + target_nan_mask: 0 + task: binary + threshold_kwargs: null + target_to_int: true + pcqm4m_g25: + - metric: mae + multitask_handling: mean-per-label + name: mae + target_nan_mask: ignore + threshold_kwargs: null + - metric: pearsonr + multitask_handling: mean-per-label + name: pearsonr + target_nan_mask: ignore + threshold_kwargs: null + - metric: r2_score + multitask_handling: mean-per-label + name: r2 + target_nan_mask: ignore + threshold_kwargs: null + pcqm4m_n4: + - metric: mae + multitask_handling: mean-per-label + name: mae + target_nan_mask: ignore + threshold_kwargs: null + - metric: pearsonr + multitask_handling: mean-per-label + name: pearsonr + target_nan_mask: ignore + threshold_kwargs: null + - metric: r2_score + multitask_handling: mean-per-label + name: r2 + target_nan_mask: ignore + threshold_kwargs: null +predictor: + loss_fun: + pcba_1328: bce_logits + pcqm4m_g25: mae + pcqm4m_n4: mae + metrics_every_n_train_steps: 5 + metrics_on_progress_bar: + pcba_1328: [] + pcqm4m_g25: [] + pcqm4m_n4: [] + metrics_on_training_set: + pcba_1328: [] + pcqm4m_g25: [pearsonr] + pcqm4m_n4: [pearsonr] + multitask_handling: flatten + optim_kwargs: + lr: ${eval:"0.003/(((${architecture.gnn.depth}+8)/24)**0.5)"} + random_seed: 42 + scheduler_kwargs: null + target_nan_mask: ignore + torch_scheduler_kwargs: + max_num_epochs: ${constants.max_epochs} + module_type: WarmUpLinearLR + verbose: false + warmup_epochs: 5 +tasks: {} +trainer: + logger: + name: ${constants.scale}/mpnn/large-no_l1000 + project: molgps-pretraining + save_dir: logs/molgps-pretraining/large-no_l1000/ + model_checkpoint: + save_last: false + save_top_k: -1 + dirpath: model_checkpoints/graphium3/large-no_l1000/${constants.scale}/mpnn/${constants.seed}/ + every_n_epochs: 5 + filename: '{epoch}' + seed: ${constants.seed} + trainer: + check_val_every_n_epoch: 1 + max_epochs: 50 + min_epochs: 1 + precision: 32 + accumulate_grad_batches: ${constants.variants.${constants.scale}.accumulate_grad_batches} + num_sanity_val_steps: 2 + devices: 1 + strategy: ddp_find_unused_parameters_true + limit_train_batches: 20 + limit_val_batches: 20 diff --git a/graphium/config/zinc_default_multitask_pyg.yaml b/graphium/config/zinc_default_multitask_pyg.yaml index b9435ec7e..01d20bc53 100644 --- a/graphium/config/zinc_default_multitask_pyg.yaml +++ b/graphium/config/zinc_default_multitask_pyg.yaml @@ -45,8 +45,6 @@ datamodule: weights_type: null # Featurization - featurization_n_jobs: 16 - featurization_progress: True featurization: atom_property_list_onehot: ["atomic-number", "degree"] atom_property_list_float: [] diff --git a/graphium/data/__init__.py b/graphium/data/__init__.py index 0a8fcd24d..d6becddf3 100644 --- a/graphium/data/__init__.py +++ b/graphium/data/__init__.py @@ -5,9 +5,6 @@ from .datamodule import GraphOGBDataModule from .datamodule import MultitaskFromSmilesDataModule -from .datamodule import ADMETBenchmarkDataModule -from .datamodule import FakeDataModule +from .datamodule import TDCBenchmarkDataModule -from .dataset import SingleTaskDataset from .dataset import MultitaskDataset -from .dataset import FakeDataset diff --git a/graphium/data/collate.py b/graphium/data/collate.py index 22486b034..0c1f6ef44 100644 --- a/graphium/data/collate.py +++ b/graphium/data/collate.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -22,19 +22,17 @@ from typing import Union, List, Optional, Dict, Type, Any, Iterable from torch_geometric.data import Data, Batch -from graphium.features import GraphDict, to_dense_array -from graphium.utils.packing import fast_packing, get_pack_sizes, node_to_pack_indices_mask from loguru import logger from graphium.data.utils import get_keys +from graphium.data.dataset import torch_enum_to_dtype def graphium_collate_fn( elements: Union[List[Any], Dict[str, List[Any]]], - labels_size_dict: Optional[Dict[str, Any]] = None, + labels_num_cols_dict: Optional[Dict[str, Any]] = None, labels_dtype_dict: Optional[Dict[str, Any]] = None, mask_nan: Union[str, float, Type[None]] = "raise", do_not_collate_keys: List[str] = [], - batch_size_per_pack: Optional[int] = None, ) -> Union[Any, Dict[str, Any]]: """This collate function is identical to the default pytorch collate function but add support for `pyg.data.Data` to batch graphs. @@ -52,7 +50,7 @@ def graphium_collate_fn( elements: The elements to batch. See `torch.utils.data.dataloader.default_collate`. - labels_size_dict: + labels_num_cols_dict: (Note): This is an attribute of the `MultitaskDataset`. A dictionary of the form Dict[tasks, sizes] which has task names as keys and the size of the label tensor as value. The size of the tensor corresponds to how many @@ -76,35 +74,38 @@ def graphium_collate_fn( do_not_batch_keys: Keys to ignore for the collate - batch_size_per_pack: The number of graphs to pack together. - This is useful for using packing with the Transformer. - If None, no packing is done. - Otherwise, indices are generated to map the nodes to the pack they belong to under the key `"pack_from_node_idx"`, - with an additional mask to indicate which nodes are from the same graph under the key `"pack_attn_mask"`. - Returns: The batched elements. See `torch.utils.data.dataloader.default_collate`. """ + # Skip any elements that failed + if None in elements: + elements = [e for e in elements if e is not None] + elem = elements[0] if isinstance(elem, Mapping): + if "features" in elem: + num_nodes = [d["features"].num_nodes for d in elements] + num_edges = [d["features"].num_edges for d in elements] + else: + num_nodes = [d["num_nodes"] for d in elements] + num_edges = [d["num_edges"] for d in elements] + batch = {} for key in elem: # Multitask setting: We have to pad the missing labels if key == "labels": labels = [d[key] for d in elements] - batch[key] = collate_labels(labels, labels_size_dict, labels_dtype_dict) - - # If the features are a dictionary containing GraphDict elements, - # Convert to pyg graphs and use the pyg batching. - elif isinstance(elem[key], GraphDict): - pyg_graphs = [d[key].make_pyg_graph(mask_nan=mask_nan) for d in elements] - batch[key] = collage_pyg_graph(pyg_graphs) + batch[key] = collate_labels( + labels, labels_num_cols_dict, labels_dtype_dict, num_nodes, num_edges + ) + elif key == "num_nodes" or key == "num_edges": + continue # If a PyG Graph is provided, use the PyG batching elif isinstance(elem[key], Data): pyg_graphs = [d[key] for d in elements] - batch[key] = collage_pyg_graph(pyg_graphs, batch_size_per_pack=batch_size_per_pack) + batch[key] = collage_pyg_graph(pyg_graphs, num_nodes) # Ignore the collate for specific keys elif key in do_not_collate_keys: @@ -125,132 +126,86 @@ def graphium_collate_fn( return default_collate(elements) -def collage_pyg_graph(pyg_graphs: Iterable[Union[Data, Dict]], batch_size_per_pack: Optional[int] = None): +def collage_pyg_graph( + pyg_graphs: List[Data], num_nodes: List[int], +): """ Function to collate pytorch geometric graphs. Convert all numpy types to torch Convert edge indices to int64 Parameters: - pyg_graphs: Iterable of PyG graphs - batch_size_per_pack: The number of graphs to pack together. - This is useful for using packing with the Transformer, + pyg_graphs: List of PyG graphs """ # Calculate maximum number of nodes per graph in current batch - num_nodes_list = [] - for pyg_graph in pyg_graphs: - num_nodes_list.append(pyg_graph["num_nodes"]) - max_num_nodes_per_graph = max(num_nodes_list) + max_num_nodes_per_graph = max(num_nodes) - pyg_batch = [] for pyg_graph in pyg_graphs: for pyg_key in get_keys(pyg_graph): - tensor = pyg_graph[pyg_key] - - # Convert numpy/scipy to Pytorch - if isinstance(tensor, (ndarray, spmatrix)): - tensor = torch.as_tensor(to_dense_array(tensor, tensor.dtype)) - # pad nodepair-level positional encodings if pyg_key.startswith("nodepair_"): - pyg_graph[pyg_key] = pad_nodepairs(tensor, pyg_graph["num_nodes"], max_num_nodes_per_graph) - else: - pyg_graph[pyg_key] = tensor + pyg_graph[pyg_key] = pad_nodepairs( + pyg_graph[pyg_key], pyg_graph.num_nodes, max_num_nodes_per_graph + ) # Convert edge index to int64 pyg_graph.edge_index = pyg_graph.edge_index.to(torch.int64) - pyg_batch.append(pyg_graph) - - # Apply the packing at the mini-batch level. This is useful for using packing with the Transformer, - # especially in the case of the large graphs being much larger than the small graphs. - # CAREFUL!!! This changes the order of the graphs in the batch, without changing the order of the labels or other objects. - # An error is raised temporarily. - if batch_size_per_pack is not None: - raise NotImplementedError( - "Packing is not yet functional, as it changes the order of the graphs in the batch without changing the label order" - ) - num_nodes = [g.num_nodes for g in pyg_batch] - packed_graph_idx = fast_packing(num_nodes, batch_size_per_pack) - - # Get the node to pack indices and the mask - pack_from_node_idx, pack_attn_mask = node_to_pack_indices_mask(packed_graph_idx, num_nodes) - for pyg_graph in pyg_batch: - pyg_graph.pack_from_node_idx = pack_from_node_idx - pyg_graph.pack_attn_mask = pack_attn_mask - return Batch.from_data_list(pyg_batch) + return Batch.from_data_list(pyg_graphs) -def pad_to_expected_label_size(labels: torch.Tensor, label_size: List[int]): +def pad_to_expected_label_size(labels: torch.Tensor, label_rows: int, label_cols: int): """Determine difference of ``labels`` shape to expected shape `label_size` and pad with ``torch.nan`` accordingly. """ - if label_size == list(labels.shape): + if len(labels.shape) == 2 and label_rows == labels.shape[0] and label_cols == labels.shape[1]: return labels - missing_dims = len(label_size) - len(labels.shape) + missing_dims = 2 - len(labels.shape) for _ in range(missing_dims): labels.unsqueeze(-1) - pad_sizes = [(0, expected - actual) for expected, actual in zip(label_size, labels.shape)] - pad_sizes = [item for before_after in pad_sizes for item in before_after] - pad_sizes.reverse() + pad_sizes = [label_cols - labels.shape[1], 0, label_rows - labels.shape[0], 0] if any([s < 0 for s in pad_sizes]): - logger.warning(f"More labels available than expected. Will remove data to fit expected size.") + logger.warning( + f"More labels available than expected. Will remove data to fit expected size. cols: {labels.shape[1]}->{label_cols}, rows: {labels.shape[0]}->{label_rows}" + ) return torch.nn.functional.pad(labels, pad_sizes, value=torch.nan) -def collate_pyg_graph_labels(pyg_labels: List[Data]): - """ - Function to collate pytorch geometric labels. - Convert all numpy types to torch - - Parameters: - pyg_labels: Iterable of PyG label Data objects - """ - pyg_batch = [] - for pyg_label in pyg_labels: - for pyg_key in set(get_keys(pyg_label)) - set(["x", "edge_index"]): - tensor = pyg_label[pyg_key] - # Convert numpy/scipy to Pytorch - if isinstance(tensor, (ndarray, spmatrix)): - tensor = torch.as_tensor(to_dense_array(tensor, tensor.dtype)) - - pyg_label[pyg_key] = tensor - - pyg_batch.append(pyg_label) - - return Batch.from_data_list(pyg_batch) - - -def get_expected_label_size(label_data: Data, task: str, label_size: List[int]): +def get_expected_label_rows(label_data: Data, task: str, num_nodes: int, num_edges: int): """Determines expected label size based on the specfic graph properties and the number of targets in the task-dataset. """ if task.startswith("graph_"): num_labels = 1 elif task.startswith("node_"): - num_labels = label_data.x.size(0) + num_labels = num_nodes elif task.startswith("edge_"): - num_labels = label_data.edge_index.size(1) + num_labels = num_edges elif task.startswith("nodepair_"): raise NotImplementedError() - return [num_labels] + label_size + else: + print("Task name " + task + " in get_expected_label_rows") + raise NotImplementedError() + return num_labels def collate_labels( labels: List[Data], - labels_size_dict: Optional[Dict[str, Any]] = None, + labels_num_cols_dict: Optional[Dict[str, Any]] = None, labels_dtype_dict: Optional[Dict[str, Any]] = None, + num_nodes: List[int] = None, + num_edges: List[int] = None, ): """Collate labels for multitask learning. Parameters: labels: List of labels - labels_size_dict: Dict of the form Dict[tasks, sizes] which has task names as keys + labels_num_cols_dict: Dict of the form Dict[tasks, sizes] which has task names as keys and the size of the label tensor as value. The size of the tensor corresponds to how many labels/values there are to predict for that task. labels_dtype_dict: @@ -260,25 +215,21 @@ def collate_labels( Returns: A dictionary of the form Dict[tasks, labels] where tasks is the name of the task and labels - is a tensor of shape (batch_size, *labels_size_dict[task]). + is a tensor of shape (batch_size, *labels_num_cols_dict[task]). """ - if labels_size_dict is not None: - for this_label in labels: - for task in labels_size_dict.keys(): - labels_size_dict[task] = list(labels_size_dict[task]) - if len(labels_size_dict[task]) >= 2: - labels_size_dict[task] = labels_size_dict[task][1:] - elif not task.startswith("graph_"): - labels_size_dict[task] = [1] + if labels_num_cols_dict is not None: + for index, this_label in enumerate(labels): label_keys_set = set(get_keys(this_label)) - empty_task_labels = set(labels_size_dict.keys()) - label_keys_set + empty_task_labels = set(labels_num_cols_dict.keys()) - label_keys_set for task in empty_task_labels: - labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task]) - dtype = labels_dtype_dict[task] - this_label[task] = torch.full([*labels_size_dict[task]], torch.nan, dtype=dtype) + label_rows = get_expected_label_rows(this_label, task, num_nodes[index], num_edges[index]) + dtype = torch_enum_to_dtype(labels_dtype_dict[task]) + this_label[task] = torch.full( + (label_rows, labels_num_cols_dict[task]), fill_value=torch.nan, dtype=dtype + ) for task in label_keys_set - set(["x", "edge_index"]) - empty_task_labels: - labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task]) + label_rows = get_expected_label_rows(this_label, task, num_nodes[index], num_edges[index]) if not isinstance(this_label[task], (torch.Tensor)): this_label[task] = torch.as_tensor(this_label[task]) @@ -286,21 +237,23 @@ def collate_labels( # Ensure explicit task dimension also for single task labels if len(this_label[task].shape) == 1: # Distinguish whether target dim or entity dim is missing - if labels_size_dict[task][0] == this_label[task].shape[0]: + if label_rows == this_label[task].shape[0]: # num graphs/nodes/edges/nodepairs already matching this_label[task] = this_label[task].unsqueeze(1) else: # data lost unless entity dim is supposed to be 1 - if labels_size_dict[task][0] == 1: + if label_rows == 1: this_label[task] = this_label[task].unsqueeze(0) else: raise ValueError( - f"Labels for {labels_size_dict[task][0]} nodes/edges/nodepairs expected, got 1." + f"Labels for {label_rows} nodes/edges/nodepairs expected, got 1." ) - this_label[task] = pad_to_expected_label_size(this_label[task], labels_size_dict[task]) + this_label[task] = pad_to_expected_label_size( + this_label[task], label_rows, labels_num_cols_dict[task] + ) - return collate_pyg_graph_labels(labels) + return Batch.from_data_list(labels) def pad_nodepairs(pe: torch.Tensor, num_nodes: int, max_num_nodes_per_graph: int): diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 4e89f6728..04f427b0d 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1,17 +1,16 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ - import tempfile from contextlib import redirect_stderr, redirect_stdout from typing import Type, List, Dict, Union, Any, Callable, Optional, Tuple, Iterable, Literal @@ -51,27 +50,22 @@ from torch.utils.data.dataloader import DataLoader, Dataset from torch.utils.data import Subset +from rdkit import RDLogger + from graphium.utils import fs -from graphium.features import ( - mol_to_graph_dict, - GraphDict, - mol_to_pyggraph, -) +from graphium.features import mol_to_pyggraph from graphium.data.sampler import DatasetSubSampler from graphium.data.utils import graphium_package_path, found_size_mismatch from graphium.utils.arg_checker import check_arg_iterator from graphium.utils.hashing import get_md5_hash -from graphium.data.smiles_transform import ( - did_featurization_fail, - BatchingSmilesTransform, - smiles_to_unique_mol_ids, -) from graphium.data.collate import graphium_collate_fn import graphium.data.dataset as Datasets from graphium.data.normalization import LabelNormalization from graphium.data.multilevel_utils import extract_labels +import graphium_cpp + torch.multiprocessing.set_sharing_strategy("file_system") @@ -107,7 +101,6 @@ def __init__( self, batch_size_training: int = 16, batch_size_inference: int = 16, - batch_size_per_pack: Optional[int] = None, num_workers: int = 0, pin_memory: bool = True, persistent_workers: bool = False, @@ -130,15 +123,6 @@ def __init__( self.batch_size_training = batch_size_training self.batch_size_inference = batch_size_inference - self.batch_size_per_pack = batch_size_per_pack - if self.batch_size_per_pack is not None: - # Check that batch_size_per_pack is a divisor of batch_size_training and batch_size_inference - assert ( - self.batch_size_training % self.batch_size_per_pack == 0 - ), f"batch_size_training must be a multiple of batch_size_per_pack, provided batch_size_training={self.batch_size_training}, batch_size_per_pack={self.batch_size_per_pack}" - assert ( - self.batch_size_inference % self.batch_size_per_pack == 0 - ), f"batch_size_inference must be a multiple of batch_size_per_pack, provided batch_size_inference={self.batch_size_inference}, batch_size_per_pack={self.batch_size_per_pack}" self.num_workers = num_workers self.pin_memory = pin_memory @@ -153,7 +137,6 @@ def __init__( self._predict_ds = None self._data_is_prepared = False - self._data_is_cached = False def prepare_data(self): raise NotImplementedError() @@ -209,7 +192,7 @@ def get_collate_fn(self, collate_fn): if collate_fn is None: # Some values become `inf` when changing data type. `mask_nan` deals with that collate_fn = partial( - graphium_collate_fn, mask_nan=0, batch_size_per_pack=self.batch_size_per_pack + graphium_collate_fn, mask_nan=0, ) collate_fn.__name__ = graphium_collate_fn.__name__ @@ -498,12 +481,12 @@ def get_dataloader_kwargs(self, stage: RunningStage, shuffle: bool, **kwargs) -> """ loader_kwargs = {} - # Get batch size and IPU options for training set + # Get batch size for training set # if stage in [RunningStage.TRAINING, RunningStage.TUNING]: if stage in [RunningStage.TRAINING]: loader_kwargs["batch_size"] = self.batch_size_training - # Get batch size and IPU options for validation / testing sets + # Get batch size for validation / testing sets elif stage in [RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]: loader_kwargs["batch_size"] = self.batch_size_inference else: @@ -678,6 +661,7 @@ def __init__( idx_col: Optional[str] = None, mol_ids_col: Optional[str] = None, sample_size: Union[int, float, Type[None]] = None, + split_type: str = "random", split_val: float = 0.2, split_test: float = 0.2, seed: int = None, @@ -720,6 +704,7 @@ def __init__( self.idx_col = idx_col self.mol_ids_col = mol_ids_col self.sample_size = sample_size + self.split_type = split_type self.split_val = split_val self.split_test = split_test self.seed = seed @@ -729,83 +714,20 @@ def __init__( self.epoch_sampling_fraction = epoch_sampling_fraction -class IPUDataModuleModifier: - def __init__( - self, - ipu_inference_opts: Optional["poptorch.Options"] = None, - ipu_training_opts: Optional["poptorch.Options"] = None, - ipu_dataloader_training_opts: Optional["IPUDataloaderOptions"] = None, - ipu_dataloader_inference_opts: Optional["IPUDataloaderOptions"] = None, - *args, - **kwargs, - ) -> None: - r""" - wrapper functions from the a `DataModule` to support IPU and IPU options To be used in dual inheritance, for example: - ``` - IPUDataModule(BaseDataModule, IPUDataModuleModifier): - def __init__(self, **kwargs): - BaseDataModule.__init__(self, **kwargs) - IPUDataModuleModifier.__init__(self, **kwargs) - ``` - - Parameters: - ipu_inference_opts: Options for the IPU in inference mode. Ignore if not using IPUs - ipu_training_opts: Options for the IPU in training mode. Ignore if not using IPUs - ipu_dataloader_kwargs_train_val: Options for the dataloader for the IPU. Ignore if not using IPUs - ipu_dataloader_kwargs_test: Options for the dataloader for the IPU. Ignore if not using IPUs - args: Arguments for the `DataModule` - kwargs: Keyword arguments for the `DataModule` - """ - self.ipu_inference_opts = ipu_inference_opts - self.ipu_training_opts = ipu_training_opts - self.ipu_dataloader_training_opts = ipu_dataloader_training_opts - self.ipu_dataloader_inference_opts = ipu_dataloader_inference_opts - - def _dataloader(self, dataset: Dataset, **kwargs) -> "poptorch.DataLoader": - """ - Get a poptorch dataloader for a given dataset - Parameters: - dataset: The dataset to use - kwargs: Keyword arguments for the dataloader - Returns: - The poptorch dataloader - """ - - # Use regular Dataloader if no IPUs - if ("ipu_options" not in kwargs.keys()) or (kwargs["ipu_options"] is None): - raise ValueError(f"No IPU options provided.") - - # Initialize the IPU dataloader - from graphium.ipu.ipu_dataloader import create_ipu_dataloader - - loader = create_ipu_dataloader( - dataset=dataset, - **kwargs, - ) - - return loader - - -class MultitaskFromSmilesDataModule(BaseDataModule, IPUDataModuleModifier): +class MultitaskFromSmilesDataModule(BaseDataModule): def __init__( self, task_specific_args: Union[Dict[str, DatasetProcessingParams], Dict[str, Any]], - processed_graph_data_path: Optional[Union[str, os.PathLike]] = None, - dataloading_from: str = "ram", + processed_graph_data_path: Union[str, os.PathLike], featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, - batch_size_per_pack: Optional[int] = None, num_workers: int = 0, pin_memory: bool = True, persistent_workers: bool = False, multiprocessing_context: Optional[str] = None, - featurization_n_jobs: int = -1, - featurization_progress: bool = False, - featurization_backend: str = "loky", - featurization_batch_size: int = 1000, collate_fn: Optional[Callable] = None, - prepare_dict_or_graph: str = "pyg:graph", + preprocessing_n_jobs: int = 0, **kwargs, ): """ @@ -821,47 +743,40 @@ def __init__( - `df_path` - `smiles_col` - `label_cols` - dataloading_from: Whether to load the data from RAM or from disk. If set to "disk", the data - must have been previously cached with `processed_graph_data_path` set. If set to "ram", the data - will be loaded in RAM and the `processed_graph_data_path` will be ignored. featurization: args to apply to the SMILES to Graph featurizer. batch_size_training: batch size for training and val dataset. batch_size_inference: batch size for test dataset. num_workers: Number of workers for the dataloader. Use -1 to use all available cores. pin_memory: Whether to pin on paginated CPU memory for the dataloader. - featurization_n_jobs: Number of cores to use for the featurization. - featurization_progress: whether to show a progress bar during featurization. - featurization_backend: The backend to use for the molecular featurization. - "multiprocessing": Found to cause less memory issues. - "loky": joblib's Default. Found to cause memory leaks. - "threading": Found to be slow. - featurization_batch_size: Batch size to use for the featurization. collate_fn: A custom torch collate function. Default is to `graphium.data.graphium_collate_fn` - prepare_dict_or_graph: Whether to preprocess all molecules as Graph dict or PyG graphs. - Possible options: + preprocessing_n_jobs: Number of threads to use during preprocessing. + Use 0 to use all available cores, or -1 to use all but one core. - - "pyg:dict": Process molecules as a `dict`. It's faster and requires less RAM during - pre-processing. It is slower during training with with `num_workers=0` since - pyg `Data` will be created during data-loading, but faster with large - `num_workers`, and less likely to cause memory issues with the parallelization. - - "pyg:graph": Process molecules as `pyg.data.Data`. + dataloading_from: Deprecated. Behaviour now always matches previous "disk" option. + featurization_n_jobs: Deprecated. + featurization_progress: Deprecated. + featurization_backend: Deprecated. + featurization_batch_size: Deprecated. + prepare_dict_or_graph: Deprecated. Behaviour now always matches previous "pyg:graph" option. """ BaseDataModule.__init__( self, batch_size_training=batch_size_training, batch_size_inference=batch_size_inference, - batch_size_per_pack=batch_size_per_pack, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, multiprocessing_context=multiprocessing_context, collate_fn=collate_fn, ) - IPUDataModuleModifier.__init__(self, **kwargs) + self._len = None self.task_specific_args = task_specific_args self.task_dataset_processing_params = {} @@ -878,26 +793,17 @@ def __init__( task: self.task_dataset_processing_params[task].epoch_sampling_fraction for task in self.task_dataset_processing_params.keys() } - - self.featurization_n_jobs = featurization_n_jobs - self.featurization_progress = featurization_progress - self.featurization_backend = featurization_backend - self.featurization_batch_size = featurization_batch_size + self.task_names = [task for task in self.task_dataset_processing_params.keys()] self.task_train_indices = None self.task_val_indices = None self.task_test_indices = None - self.single_task_datasets = None - self.train_singletask_datasets = None - self.val_singletask_datasets = None - self.test_singletask_datasets = None - self.train_ds = None self.val_ds = None self.test_ds = None - self._parse_caching_args(processed_graph_data_path, dataloading_from) + self._parse_caching_args(processed_graph_data_path) self.task_norms = {} @@ -906,42 +812,94 @@ def __init__( self.featurization = featurization - # Whether to transform the smiles into a pyg `Data` graph or a dictionary compatible with pyg - if prepare_dict_or_graph == "pyg:dict": - self.smiles_transformer = partial(mol_to_graph_dict, **featurization) - elif prepare_dict_or_graph == "pyg:graph": - self.smiles_transformer = partial(mol_to_pyggraph, **featurization) - else: - raise ValueError( - f"`prepare_dict_or_graph` should be either 'pyg:dict' or 'pyg:graph', Provided: `{prepare_dict_or_graph}`" + # Copy featurization for the representation used by graphium_cpp + encoded_featurization = deepcopy(featurization) + self.encoded_featurization = encoded_featurization + + def encode_feature_options(options, name, encoding_function): + if name not in options or options[name] is None: + options[name] = torch.tensor(data=[], dtype=torch.int64) + else: + options[name] = encoding_function(options[name]) + + encode_feature_options( + encoded_featurization, + "atom_property_list_onehot", + graphium_cpp.atom_onehot_feature_names_to_tensor, + ) + encode_feature_options( + encoded_featurization, "atom_property_list_float", graphium_cpp.atom_float_feature_names_to_tensor + ) + encode_feature_options( + encoded_featurization, "edge_property_list", graphium_cpp.bond_feature_names_to_tensor + ) + + if ( + "pos_encoding_as_features" in featurization + and featurization["pos_encoding_as_features"] is not None + and featurization["pos_encoding_as_features"]["pos_types"] is not None + ): + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor( + featurization["pos_encoding_as_features"]["pos_types"] ) + else: + pos_encoding_names = [] + pos_encoding_tensor = torch.tensor(data=[], dtype=torch.int64) + encoded_featurization["pos_encoding_as_features"] = (pos_encoding_names, pos_encoding_tensor) + + explicit_H = featurization["explicit_H"] if "explicit_H" in featurization else False + add_self_loop = featurization["add_self_loop"] if "add_self_loop" in featurization else False + merge_equivalent_mols = ( + featurization["merge_equivalent_mols"] if "merge_equivalent_mols" in featurization else True + ) + + # Save these for calling graphium_cpp.prepare_and_save_data later + self.add_self_loop = add_self_loop + self.explicit_H = explicit_H + self.merge_equivalent_mols = merge_equivalent_mols + + self.preprocessing_n_jobs = preprocessing_n_jobs + + self.smiles_transformer = partial(mol_to_pyggraph, **encoded_featurization) self.data_hash = self.get_data_hash() - if self.processed_graph_data_path is not None: - if self._ready_to_load_all_from_file(): - self._data_is_prepared = True - self._data_is_cached = True + if self._ready_to_load_all_from_file(): + self._data_is_prepared = True + self._len = self._get_len_from_cached_file() - def _parse_caching_args(self, processed_graph_data_path, dataloading_from): + def _get_len_from_cached_file(self): + if self._ready_to_load_all_from_file(): + self._data_is_prepared = True + train_metadata = graphium_cpp.load_metadata_tensors( + self.processed_graph_data_path, "train", self.data_hash + ) + val_metadata = graphium_cpp.load_metadata_tensors( + self.processed_graph_data_path, "val", self.data_hash + ) + test_metadata = graphium_cpp.load_metadata_tensors( + self.processed_graph_data_path, "test", self.data_hash + ) + length = 0 + if len(train_metadata) > 0: + length += len(train_metadata[2]) + if len(val_metadata) > 0: + length += len(val_metadata[2]) + if len(test_metadata) > 0: + length += len(test_metadata[2]) + else: + raise ValueError("Data is not prepared. Please call prepare_data() first.") + return length + + def _parse_caching_args(self, processed_graph_data_path): """ Parse the caching arguments, and raise errors if the arguments are invalid. """ - # Whether to load the data from RAM or from disk - dataloading_from = dataloading_from.lower() - if dataloading_from not in ["disk", "ram"]: - raise ValueError( - f"`dataloading_from` should be either 'disk' or 'ram', Provided: `{dataloading_from}`" - ) - # If loading from disk, the path to the cached data must be provided - if dataloading_from == "disk" and processed_graph_data_path is None: - raise ValueError( - "When `dataloading_from` is 'disk', `processed_graph_data_path` must be provided." - ) + if processed_graph_data_path is None: + raise ValueError("`processed_graph_data_path` must be provided.") self.processed_graph_data_path = processed_graph_data_path - self.dataloading_from = dataloading_from def _get_task_key(self, task_level: str, task: str): task_prefix = f"{task_level}_" @@ -959,7 +917,27 @@ def get_task_levels(self): return task_level_map - def prepare_data(self, save_smiles_and_ids: bool = False): + @staticmethod + def concat_smiles_tensor_index(): + return 0 + + @staticmethod + def smiles_offsets_tensor_index(): + return 1 + + @staticmethod + def num_nodes_tensor_index(): + return 2 + + @staticmethod + def num_edges_tensor_index(): + return 3 + + @staticmethod + def data_offsets_tensor_index(): + return 4 + + def prepare_data(self): """Called only from a single process in distributed settings. Steps: - If each cache is set and exists, reload from cache and return. Otherwise, @@ -970,30 +948,54 @@ def prepare_data(self, save_smiles_and_ids: bool = False): - In the previous step, we were also able to get the unique smiles, which we use to compute the features - For each single-task dataframe and associated data (smiles, labels, etc.): - Filter out the data corresponding to molecules which failed featurization. - - Create a corresponding SingletaskDataset - - Split the SingletaskDataset according to the task-specific splits for train, val and test + - Split the dataset according to the task-specific splits for train, val and test """ - def has_atoms_after_h_removal(smiles): - # Remove all 'H' characters from the SMILES - smiles_without_h = re.sub("H", "", smiles) - # Check if any letters are remaining in the modified string - has_atoms = bool(re.search("[a-zA-Z]", smiles_without_h)) - if has_atoms == False: - logger.info(f"Removed Hydrogen molecule: {smiles}") - return has_atoms + # Don't log error messages from SMILES parsing in RDKit. + # Common error messages were: + # WARNING: not removing hydrogen atom without neighbors + # SMILES Parse Error: syntax error while parsing: restricted + # SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted' + RDLogger.DisableLog("rdApp.*") + + for task, args in self.task_dataset_processing_params.items(): + if args.label_normalization is None: + args.label_normalization = {} + label_normalization = LabelNormalization(**args.label_normalization) + self.task_norms[task] = label_normalization if self._data_is_prepared: logger.info("Data is already prepared.") - self.get_label_statistics(self.processed_graph_data_path, self.data_hash, dataset=None) + self.label_num_cols, self.label_dtypes = graphium_cpp.load_num_cols_and_dtypes( + self.processed_graph_data_path, self.data_hash + ) + self.stage_data = { + "train": graphium_cpp.load_metadata_tensors( + self.processed_graph_data_path, "train", self.data_hash + ), + "val": graphium_cpp.load_metadata_tensors( + self.processed_graph_data_path, "val", self.data_hash + ), + "test": graphium_cpp.load_metadata_tensors( + self.processed_graph_data_path, "test", self.data_hash + ), + } + if len(self.label_num_cols) > 0: + for task in self.task_dataset_processing_params.keys(): + stats = graphium_cpp.load_stats(self.processed_graph_data_path, self.data_hash, task) + if len(stats) < 4: + raise RuntimeError(f'Error loading cached stats for task "{task}"') + + self.task_norms[task].set_statistics(stats[0], stats[1], stats[2], stats[3]) return + task_dataset_args = {} + self.task_train_indices = {} + self.task_val_indices = {} + self.task_test_indices = {} + """Load all single-task dataframes.""" - task_df = {} for task, args in self.task_dataset_processing_params.items(): - if args.label_normalization is None: - args.label_normalization = {} - label_normalization = LabelNormalization(**args.label_normalization) logger.info(f"Reading data for task '{task}'") if args.df is None: # Only load the useful columns, as some datasets can be very large when loading all columns. @@ -1007,24 +1009,18 @@ def has_atoms_after_h_removal(smiles): + check_arg_iterator(args.weights_col, enforce_type=list) ) label_dtype = {col: np.float32 for col in label_cols} - task_df[task] = self._read_table(args.df_path, usecols=usecols, dtype=label_dtype) + df = self._read_table(args.df_path, usecols=usecols, dtype=label_dtype) else: label_cols = self._parse_label_cols( df=args.df, df_path=None, label_cols=args.label_cols, smiles_col=args.smiles_col ) - task_df[task] = args.df - task_df[task] = task_df[task] + df = args.df + args.label_cols = label_cols - self.task_norms[task] = label_normalization - logger.info("Done reading datasets") - """Subsample the data frames and extract the necessary data to create SingleTaskDatasets for each task (smiles, labels, extras).""" - task_dataset_args = {} - for task in task_df.keys(): - task_dataset_args[task] = {} + """Subsample the data frames and extract the necessary data for each task (smiles, labels, extras).""" - for task, df in task_df.items(): # Subsample all the dataframes sample_size = self.task_dataset_processing_params[task].sample_size df = self._sub_sample_df(df, sample_size, self.task_dataset_processing_params[task].seed) @@ -1036,7 +1032,7 @@ def has_atoms_after_h_removal(smiles): logger.info("Filtering done") # Extract smiles, labels, extras args = self.task_dataset_processing_params[task] - smiles, labels, sample_idx, extras = self._extract_smiles_labels( + smiles, labels, label_offsets, sample_idx, extras = self._extract_smiles_labels( df, task_level=args.task_level, smiles_col=args.smiles_col, @@ -1046,125 +1042,78 @@ def has_atoms_after_h_removal(smiles): weights_type=args.weights_type, ) - # Store the relevant information for each task's dataset - task_dataset_args[task]["smiles"] = smiles - task_dataset_args[task]["labels"] = labels - task_dataset_args[task]["sample_idx"] = sample_idx - task_dataset_args[task]["extras"] = extras - - """Convert SMILES to features (graphs, fingerprints, etc.) for the unique molecules found.""" - all_smiles = [] - all_tasks = [] - idx_per_task = {} - total_len = 0 - for task, dataset_args in task_dataset_args.items(): - all_smiles.extend(dataset_args["smiles"]) - num_smiles = len(dataset_args["smiles"]) - idx_per_task[task] = (total_len, total_len + num_smiles) - total_len += num_smiles - for count in range(len(dataset_args["smiles"])): - all_tasks.append(task) - # Get all unique mol ids - all_unique_mol_ids = smiles_to_unique_mol_ids( - all_smiles, - n_jobs=self.featurization_n_jobs, - featurization_batch_size=self.featurization_batch_size, - backend=self.featurization_backend, - ) - _, unique_ids_idx, unique_ids_inv = np.unique( - all_unique_mol_ids, return_index=True, return_inverse=True - ) + num_molecules = len(smiles) - smiles_to_featurize = [all_smiles[ii] for ii in unique_ids_idx] - - # Convert SMILES to features - features, _ = self._featurize_molecules(smiles_to_featurize) - - # Store the features (including Nones, which will be filtered in the next step) - for task in task_dataset_args.keys(): - task_dataset_args[task]["features"] = [] - task_dataset_args[task]["idx_none"] = [] - # Create a list of features matching up with the original smiles - all_features = [features[unique_idx] for unique_idx in unique_ids_inv] - - # Add the features to the task-specific data - for all_idx, task in enumerate(all_tasks): - task_dataset_args[task]["features"].append(all_features[all_idx]) - - """Filter data based on molecules which failed featurization. Create single task datasets as well.""" - self.single_task_datasets = {} - for task, args in task_dataset_args.items(): - # Find out which molecule failed featurization, and filter them out - idx_none = [] - for idx, (feat, labels, smiles) in enumerate( - zip(args["features"], args["labels"], args["smiles"]) - ): - if did_featurization_fail(feat) or found_size_mismatch(task, feat, labels, smiles): - idx_none.append(idx) - this_unique_ids = all_unique_mol_ids[idx_per_task[task][0] : idx_per_task[task][1]] - df, features, smiles, labels, sample_idx, extras, this_unique_ids = self._filter_none_molecules( - idx_none, - task_df[task], - args["features"], - args["smiles"], - args["labels"], - args["sample_idx"], - args["extras"], - this_unique_ids, - ) - task_dataset_args[task]["smiles"] = smiles - task_dataset_args[task]["labels"] = labels - task_dataset_args[task]["features"] = features - task_dataset_args[task]["sample_idx"] = sample_idx - task_dataset_args[task]["extras"] = extras - - # We have the necessary components to create single-task datasets. - self.single_task_datasets[task] = Datasets.SingleTaskDataset( - features=task_dataset_args[task]["features"], - labels=task_dataset_args[task]["labels"], - smiles=task_dataset_args[task]["smiles"], - unique_ids=this_unique_ids, - indices=task_dataset_args[task]["sample_idx"], - **task_dataset_args[task]["extras"], - ) + # Clear the reference to the DataFrame, so that Python can free up the memory. + df = None - """We split the data up to create train, val and test datasets""" - self.task_train_indices = {} - self.task_val_indices = {} - self.task_test_indices = {} + # Store the relevant information for each task's dataset + task_dataset_args[task] = { + "smiles": smiles, + "extras": extras, + } + if args.label_cols != 0: + task_dataset_args[task]["labels"] = labels + task_dataset_args[task]["label_offsets"] = label_offsets + + """We split the data up to create train, val and test datasets""" - for task, df in task_df.items(): train_indices, val_indices, test_indices = self._get_split_indices( - len(df), + num_molecules, split_val=self.task_dataset_processing_params[task].split_val, + split_type=self.task_dataset_processing_params[task].split_type, split_test=self.task_dataset_processing_params[task].split_test, + sample_idx=sample_idx, split_seed=self.task_dataset_processing_params[task].seed, splits_path=self.task_dataset_processing_params[task].splits_path, split_names=self.task_dataset_processing_params[task].split_names, - sample_idx=task_dataset_args[task]["sample_idx"], + # smiles and labels are already sub-sampled, so the split indices need to be + # relative to the sample, not the original. + # sample_idx=task_dataset_args[task]["sample_idx"], ) self.task_train_indices[task] = train_indices self.task_val_indices[task] = val_indices self.task_test_indices[task] = test_indices + logger.info("Done reading datasets") + + # The rest of the data preparation and caching is done in graphium_cpp.prepare_and_save_data + normalizations = { + task: self.task_dataset_processing_params[task].label_normalization + for task in self.task_dataset_processing_params.keys() + } ( - self.train_singletask_datasets, - self.val_singletask_datasets, - self.test_singletask_datasets, - ) = self.get_subsets_of_datasets( - self.single_task_datasets, self.task_train_indices, self.task_val_indices, self.task_test_indices + self.stage_data, + all_stats, + self.label_num_cols, + self.label_dtypes, + ) = graphium_cpp.prepare_and_save_data( + self.task_names, + task_dataset_args, + normalizations, + self.processed_graph_data_path, + self.data_hash, + self.task_train_indices, + self.task_val_indices, + self.task_test_indices, + self.add_self_loop, + self.explicit_H, + self.preprocessing_n_jobs, + self.merge_equivalent_mols, ) + self._len = self._get_len_from_cached_file() + + for task, stats in all_stats.items(): + if len(stats) < 4: + raise RuntimeError(f'Error loading cached stats for task "{task}"') - if self.processed_graph_data_path is not None: - self._save_data_to_files(save_smiles_and_ids) - self._data_is_cached = True + self.task_norms[task].set_statistics(stats[0], stats[1], stats[2], stats[3]) self._data_is_prepared = True def setup( self, stage: str = None, - save_smiles_and_ids: bool = False, ): """ Prepare the torch dataset. Called on every GPUs. Setting state here is ok. @@ -1174,54 +1123,49 @@ def setup( # Can possibly get rid of setup because a single dataset will have molecules exclusively in train, val or test # Produce the label sizes to update the collate function - labels_size = {} - labels_dtype = {} + label_num_cols = {} + label_dtypes = {} if stage == "fit" or stage is None: if self.train_ds is None: - self.train_ds = self._make_multitask_dataset( - self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids - ) + self.train_ds = self._make_multitask_dataset("train") - if self.val_ds is None: - self.val_ds = self._make_multitask_dataset( - self.dataloading_from, "val", save_smiles_and_ids=save_smiles_and_ids - ) + if self.val_ds is None and len(self.stage_data["val"]) >= self.num_edges_tensor_index(): + self.val_ds = self._make_multitask_dataset("val") logger.info(self.train_ds) - logger.info(self.val_ds) - labels_size.update( - self.train_ds.labels_size + label_num_cols.update( + dict(zip(self.train_ds.task_names, self.train_ds.label_num_cols)) ) # Make sure that all task label sizes are contained in here. Maybe do the update outside these if statements. - labels_size.update(self.val_ds.labels_size) - labels_dtype.update(self.train_ds.labels_dtype) - labels_dtype.update(self.val_ds.labels_dtype) + label_dtypes.update(dict(zip(self.train_ds.task_names, self.train_ds.label_dtypes))) + + if self.val_ds is not None: + logger.info(self.val_ds) + label_num_cols.update(dict(zip(self.val_ds.task_names, self.val_ds.label_num_cols))) + label_dtypes.update(dict(zip(self.val_ds.task_names, self.val_ds.label_dtypes))) if stage == "test" or stage is None: - if self.test_ds is None: - self.test_ds = self._make_multitask_dataset( - self.dataloading_from, "test", save_smiles_and_ids=save_smiles_and_ids - ) + if self.test_ds is None and len(self.stage_data["test"]) >= self.num_edges_tensor_index(): + self.test_ds = self._make_multitask_dataset("test") - logger.info(self.test_ds) + if self.test_ds is not None: + logger.info(self.test_ds) - labels_size.update(self.test_ds.labels_size) - labels_dtype.update(self.test_ds.labels_dtype) + label_num_cols.update(dict(zip(self.test_ds.task_names, self.test_ds.label_num_cols))) + label_dtypes.update(dict(zip(self.test_ds.task_names, self.test_ds.label_dtypes))) - default_labels_size_dict = self.collate_fn.keywords.get("labels_size_dict", None) + default_labels_num_cols_dict = self.collate_fn.keywords.get("labels_num_cols_dict", None) - if default_labels_size_dict is None: - self.collate_fn.keywords["labels_size_dict"] = labels_size + if default_labels_num_cols_dict is None: + self.collate_fn.keywords["labels_num_cols_dict"] = label_num_cols default_labels_dtype_dict = self.collate_fn.keywords.get("labels_dtype_dict", None) if default_labels_dtype_dict is None: - self.collate_fn.keywords["labels_dtype_dict"] = labels_dtype + self.collate_fn.keywords["labels_dtype_dict"] = label_dtypes def _make_multitask_dataset( self, - dataloading_from: Literal["disk", "ram"], stage: Literal["train", "val", "test"], - save_smiles_and_ids: bool, ) -> Datasets.MultitaskDataset: """ Create a MultitaskDataset for the given stage using single task datasets @@ -1229,7 +1173,6 @@ def _make_multitask_dataset( Parameters: stage: Stage to create multitask dataset for - save_smiles_and_ids: Whether to save SMILES strings and unique IDs processed_graph_data_path: path to save and load processed graph data from """ @@ -1237,41 +1180,35 @@ def _make_multitask_dataset( assert stage in allowed_stages, f"Multitask dataset stage `{stage}` not in {allowed_stages}" if stage == "train": - singletask_datasets = self.train_singletask_datasets about = "training set" elif stage == "val": - singletask_datasets = self.val_singletask_datasets about = "validation set" elif stage == "test": - singletask_datasets = self.test_singletask_datasets about = "test set" else: raise ValueError(f"Unknown stage {stage}") processed_graph_data_path = self.processed_graph_data_path + stage_data = self.stage_data[stage] + data_offsets = None + if self.data_offsets_tensor_index() < len(stage_data): + data_offsets = stage_data[self.data_offsets_tensor_index()] + multitask_dataset = Datasets.MultitaskDataset( - singletask_datasets, - n_jobs=self.featurization_n_jobs, - backend=self.featurization_backend, - featurization_batch_size=self.featurization_batch_size, - progress=self.featurization_progress, about=about, - save_smiles_and_ids=save_smiles_and_ids, data_path=self._path_to_load_from_file(stage) if processed_graph_data_path else None, - dataloading_from=dataloading_from, - data_is_cached=self._data_is_cached, + featurize_smiles=self.smiles_transformer, + task_names=self.task_names, + label_num_cols=self.label_num_cols, + label_dtypes=self.label_dtypes, + mol_file_data_offsets=data_offsets, + concat_smiles_tensor=stage_data[self.concat_smiles_tensor_index()], + smiles_offsets_tensor=stage_data[self.smiles_offsets_tensor_index()], + num_nodes_tensor=stage_data[self.num_nodes_tensor_index()], + num_edges_tensor=stage_data[self.num_edges_tensor_index()], ) # type: ignore - # calculate statistics for the train split and used for all splits normalization - if stage == "train": - self.get_label_statistics( - self.processed_graph_data_path, self.data_hash, multitask_dataset, train=True - ) - # Normalization has already been applied in cached data - if not self._data_is_prepared: - self.normalize_label(multitask_dataset, stage) - return multitask_dataset def _ready_to_load_all_from_file(self) -> bool: @@ -1300,174 +1237,10 @@ def _data_ready_at_path(self, path: str) -> bool: return can_load_from_file - def _save_data_to_files(self, save_smiles_and_ids: bool = False) -> None: - """ - Save data to files so that they can be loaded from file during training/validation/test - """ - - stages = ["train", "val", "test"] - - # At the moment, we need to merge the `SingleTaskDataset`'s into `MultitaskDataset`s in order to save to file - # This is because the combined labels need to be stored together. We can investigate not doing this if this is a problem - temp_datasets = { - stage: self._make_multitask_dataset( - dataloading_from="ram", stage=stage, save_smiles_and_ids=save_smiles_and_ids - ) - for stage in stages - } - for stage in stages: - self.save_featurized_data(temp_datasets[stage], self._path_to_load_from_file(stage)) - temp_datasets[stage].save_metadata(self._path_to_load_from_file(stage)) - # self.train_ds, self.val_ds, self.test_ds will be created during `setup()` - - if self.dataloading_from == "disk": - del temp_datasets - else: - self.train_ds = temp_datasets["train"] - self.val_ds = temp_datasets["val"] - self.test_ds = temp_datasets["test"] - def get_folder_size(self, path): # check if the data items are actually saved into the folders return sum(os.path.getsize(osp.join(path, f)) for f in os.listdir(path)) - def calculate_statistics(self, dataset: Datasets.MultitaskDataset, train: bool = False): - """ - Calculate the statistics of the labels for each task, and overwrites the `self.task_norms` attribute. - - Parameters: - dataset: the dataset to calculate the statistics from - train: whether the dataset is the training set - - """ - - if self.task_norms and train: - for task in dataset.labels_size.keys(): - # if the label type is graph_*, we need to stack them as the tensor shape is (num_labels, ) - if task.startswith("graph"): - labels = np.stack( - np.array([datum["labels"][task] for datum in dataset if task in datum["labels"]]), - axis=0, - ) - # for other tasks with node_ and edge_, the label shape is [num_nodes/num_edges, num_labels] - # we can concatenate them directly - else: - labels = np.concatenate( - [datum["labels"][task] for datum in dataset if task in datum["labels"]], axis=0 - ) - - self.task_norms[task].calculate_statistics(labels) - - def get_label_statistics( - self, - data_path: Union[str, os.PathLike], - data_hash: str, - dataset: Datasets.MultitaskDataset, - train: bool = False, - ): - """ - Get the label statistics from the dataset, and save them to file, if needed. - `self.task_norms` will be modified in-place with the label statistics. - - Parameters: - data_path: the path to save and load the label statistics to. If None, no saving and loading will be done. - data_hash: the hash of the dataset generated by `get_data_hash()` - dataset: the dataset to calculate the statistics from - train: whether the dataset is the training set - - """ - if data_path is None: - self.calculate_statistics(dataset, train=train) - else: - path_with_hash = os.path.join(data_path, data_hash) - os.makedirs(path_with_hash, exist_ok=True) - filename = os.path.join(path_with_hash, "task_norms.pkl") - if self.task_norms and train and not os.path.isfile(filename): - self.calculate_statistics(dataset, train=train) - torch.save(self.task_norms, filename, pickle_protocol=4) - # if any of the above three condition does not satisfy, we load from file. - else: - self.task_norms = torch.load(filename) - - def normalize_label(self, dataset: Datasets.MultitaskDataset, stage) -> Datasets.MultitaskDataset: - """ - Normalize the labels in the dataset using the statistics in `self.task_norms`. - - Parameters: - dataset: the dataset to normalize the labels from - - Returns: - the dataset with normalized labels - """ - for task in dataset.labels_size.keys(): - # we normalize the dataset if (it is train split) or (it is val/test splits and normalize_val_test is set to true) - if (stage == "train") or (stage in ["val", "test"] and self.task_norms[task].normalize_val_test): - for i in range(len(dataset)): - if task in dataset[i]["labels"]: - dataset[i]["labels"][task] = self.task_norms[task].normalize( - dataset[i]["labels"][task] - ) - return dataset - - def save_featurized_data(self, dataset: Datasets.MultitaskDataset, processed_data_path): - os.makedirs(processed_data_path) # In case the len(dataset) is 0 - for i in range(0, len(dataset), 1000): - os.makedirs(os.path.join(processed_data_path, format(i // 1000, "04d")), exist_ok=True) - process_params = [(index, datum, processed_data_path) for index, datum in enumerate(dataset)] - - # Check if "about" is in the Dataset object - about = "" - if hasattr(dataset, "about"): - about = dataset.about - for param in tqdm(process_params, desc=f"Saving featurized data {about}"): - self.process_func(param) - return - - def process_func(self, param): - index, datum, folder = param - filename = os.path.join(folder, format(index // 1000, "04d"), format(index, "07d") + ".pkl") - torch.save( - {"graph_with_features": datum["features"], "labels": datum["labels"]}, - filename, - pickle_protocol=4, - ) - return - - def get_dataloader_kwargs(self, stage: RunningStage, shuffle: bool, **kwargs) -> Dict[str, Any]: - """ - Get the options for the dataloader depending on the current stage. - - Parameters: - stage: Whether in Training, Validating, Testing, Sanity-checking, Predicting, or Tuning phase. - shuffle: set to ``True`` to have the data reshuffled at every epoch. - - Returns: - Arguments to pass to the `DataLoader` during initialization - """ - loader_kwargs = super().get_dataloader_kwargs(stage=stage, shuffle=shuffle, **kwargs) - - # Get batch size and IPU options for training set - # if stage in [RunningStage.TRAINING, RunningStage.TUNING]: - if stage in [RunningStage.TRAINING]: - loader_kwargs["ipu_dataloader_options"] = self.ipu_dataloader_training_opts - loader_kwargs["ipu_options"] = self.ipu_training_opts - - # Get batch size and IPU options for validation / testing sets - elif stage in [RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]: - loader_kwargs["ipu_dataloader_options"] = self.ipu_dataloader_inference_opts - loader_kwargs["ipu_options"] = self.ipu_inference_opts - else: - raise ValueError(f"Wrong value for `stage`. Provided `{stage}`") - - # Remove the IPU options if not available - if loader_kwargs["ipu_options"] is None: - loader_kwargs.pop("ipu_options") - if loader_kwargs["ipu_dataloader_options"] is not None: - logger.warning( - "`ipu_dataloader_options` will be ignored since it is provided without `ipu_options`." - ) - loader_kwargs.pop("ipu_dataloader_options") - return loader_kwargs def get_dataloader( self, dataset: Dataset, shuffle: bool, stage: RunningStage @@ -1494,11 +1267,8 @@ def get_dataloader( ) # turn shuffle off when sampler is used as sampler option is mutually exclusive with shuffle kwargs["shuffle"] = False - is_ipu = ("ipu_options" in kwargs.keys()) and (kwargs.get("ipu_options") is not None) - if is_ipu: - loader = IPUDataModuleModifier._dataloader(self, dataset=dataset, sampler=sampler, **kwargs) - else: - loader = BaseDataModule._dataloader(self, dataset=dataset, sampler=sampler, **kwargs) + + loader = BaseDataModule._dataloader(self, dataset=dataset, sampler=sampler, **kwargs) return loader @@ -1509,115 +1279,10 @@ def get_collate_fn(self, collate_fn): graphium_collate_fn, mask_nan=0, do_not_collate_keys=["smiles", "mol_ids"], - batch_size_per_pack=self.batch_size_per_pack, ) collate_fn.__name__ = graphium_collate_fn.__name__ return collate_fn - # Cannot be used as is for the multitask version, because sample_idx does not apply. - def _featurize_molecules(self, smiles: Iterable[str]) -> Tuple[List, List]: - """ - Precompute the features (graphs, fingerprints, etc.) from the SMILES. - Features are computed from `self.smiles_transformer`. - A warning is issued to mention which molecules failed featurization. - - Note: - (hadim): in case of very large dataset we could: - - or cache the data and read from it during `next(iter(dataloader))` - - or compute the features on-the-fly during `next(iter(dataloader))` - For now we compute in advance and hold everything in memory. - - Parameters: - smiles: A list of all the molecular SMILES to featurize - sample_idx: The indexes corresponding to the sampled SMILES. - If not provided, computed from `numpy.arange`. - - Returns: - features: A list of all the featurized molecules - idx_none: A list of the indexes that failed featurization - """ - - batch_size = BatchingSmilesTransform.parse_batch_size( - numel=len(smiles), - desired_batch_size=self.featurization_batch_size, - n_jobs=self.featurization_n_jobs, - ) - - # Loop all the smiles and compute the features - features = dm.parallelized_with_batches( - BatchingSmilesTransform(self.smiles_transformer), - smiles, - batch_size=batch_size, - progress=True, - n_jobs=self.featurization_n_jobs, - backend=self.featurization_backend, - tqdm_kwargs={"desc": f"featurizing_smiles, batch={batch_size}"}, - ) - - # Warn about None molecules - idx_none = [ii for ii, feat in enumerate(features) if did_featurization_fail(feat)] - if len(idx_none) > 0: - mols_to_msg = [ - f"idx={idx} - smiles={smiles[idx]} - Error_msg[:-200]=\n{str(features[idx])[:-200]}" - for idx in idx_none - ] - msg = "\n".join(mols_to_msg) - logger.warning( - (f"{len(idx_none)} molecules will be removed since they failed featurization:\n" + msg) - ) - - return features, idx_none - - @staticmethod - def _filter_none_molecules( - idx_none: Iterable, - *args: Union[pd.DataFrame, pd.Series, np.ndarray, torch.Tensor, list, tuple, Dict[Any, Iterable]], - ) -> List[Union[pd.DataFrame, pd.Series, np.ndarray, torch.Tensor, list, tuple, Dict[Any, Iterable]]]: - """ - Filter the molecules, labels, etc. for the molecules that failed featurization. - - Parameters: - idx_none: A list of the indexes that failed featurization - args: Any argument from which to filter the failed SMILES. - Can be a `list`, `tuple`, `Tensor`, `np.array`, `Dict`, `pd.DataFrame`, `pd.Series`. - Otherwise, it is not filtered. - WARNING: If a `pd.DataFrame` or `pd.Series` is passed, it filters by the row indexes, - NOT by the `DataFrame.index` or `Series.index`! Be careful! - - Returns: - out: All the `args` with the indexes from `idx_none` removed. - """ - if len(idx_none) == 0: - return args - idx_none = np.asarray(idx_none) - - out = [] - for arg in args: - if isinstance(arg, pd.DataFrame): - new = arg.drop(arg.index[idx_none], axis=0) - elif isinstance(arg, pd.Series): - new = arg.drop(arg.index[idx_none], axis=0) - elif isinstance(arg, np.ndarray): - new = np.delete(arg, idx_none, axis=0) - elif isinstance(arg, torch.Tensor): - not_none = torch.ones(arg.shape[0], dtype=bool) - not_none[idx_none] = False - new = arg[not_none] - elif isinstance(arg, (list, tuple)): - arg = list(arg) - new = [elem for ii, elem in enumerate(arg) if ii not in idx_none] - elif isinstance(arg, dict): - new = {} - for key, val in arg.items(): - new[key] = MultitaskFromSmilesDataModule._filter_none_molecules(idx_none, val) # Careful - else: - new = arg - out.append(new) - - out = tuple(out) if len(out) > 1 else out[0] - - return out - def _parse_label_cols( self, df: pd.DataFrame, @@ -1695,8 +1360,6 @@ def in_dims(self): """ graph = self.get_fake_graph() - if isinstance(graph, (GraphDict)): - graph = graph.data # get list of all keys corresponding to positional encoding pe_dim_dict = {} @@ -1735,14 +1398,9 @@ def get_fake_graph(self): return graph ########################## Private methods ###################################### - def _save_to_cache(self): - raise NotImplementedError() - - def _load_from_cache(self): - raise NotImplementedError() + @staticmethod def _extract_smiles_labels( - self, df: pd.DataFrame, task_level: str, smiles_col: Optional[str] = None, @@ -1752,7 +1410,11 @@ def _extract_smiles_labels( weights_col: Optional[str] = None, weights_type: Optional[str] = None, ) -> Tuple[ - np.ndarray, np.ndarray, Union[Type[None], np.ndarray], Dict[str, Union[Type[None], np.ndarray]] + np.ndarray, + np.ndarray, + np.ndarray, + Union[Type[None], np.ndarray], + Dict[str, Union[Type[None], np.ndarray]], ]: """ For a given dataframe extract the SMILES and labels columns. Smiles is returned as a list @@ -1767,7 +1429,7 @@ def _extract_smiles_labels( weights_col: Name of the column containing the weights weights_type: Type of weights to use. Returns: - smiles, labels, sample_idx, extras + smiles, labels, label_offsets, sample_idx, extras """ if smiles_col is None: # Should we specify which dataset has caused the potential issue? @@ -1788,17 +1450,18 @@ def _extract_smiles_labels( smiles = df[smiles_col].values if len(label_cols) > 0: if task_level == "graph": - labels = extract_labels(df, "graph", label_cols) + labels, label_offsets = extract_labels(df, "graph", label_cols) elif task_level == "node": - labels = extract_labels(df, "node", label_cols) + labels, label_offsets = extract_labels(df, "node", label_cols) elif task_level == "edge": - labels = extract_labels(df, "edge", label_cols) + labels, label_offsets = extract_labels(df, "edge", label_cols) elif task_level == "nodepair": - labels = extract_labels(df, "nodepair", label_cols) + labels, label_offsets = extract_labels(df, "nodepair", label_cols) else: raise ValueError(f"Unknown task level: {task_level}") else: labels = float("nan") + np.zeros([len(smiles), 0]) + label_offsets = None # Get the indices, used for sub-sampling and splitting the dataset if idx_col is not None: @@ -1837,13 +1500,14 @@ def _extract_smiles_labels( weights /= np.max(weights) # Put the max weight to 1 extras = {"weights": weights, "mol_ids": mol_ids} - return smiles, labels, sample_idx, extras + return smiles, labels, label_offsets, sample_idx, extras + @staticmethod def _get_split_indices( - self, dataset_size: int, split_val: float, split_test: float, + split_type: str = "random", sample_idx: Optional[Iterable[int]] = None, split_seed: int = None, splits_path: Union[str, os.PathLike, Dict[str, Iterable[int]]] = None, @@ -1865,38 +1529,14 @@ def _get_split_indices( if sample_idx is None: sample_idx = np.arange(dataset_size) - if splits_path is None: - # Random splitting - if split_test + split_val > 0: - train_indices, val_test_indices = train_test_split( - sample_idx, - test_size=split_val + split_test, - random_state=split_seed, - ) - sub_split_test = split_test / (split_test + split_val) - else: - train_indices = sample_idx - val_test_indices = np.array([]) - sub_split_test = 0 - - if split_test > 0: - val_indices, test_indices = train_test_split( - val_test_indices, - test_size=sub_split_test, - random_state=split_seed, - ) - else: - val_indices = val_test_indices - test_indices = np.array([]) - - else: + if splits_path is not None: train, val, test = split_names if isinstance(splits_path, (Dict, pd.DataFrame)): # Split from a dataframe splits = splits_path else: # Split from an indices file - file_type = self._get_data_file_type(splits_path) + file_type = BaseDataModule._get_data_file_type(splits_path) train, val, test = split_names @@ -1904,7 +1544,7 @@ def _get_split_indices( splits = torch.load(splits_path) elif file_type in ["csv", "tsv"]: with fsspec.open(str(splits_path)) as f: - splits = self._read_csv(splits_path) + splits = BaseDataModule._read_csv(splits_path) else: raise ValueError( f"file type `{file_type}` for `{splits_path}` not recognised, please use .pt, .csv or .tsv" @@ -1916,18 +1556,83 @@ def _get_split_indices( test_indices = np.asarray(splits[test]).astype("int") test_indices = test_indices[~np.isnan(test_indices)].tolist() + elif split_type == "scaffold" and split_test != 1.: + # Scaffold splitting + try: + import splito + except ImportError as error: + raise RuntimeError( + f"To do the splitting, `splito` needs to be installed. " + f"Please install it with `pip install splito`" + ) from error + + # Split data into scaffolds + splitter = splito.ScaffoldSplit( + smiles=self.smiles, + test_size=split_test, + random_state=split_seed, + ) + train_val_indices, test_indices = next(splitter.split(X=self.smiles)) + train_val_smiles = [self.smiles[i] for i in train_val_indices] + + sub_split_val = split_val / (1 - split_test) + + splitter = splito.ScaffoldSplit( + smiles=train_val_smiles, + test_size=sub_split_val, + random_state=split_seed, + ) + train_indices, val_indices = next(splitter.split(X=train_val_smiles)) + + else: + if split_type != "random": + logger.warning(f"Unkown split {split_type}. Defaulting to `random`.") + + # Random splitting + if split_test + split_val > 0: + if split_test == 1.: + train_indices = np.array([]) + val_test_indices = sample_idx + sub_split_test = 1. + else: + train_indices, val_test_indices = train_test_split( + sample_idx, + test_size=split_val + split_test, + random_state=split_seed, + ) + sub_split_test = split_test / (split_test + split_val) + else: + train_indices = sample_idx + val_test_indices = np.array([]) + sub_split_test = 0 + + if split_test > 0: + if split_test == 1.: + val_indices = np.array([]) + test_indices = val_test_indices + else: + val_indices, test_indices = train_test_split( + val_test_indices, + test_size=sub_split_test, + random_state=split_seed, + ) + else: + val_indices = val_test_indices + test_indices = np.array([]) + # Filter train, val and test indices _, train_idx, _ = np.intersect1d(sample_idx, train_indices, return_indices=True) train_indices = train_idx.tolist() - _, valid_idx, _ = np.intersect1d(sample_idx, val_indices, return_indices=True) - val_indices = valid_idx.tolist() + _, val_idx, _ = np.intersect1d(sample_idx, val_indices, return_indices=True) + val_indices = val_idx.tolist() _, test_idx, _ = np.intersect1d(sample_idx, test_indices, return_indices=True) test_indices = test_idx.tolist() return train_indices, val_indices, test_indices + @staticmethod def _sub_sample_df( - self, df: pd.DataFrame, sample_size: Union[int, float, None], seed: Optional[int] = None + df: pd.DataFrame, sample_size: Union[int, float, None], seed: Optional[int] = None ) -> pd.DataFrame: r""" subsample from a pandas dataframe @@ -1954,10 +1659,17 @@ def _sub_sample_df( def get_data_hash(self): """ - Get a hash specific to a dataset and smiles_transformer. + Get a hash specific to a dataset. Useful to cache the pre-processed data. - """ - args = {} + Don't include options only used at data loading time, such as + most featurization options, but include options used during + pre-processing, like merge_equivalent_mols. + """ + args = { + "add_self_loop": self.add_self_loop, + "explicit_H": self.explicit_H, + "merge_equivalent_mols": self.merge_equivalent_mols, + } # pop epoch_sampling_fraction out when creating hash # so that the data cache does not need to be regenerated # when epoch_sampling_fraction has changed. @@ -1974,129 +1686,19 @@ def get_data_hash(self): task_args.pop("epoch_sampling_fraction", None) args[task_key] = task_args - hash_dict = { - "smiles_transformer": self.smiles_transformer, - "task_specific_args": args, - } - data_hash = get_md5_hash(hash_dict) + data_hash = get_md5_hash(args) return data_hash - def get_data_cache_fullname(self, compress: bool = False) -> str: - """ - Create a hash for the dataset, and use it to generate a file name - - Parameters: - compress: Whether to compress the data - Returns: - full path to the data cache file - """ - if self.processed_graph_data_path is None: - return - ext = ".datacache" - if compress: - ext += ".gz" - data_cache_fullname = fs.join(self.processed_graph_data_path, self.data_hash + ext) - return data_cache_fullname - - def load_data_from_cache(self, verbose: bool = True, compress: bool = False) -> bool: - """ - Load the datasets from cache. First create a hash for the dataset, and verify if that - hash is available at the path given by `self.processed_graph_data_path`. - - Parameters: - verbose: Whether to print the progress - compress: Whether to compress the data - - Returns: - cache_data_exists: Whether the cache exists (if the hash matches) and the loading succeeded - """ - full_cache_data_path = self.get_data_cache_fullname(compress=compress) - - if full_cache_data_path is None: - logger.info("No cache data path specified. Skipping loading the data from cache.") - return False - - cache_data_exists = fs.exists(full_cache_data_path) - - if cache_data_exists: - try: - logger.info(f"Loading the data from cache at path `{full_cache_data_path}`") - now = time.time() - with fsspec.open(full_cache_data_path, mode="rb", compression="infer") as file: - load_params = torch.load(file) - self.__dict__.update(load_params) - ( - self.train_singletask_datasets, - self.val_singletask_datasets, - self.test_singletask_datasets, - ) = self.get_subsets_of_datasets( - self.single_task_datasets, - self.task_train_indices, - self.task_val_indices, - self.task_test_indices, - ) - elapsed = round(time.time() - now) - logger.info( - f"Successfully loaded the data from cache in {elapsed}s at path: `{full_cache_data_path}`" - ) - return True - except Exception as e: - if verbose: - logger.warning( - f"Data cache failed to load path: `{full_cache_data_path}`.\nThe data will be prepared and cache will be created for future runs." - ) - logger.warning(e.__str__()) - return False - else: - if verbose: - logger.info( - f"Data cache not found at path: `{full_cache_data_path}`.\nThe data will be prepared and cache will be created for future runs." - ) - return False - - def get_subsets_of_datasets( - self, - single_task_datasets: Dict[str, Datasets.SingleTaskDataset], - task_train_indices: Dict[str, Iterable], - task_val_indices: Dict[str, Iterable], - task_test_indices: Dict[str, Iterable], - ) -> Tuple[Subset, Subset, Subset]: - """ - From a dictionary of datasets and their associated indices, subset the train/val/test sets - - Parameters: - single_task_datasets: Dictionary of datasets - task_train_indices: Dictionary of train indices - task_val_indices: Dictionary of val indices - task_test_indices: Dictionary of test indices - Returns: - train_singletask_datasets: Dictionary of train subsets - val_singletask_datasets: Dictionary of val subsets - test_singletask_datasets: Dictionary of test subsets - """ - train_singletask_datasets = {} - val_singletask_datasets = {} - test_singletask_datasets = {} - for task in task_train_indices.keys(): - train_singletask_datasets[task] = Subset(single_task_datasets[task], task_train_indices[task]) - val_singletask_datasets[task] = Subset(single_task_datasets[task], task_val_indices[task]) - test_singletask_datasets[task] = Subset(single_task_datasets[task], task_test_indices[task]) - return train_singletask_datasets, val_singletask_datasets, test_singletask_datasets - def __len__(self) -> int: r""" - Returns the number of elements of the current DataModule, which is the combined size of all single-task datasets given. + Returns the number of smiles of the current DataModule, which depends on all the smiles from all tasks. + If `prepare_data` is not called, the length is unknown and will raise an error. Returns: num_elements: Number of elements in the current DataModule """ - num_elements = 0 - for task, args in self.task_dataset_processing_params.items(): - if args.df is None: - df = self._read_table(args.df_path, usecols=[args.smiles_col]) - num_elements += len(df) - else: - num_elements += len(args.df) - return num_elements + if self._len is None: + raise ValueError("The length of the dataset is unknown. Please call `prepare_data` first.") + return self._len def to_dict(self) -> Dict[str, Any]: """ @@ -2113,7 +1715,6 @@ def to_dict(self) -> Dict[str, Any]: obj_repr["test_size"] = len(self.test_indices) if self.test_indices is not None else None obj_repr["batch_size_training"] = self.batch_size_training obj_repr["batch_size_inference"] = self.batch_size_inference - obj_repr["batch_size_per_pack"] = self.batch_size_per_pack obj_repr["num_node_feats"] = self.num_node_feats obj_repr["num_node_feats_with_positional_encoding"] = self.num_node_feats_with_positional_encoding obj_repr["num_edge_feats"] = self.num_edge_feats @@ -2138,20 +1739,15 @@ def __init__( self, task_specific_args: Dict[str, Union[DatasetProcessingParams, Dict[str, Any]]], processed_graph_data_path: Optional[Union[str, os.PathLike]] = None, - dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, - batch_size_per_pack: Optional[int] = None, num_workers: int = 0, pin_memory: bool = True, persistent_workers: bool = False, multiprocessing_context: Optional[str] = None, - featurization_n_jobs: int = -1, - featurization_progress: bool = False, - featurization_backend: str = "loky", collate_fn: Optional[Callable] = None, - prepare_dict_or_graph: str = "pyg:graph", + preprocessing_n_jobs: int = 0, **kwargs, ): r""" @@ -2168,27 +1764,26 @@ def __init__( meaning that all molecules will be considered. processed_graph_data_path: Path to the processed graph data. If None, the data will be downloaded from the OGB website. - dataloading_from: Whether to load the data from RAM or disk. Default is "ram". featurization: args to apply to the SMILES to Graph featurizer. batch_size_training: batch size for training and val dataset. batch_size_inference: batch size for test dataset. num_workers: Number of workers for the dataloader. Use -1 to use all available cores. pin_memory: Whether to pin on paginated CPU memory for the dataloader. - featurization_n_jobs: Number of cores to use for the featurization. - featurization_progress: whether to show a progress bar during featurization. - featurization_backend: The backend to use for the molecular featurization. - - - "multiprocessing": Found to cause less memory issues. - - "loky": joblib's Default. Found to cause memory leaks. - - "threading": Found to be slow. - collate_fn: A custom torch collate function. Default is to `graphium.data.graphium_collate_fn` sample_size: - `int`: The maximum number of elements to take from the dataset. - `float`: Value between 0 and 1 representing the fraction of the dataset to consider - `None`: all elements are considered. + preprocessing_n_jobs: Number of threads to use during preprocessing. + Use 0 to use all available cores, or -1 to use all but one core. + + dataloading_from: Deprecated. Behaviour now always matches previous "disk" option. + featurization_n_jobs: Deprecated. + featurization_progress: Deprecated. + featurization_backend: Deprecated. + prepare_dict_or_graph: Deprecated. Behaviour now always matches previous "pyg:graph" option. """ new_task_specific_args = {} @@ -2214,21 +1809,15 @@ def __init__( dm_args = {} dm_args["task_specific_args"] = new_task_specific_args dm_args["processed_graph_data_path"] = processed_graph_data_path - dm_args["dataloading_from"] = dataloading_from - dm_args["dataloader_from"] = dataloading_from dm_args["featurization"] = featurization dm_args["batch_size_training"] = batch_size_training dm_args["batch_size_inference"] = batch_size_inference - dm_args["batch_size_per_pack"] = batch_size_per_pack dm_args["num_workers"] = num_workers dm_args["pin_memory"] = pin_memory - dm_args["featurization_n_jobs"] = featurization_n_jobs - dm_args["featurization_progress"] = featurization_progress - dm_args["featurization_backend"] = featurization_backend dm_args["persistent_workers"] = persistent_workers dm_args["multiprocessing_context"] = multiprocessing_context dm_args["collate_fn"] = collate_fn - dm_args["prepare_dict_or_graph"] = prepare_dict_or_graph + dm_args["preprocessing_n_jobs"] = preprocessing_n_jobs super().__init__(**dm_args, **kwargs) @@ -2365,7 +1954,7 @@ def _get_ogb_metadata(self): return ogb_metadata -class ADMETBenchmarkDataModule(MultitaskFromSmilesDataModule): +class TDCBenchmarkDataModule(MultitaskFromSmilesDataModule): """ Wrapper to use the ADMET benchmark group from the TDC (Therapeutics Data Commons). @@ -2400,20 +1989,15 @@ def __init__( tdc_train_val_seed: int = 0, # Inherited arguments from superclass processed_graph_data_path: Optional[Union[str, Path]] = None, - dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, batch_size_training: int = 16, batch_size_inference: int = 16, - batch_size_per_pack: Optional[int] = None, num_workers: int = 0, pin_memory: bool = True, persistent_workers: bool = False, multiprocessing_context: Optional[str] = None, - featurization_n_jobs: int = -1, - featurization_progress: bool = False, - featurization_backend: str = "loky", collate_fn: Optional[Callable] = None, - prepare_dict_or_graph: str = "pyg:graph", + preprocessing_n_jobs: int = 0, **kwargs, ): try: @@ -2456,23 +2040,24 @@ def __init__( for t in tdc_benchmark_names } + # Create a temporary `processed_graph_data_path` to store the processed graphs and labels + if processed_graph_data_path is None: + processed_graph_data_path = fs.join(tdc_cache_dir, "processed_graph_data") + if not fs.exists(processed_graph_data_path): + fs.mkdir(processed_graph_data_path) + super().__init__( task_specific_args=task_specific_args, featurization=featurization, processed_graph_data_path=processed_graph_data_path, - dataloading_from=dataloading_from, batch_size_training=batch_size_training, batch_size_inference=batch_size_inference, - batch_size_per_pack=batch_size_per_pack, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, multiprocessing_context=multiprocessing_context, - featurization_n_jobs=featurization_n_jobs, - featurization_progress=featurization_progress, - featurization_backend=featurization_backend, collate_fn=collate_fn, - prepare_dict_or_graph=prepare_dict_or_graph, + preprocessing_n_jobs=preprocessing_n_jobs, **kwargs, ) @@ -2532,237 +2117,3 @@ def _get_task_specific_arguments(self, name: str, seed: int, cache_dir: str) -> split_names=["train", "val", "test"], task_level="graph", ) - - -class FakeDataModule(MultitaskFromSmilesDataModule): - """ - A fake datamodule that generates artificial data by mimicking the true data coming - from the provided dataset. - It is useful to test the speed and performance of the model on a dataset without - having to featurize it and wait for the workers to load it. - """ - - def __init__( - self, - task_specific_args: Dict[str, Dict[str, Any]], # TODO: Replace this with DatasetParams - featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, - batch_size_training: int = 16, - batch_size_inference: int = 16, - num_workers: int = 0, - pin_memory: bool = True, - persistent_workers: bool = False, - multiprocessing_context: Optional[str] = None, - collate_fn: Optional[Callable] = None, - prepare_dict_or_graph: str = "pyg:graph", - num_mols_to_generate: int = 1000000, - indexing_single_elem: bool = True, - **kwargs, - ): - super().__init__( - task_specific_args=task_specific_args, - featurization=featurization, - batch_size_training=batch_size_training, - batch_size_inference=batch_size_inference, - num_workers=num_workers, - pin_memory=pin_memory, - persistent_workers=persistent_workers, - multiprocessing_context=multiprocessing_context, - collate_fn=collate_fn, - prepare_dict_or_graph=prepare_dict_or_graph, - **kwargs, - ) - self.num_mols_to_generate = num_mols_to_generate - self.indexing_single_elem = indexing_single_elem - - def generate_data(self, label_cols: List[str], smiles_col: str): - """ - Parameters: - labels_cols - smiles_col - Returns: - pd.DataFrame - """ - num_generated_mols = int(1) - # Create a dummy generated dataset - singel smiles string, duplicated N times - example_molecules = dict( - smiles="C1N2C3C4C5OC13C2C45", - cxsmiles="[H]C1C2=C(NC(=O)[C@@]1([H])C1=C([H])C([H])=C(C([H])([H])[H])C([H])=C1[H])C([H])=C([H])N=C2[H] |(6.4528,-1.5789,-1.2859;5.789,-0.835,-0.8455;4.8499,-0.2104,-1.5946;3.9134,0.7241,-0.934;3.9796,1.1019,0.3172;5.0405,0.6404,1.1008;5.2985,1.1457,2.1772;5.9121,-0.5519,0.613;6.9467,-0.2303,0.8014;5.677,-1.7955,1.4745;4.7751,-2.7953,1.0929;4.2336,-2.7113,0.154;4.5521,-3.9001,1.914;3.8445,-4.6636,1.5979;5.215,-4.0391,3.1392;4.9919,-5.2514,4.0126;5.1819,-5.0262,5.0671;5.6619,-6.0746,3.7296;3.966,-5.6247,3.925;6.1051,-3.0257,3.52;6.6247,-3.101,4.4725;6.3372,-1.9217,2.7029;7.0168,-1.1395,3.0281;2.8586,1.2252,-1.7853;2.1303,1.9004,-1.3493;2.8118,0.8707,-3.0956;2.0282,1.2549,-3.7434;3.716,0.0207,-3.7371;4.6658,-0.476,-3.0127;5.3755,-1.1468,-3.5021)|", - ) - example_df_entry = {smiles_col: example_molecules[smiles_col]} - for label in label_cols: - example_df_entry[label] = np.random.random() - df = pd.DataFrame([example_df_entry]) - logger.info(f"Generating fake dataset on host... \n Generating {num_generated_mols} rows in the df.") - df = pd.concat([df] * num_generated_mols, ignore_index=True) - return df - - def prepare_data(self): - """Called only from a single process in distributed settings. Steps: - - - If each cache is set and exists, reload from cache and return. Otherwise, - - For each single-task dataset: - - Load its dataframe from a path (if provided) - - Subsample the dataframe - - Extract the smiles, labels from the dataframe - - In the previous step, we were also able to get the unique smiles, which we use to compute the features - - For each single-task dataframe and associated data (smiles, labels, etc.): - - Filter out the data corresponding to molecules which failed featurization. - - Create a corresponding SingletaskDataset - - Split the SingletaskDataset according to the task-specific splits for train, val and test - """ - - """Load all single-task dataframes.""" - if self.num_mols_to_generate is None: - num_mols = 0 - - task_df = {} - for task, args in self.task_dataset_processing_params.items(): - logger.info(f"Reading data for task '{task}'") - if args.df is None: - # Only load the useful columns, as some datasets can be very large when loading all columns. - label_cols = self._parse_label_cols( - df=None, df_path=args.df_path, label_cols=args.label_cols, smiles_col=args.smiles_col - ) - task_df[task] = self.generate_data(label_cols=args.label_cols, smiles_col=args.smiles_col) - if self.num_mols_to_generate is None: - num_mols = max(num_mols, len(task_df[task])) - task_df[task] = task_df[task].iloc[0:1] - - args.label_cols = label_cols - if self.num_mols_to_generate is None: - self.num_mols_to_generate = num_mols - logger.info("Done reading datasets") - - """Subsample the data frames and extract the necessary data to create SingleTaskDatasets for each task (smiles, labels, extras).""" - task_dataset_args = {} - for task in task_df.keys(): - task_dataset_args[task] = {} - - for task, df in task_df.items(): - logger.info(f"Prepare single-task dataset for task '{task}' with {len(df)} data points.") - # Extract smiles, labels, extras - args = self.task_dataset_processing_params[task] - smiles, labels, sample_idx, extras = self._extract_smiles_labels( - df, - task_level=args.task_level, - smiles_col=args.smiles_col, - label_cols=args.label_cols, - idx_col=args.idx_col, - weights_col=args.weights_col, - weights_type=args.weights_type, - ) - - # Store the relevant information for each task's dataset - task_dataset_args[task]["smiles"] = smiles - task_dataset_args[task]["labels"] = labels - task_dataset_args[task]["sample_idx"] = sample_idx - task_dataset_args[task]["extras"] = extras - - """Convert SMILES to features (graphs, fingerprints, etc.) for the unique molecules found.""" - all_smiles = [] - idx_per_task = {} - total_len = 0 - for task, dataset_args in task_dataset_args.items(): - all_smiles.extend(dataset_args["smiles"]) - num_smiles = len(dataset_args["smiles"]) - idx_per_task[task] = (total_len, total_len + num_smiles) - total_len += num_smiles - # Get all unique mol ids - all_unique_mol_ids = smiles_to_unique_mol_ids( - all_smiles, - n_jobs=self.featurization_n_jobs, - featurization_batch_size=self.featurization_batch_size, - backend=self.featurization_backend, - ) - # Convert SMILES to features - features, _ = self._featurize_molecules(all_smiles) - task_dataset_args[task]["features"] = features - """Filter data based on molecules which failed featurization. Create single task datasets as well.""" - self.single_task_datasets = {} - for task, args in task_dataset_args.items(): - self.single_task_datasets[task] = Datasets.SingleTaskDataset( - features=task_dataset_args[task]["features"], - labels=task_dataset_args[task]["labels"], - smiles=task_dataset_args[task]["smiles"], - indices=task_dataset_args[task]["sample_idx"], - unique_ids=all_unique_mol_ids[idx_per_task[task][0] : idx_per_task[task][1]], - **task_dataset_args[task]["extras"], - ) - - """We split the data up to create train, val and test datasets""" - self.train_singletask_datasets = {} - self.val_singletask_datasets = {} - self.test_singletask_datasets = {} - for task, df in task_df.items(): - self.train_singletask_datasets[task] = Subset(self.single_task_datasets[task], [0]) - self.val_singletask_datasets[task] = Subset(self.single_task_datasets[task], [0]) - self.test_singletask_datasets[task] = Subset(self.single_task_datasets[task], [0]) - - def setup(self, stage=None): - # TODO - """ - Prepare the torch dataset. Called on every GPUs. Setting state here is ok. - Parameters: - stage (str): Either 'fit', 'test', or None. - """ - labels_size = {} - - if stage == "fit" or stage is None: - self.train_ds = Datasets.FakeDataset(self.train_singletask_datasets, num_mols=self.num_mols_to_generate, indexing_same_elem=self.indexing_single_elem) # type: ignore - self.val_ds = Datasets.FakeDataset(self.val_singletask_datasets, num_mols=self.num_mols_to_generate, indexing_same_elem=self.indexing_single_elem) # type: ignore - print(self.train_ds) - print(self.val_ds) - - labels_size.update( - self.train_ds.labels_size - ) # Make sure that all task label sizes are contained in here. Maybe do the update outside these if statements. - labels_size.update(self.val_ds.labels_size) - - if stage == "test" or stage is None: - self.test_ds = Datasets.FakeDataset(self.test_singletask_datasets, num_mols=self.num_mols_to_generate, indexing_same_elem=self.indexing_single_elem) # type: ignore - print(self.test_ds) - labels_size.update(self.test_ds.labels_size) - - default_labels_size_dict = self.collate_fn.keywords.get("labels_size_dict", None) - - if default_labels_size_dict is None: - self.collate_fn.keywords["labels_size_dict"] = labels_size - - def get_fake_graph(self): - """ - Low memory footprint method to get the first datapoint DGL graph. - The first 10 rows of the data are read in case the first one has a featurization - error. If all 20 first element, then `None` is returned, otherwise the first - graph to not fail is returned. - """ - keys = list(self.task_dataset_processing_params.keys()) - task = keys[0] - args = self.task_dataset_processing_params[task] - if args.df is None: - df = self._read_csv(args.df_path, nrows=20) - else: - df = args.df.iloc[0:20, :] - - df = df.iloc[0:20, :] - label_cols = self._parse_label_cols( - df, df_path=None, label_cols=args.label_cols, smiles_col=args.smiles_col - ) - - smiles, labels, sample_idx, extras = self._extract_smiles_labels( - df, - task_level=args.task_level, - smiles_col=args.smiles_col, - label_cols=label_cols, - idx_col=args.idx_col, - weights_col=args.weights_col, - weights_type=args.weights_type, - ) - - graph = None - for s in smiles: - graph = self.smiles_transformer(s, mask_nan=0.0) - num_nodes = graph.num_nodes - num_edges = graph.num_edges - if (graph is not None) and (num_edges > 0) and (num_nodes > 0): - break - return graph diff --git a/graphium/data/dataset.py b/graphium/data/dataset.py index 34c1b30aa..bf55e0418 100644 --- a/graphium/data/dataset.py +++ b/graphium/data/dataset.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -16,7 +16,7 @@ from copy import deepcopy from functools import lru_cache from multiprocessing import Manager -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import fsspec import numpy as np @@ -26,125 +26,7 @@ from torch.utils.data.dataloader import Dataset from torch_geometric.data import Batch, Data -from graphium.data.smiles_transform import smiles_to_unique_mol_ids -from graphium.features import GraphDict - - -class SingleTaskDataset(Dataset): - def __init__( - self, - labels: List[Union[torch.Tensor, np.ndarray]], - features: Optional[List[Union[Data, "GraphDict"]]] = None, - smiles: Optional[List[str]] = None, - indices: Optional[List[int]] = None, - weights: Optional[Union[torch.Tensor, np.ndarray]] = None, - unique_ids: Optional[List[str]] = None, - mol_ids: Optional[List[str]] = None, - ): - r""" - dataset for a single task - Parameters: - labels: A list of labels for the given task (one per graph) - features: A list of graphs - smiles: A list of smiles - indices: A list of indices - weights: A list of weights - unique_ids: A list of unique ids for each molecule generated from `datamol.unique_id` - mol_ids: A list of ids coming from the original dataset. Useful to identify the molecule in the original dataset. - """ - - # Verify that all lists are the same length - numel = len(labels) - - def _check_if_same_length(to_check, label): - """Simple utility method to throw an error if the length is not as expected.""" - if to_check is not None and len(to_check) != numel: - raise ValueError( - f"{label} must be the same length as `labels`, got {len(to_check)} and {numel}" - ) - - _check_if_same_length(features, "features") - _check_if_same_length(indices, "indices") - _check_if_same_length(weights, "weights") - _check_if_same_length(unique_ids, "unique_ids") - _check_if_same_length(mol_ids, "mol_ids") - - self.labels = labels - if smiles is not None: - manager = Manager() # Avoid memory leaks with `num_workers > 0` by using the Manager - self.smiles = manager.list(smiles) - else: - self.smiles = None - self.features = features - self.indices = indices - if self.indices is not None: - self.indices = np.array( - self.indices - ) # Avoid memory leaks with `num_workers > 0` by using numpy array - self.weights = weights - self.unique_ids = unique_ids - self.mol_ids = mol_ids - - def __len__(self): - r""" - return the size of the dataset - Returns: - size: the size of the dataset - """ - return len(self.labels) - - def __getitem__(self, idx): - """ - get the data at the given index - Parameters: - idx: the index to get the data at - Returns: - datum: a dictionary containing the data at the given index, with keys "features", "labels", "smiles", "indices", "weights", "unique_ids" - """ - datum = {} - - if self.features is not None: - datum["features"] = self.features[idx] - - if self.labels is not None: - datum["labels"] = self.labels[idx] - - if self.smiles is not None: - datum["smiles"] = self.smiles[idx] - - if self.indices is not None: - datum["indices"] = self.indices[idx] - - if self.weights is not None: - datum["weights"] = self.weights[idx] - - if self.unique_ids is not None: - datum["unique_ids"] = self.unique_ids[idx] - - if self.mol_ids is not None: - datum["mol_ids"] = self.mol_ids[idx] - - return datum - - def __getstate__(self): - """Serialize the class for pickling.""" - state = {} - state["labels"] = self.labels - state["smiles"] = list(self.smiles) if self.smiles is not None else None - state["features"] = self.features - state["indices"] = self.indices - state["weights"] = self.weights - state["unique_ids"] = self.unique_ids - state["mol_ids"] = self.mol_ids - return state - - def __setstate__(self, state: dict): - """Reload the class from pickling.""" - if state["smiles"] is not None: - manager = Manager() - state["smiles"] = manager.list(state["smiles"]) - - self.__dict__.update(state) +import graphium_cpp class MultitaskDataset(Dataset): @@ -152,178 +34,51 @@ class MultitaskDataset(Dataset): def __init__( self, - datasets: Dict[str, SingleTaskDataset], - n_jobs=-1, - backend: str = "loky", - featurization_batch_size=1000, - progress: bool = True, - save_smiles_and_ids: bool = False, + featurize_smiles: Callable[[str], dict], + task_names: List[str] = None, + label_num_cols: List[int] = None, + label_dtypes: List[int] = None, + mol_file_data_offsets=None, + concat_smiles_tensor=None, + smiles_offsets_tensor=None, + num_nodes_tensor=None, + num_edges_tensor=None, about: str = "", data_path: Optional[Union[str, os.PathLike]] = None, - dataloading_from: str = "ram", - data_is_cached: bool = False, + return_smiles: bool = False, ): r""" This class holds the information for the multitask dataset. - Several single-task datasets can be merged to create a multi-task dataset. After merging the dictionary of single-task datasets. we will have a multitask dataset of the following form: - - self.mol_ids will be a list to contain the unique molecular IDs to identify the molecules - - self.smiles will be a list to contain the corresponding smiles for that molecular ID across all single-task datasets - - self.labels will be a list of dictionaries where the key is the task name and the value is the label(s) for that task. - At this point, any particular molecule will only have entries for tasks for which it has a label. Later, in the collate - function, we fill up the missing task labels with NaNs. - - self.features will be a list of featurized graphs corresponding to that particular unique molecule. - However, for testing purposes we may not require features so that we can make sure that this merge function works. + - self.mol_file_data_offsets will be a Tensor representing where to find + label data about each molecule in the corresponding file + - self.smiles_tensor will be a Tensor containing all smiles strings concatenated, with null terminators + - self.smiles_offsets_tensor will be a Tensor indicating where smiles strings start in smiles_tensor + - self.num_nodes_tensor will be a Tensor of the number of nodes in each graph + - self.num_edges_tensor will be a Tensor of the number of edges in each graph Parameters: - datasets: A dictionary of single-task datasets - n_jobs: Number of jobs to run in parallel - backend: Parallelization backend - featurization_batch_size: The batch size to use for the parallelization of the featurization - progress: Whether to display the progress bar - save_smiles_and_ids: Whether to save the smiles and ids for the dataset. If `False`, `mol_ids` and `smiles` are set to `None` about: A description of the dataset data_path: The location of the data if saved on disk - dataloading_from: Whether to load the data from `"disk"` or `"ram"` - data_is_cached: Whether the data is already cached on `"disk"` """ super().__init__() - self.n_jobs = n_jobs - self.backend = backend - self.featurization_batch_size = featurization_batch_size - self.progress = progress + self.about = about - self.save_smiles_and_ids = save_smiles_and_ids self.data_path = data_path - self.dataloading_from = dataloading_from - - logger.info(f"Dataloading from {dataloading_from.upper()}") - - if data_is_cached: - self._load_metadata() - - if dataloading_from == "disk": - self.features = None - self.labels = None - elif dataloading_from == "ram": - logger.info(f"Transferring {about} from DISK to RAM...") - self.transfer_from_disk_to_ram() - - else: - task = next(iter(datasets)) - self.features = None - if (len(datasets[task]) > 0) and ("features" in datasets[task][0]): - self.mol_ids, self.smiles, self.labels, self.features = self.merge(datasets) - else: - self.mol_ids, self.smiles, self.labels = self.merge(datasets) - # Set mol_ids and smiles to None to save memory as they are not needed. - if not save_smiles_and_ids: - self.mol_ids = None - self.smiles = None - self.labels_size = self.set_label_size_dict(datasets) - self.labels_dtype = self.set_label_dtype_dict(datasets) - self.dataset_length = len(self.labels) - self._num_nodes_list = None - self._num_edges_list = None - if self.features is not None: - self._num_nodes_list = get_num_nodes_per_graph(self.features) - self._num_edges_list = get_num_edges_per_graph(self.features) - - def transfer_from_disk_to_ram(self, parallel_with_batches: bool = False): - """ - Function parallelizing transfer from DISK to RAM - """ - - def transfer_mol_from_disk_to_ram(idx): - """ - Function transferring single mol from DISK to RAM - """ - data_dict = self.load_graph_from_index(idx) - mol_in_ram = { - "features": data_dict["graph_with_features"], - "labels": data_dict["labels"], - } + self.featurize_smiles = featurize_smiles + self.task_names = task_names + self.label_num_cols = label_num_cols + self.label_dtypes = label_dtypes + self.mol_file_data_offsets = mol_file_data_offsets + self.smiles_tensor = concat_smiles_tensor + self.smiles_offsets_tensor = smiles_offsets_tensor + self.num_nodes_tensor = num_nodes_tensor + self.num_edges_tensor = num_edges_tensor + self.dataset_length = num_nodes_tensor.size(dim=0) - return mol_in_ram - - if parallel_with_batches and self.featurization_batch_size: - data_in_ram = parallelized_with_batches( - transfer_mol_from_disk_to_ram, - range(self.dataset_length), - batch_size=self.featurization_batch_size, - n_jobs=0, - backend=self.backend, - progress=self.progress, - tqdm_kwargs={"desc": "Transfer from DISK to RAM"}, - ) - else: - data_in_ram = parallelized( - transfer_mol_from_disk_to_ram, - range(self.dataset_length), - n_jobs=0, - backend=self.backend, - progress=self.progress, - tqdm_kwargs={"desc": "Transfer from DISK to RAM"}, - ) - - self.features = [sample["features"] for sample in data_in_ram] - self.labels = [sample["labels"] for sample in data_in_ram] - - def save_metadata(self, directory: str): - """ - Save everything other than features/labels - """ - attrs_to_save = [ - "mol_ids", - "smiles", - "labels_size", - "labels_dtype", - "dataset_length", - "_num_nodes_list", - "_num_edges_list", - ] - attrs = {attr: getattr(self, attr) for attr in attrs_to_save} - - path = os.path.join(directory, "multitask_metadata.pkl") - - torch.save(attrs, path, pickle_protocol=4) - - def _load_metadata(self): - """ - Load everything other than features/labels - """ - attrs_to_load = [ - "mol_ids", - "smiles", - "labels_size", - "labels_dtype", - "dataset_length", - "_num_nodes_list", - "_num_edges_list", - ] - path = os.path.join(self.data_path, "multitask_metadata.pkl") - with fsspec.open(path, "rb") as f: - attrs = torch.load(path) - - if not set(attrs_to_load).issubset(set(attrs.keys())): - raise ValueError( - f"The metadata in the cache at {self.data_path} does not contain the right information. " - f"This may be because the cache was prepared using an earlier version of Graphium. " - f"You can try deleting the cache and running the data preparation again. " - f"\nMetadata keys found: {attrs.keys()}" - f"\nMetadata keys required: {attrs_to_load}" - ) - - for attr, value in attrs.items(): - setattr(self, attr, value) + self.return_smiles = return_smiles - if self.save_smiles_and_ids: - if self.smiles is None or self.mol_ids is None: - logger.warning( - f"Argument `save_smiles_and_ids` is set to {self.save_smiles_and_ids} but metadata in the cache at {self.data_path} does not contain smiles and mol_ids. " - f"This may be because `Datamodule.prepare_data(save_smiles_and_ids=False)` was run followed by `Datamodule.setup(save_smiles_and_ids=True)`. " - f"When loading from cached files, the `save_smiles_and_ids` argument of `Datamodule.setup()` is superseeded by the `Datamodule.prepare_data()`. " - ) + logger.info(f"Dataloading from DISK") def __len__(self): r""" @@ -336,24 +91,14 @@ def num_nodes_list(self): """ The number of nodes per graph """ - if self._num_nodes_list is None: - if len(self) == 0: - self._num_nodes_list = [] - else: - self._num_nodes_list = get_num_nodes_per_graph(self.features) - return self._num_nodes_list + return self.num_nodes_tensor @property def num_edges_list(self): """ The number of edges per graph """ - if self._num_edges_list is None: - if len(self) == 0: - self._num_edges_list = [] - else: - self._num_edges_list = get_num_edges_per_graph(self.features) - return self._num_edges_list + return self.num_edges_tensor @property def num_graphs_total(self): @@ -367,28 +112,30 @@ def num_nodes_total(self): """Total number of nodes for all graphs""" if len(self) == 0: return - return sum(self.num_nodes_list) + return torch.sum(self.num_nodes_list, dtype=torch.int64).item() @property def max_num_nodes_per_graph(self): """Maximum number of nodes per graph""" if len(self) == 0: return - return max(self.num_nodes_list) + return torch.max(self.num_nodes_list).item() @property def std_num_nodes_per_graph(self): """Standard deviation of number of nodes per graph""" if len(self) == 0: return - return np.std(self.num_nodes_list) + # correction is zero to match previous default behaviour of numpy.std + # Consider changing it to 1 (the torch.std default) + return torch.std(self.num_nodes_list.to(torch.float64), correction=0).item() @property def min_num_nodes_per_graph(self): """Minimum number of nodes per graph""" if len(self) == 0: return - return min(self.num_nodes_list) + return torch.min(self.num_nodes_list).item() @property def mean_num_nodes_per_graph(self): @@ -402,28 +149,30 @@ def num_edges_total(self): """Total number of edges for all graphs""" if len(self) == 0: return - return sum(self.num_edges_list) + return torch.sum(self.num_edges_list, dtype=torch.int64).item() @property def max_num_edges_per_graph(self): """Maximum number of edges per graph""" if len(self) == 0: return - return max(self.num_edges_list) + return torch.max(self.num_edges_list).item() @property def min_num_edges_per_graph(self): """Minimum number of edges per graph""" if len(self) == 0: return - return min(self.num_edges_list) + return torch.min(self.num_edges_list).item() @property def std_num_edges_per_graph(self): """Standard deviation of number of nodes per graph""" if len(self) == 0: return - return np.std(self.num_edges_list) + # correction is zero to match previous default behaviour of numpy.std + # Consider changing it to 1 (the torch.std default) + return torch.std(self.num_edges_list.to(torch.float64), correction=0).item() @property def mean_num_edges_per_graph(self): @@ -438,27 +187,29 @@ def __getitem__(self, idx): Parameters: idx: The index of the data to retrieve Returns: - A dictionary containing the data for the specified index with keys "mol_ids", "smiles", "labels", and "features" + A dictionary containing the data for the specified index with keys "labels", "num_nodes", "num_edges", and "features" """ - datum = {} - if self.dataloading_from == "disk": - data_dict = self.load_graph_from_index(idx) - datum["features"] = data_dict["graph_with_features"] - datum["labels"] = data_dict["labels"] - if "smiles" in data_dict.keys(): - datum["smiles"] = data_dict["smiles"] - else: - if self.mol_ids is not None: - datum["mol_ids"] = self.mol_ids[idx] + if self.smiles_tensor is None or self.smiles_offsets_tensor is None: + raise ValueError("Missing smiles in MultitaskDataset.__getitem__") + + smiles_str = graphium_cpp.extract_string(self.smiles_tensor, self.smiles_offsets_tensor, idx) - if self.smiles is not None: - datum["smiles"] = self.smiles[idx] + if self.mol_file_data_offsets is None: + datum = {"features": self.featurize_smiles(smiles_str)} + else: + datum = { + "labels": self.load_graph_from_index(idx), + "features": self.featurize_smiles(smiles_str), + } - if self.labels is not None: - datum["labels"] = self.labels[idx] + if self.return_smiles: + datum["smiles"] = smiles_str - if self.features is not None: - datum["features"] = self.features[idx] + # One of the featurization error handling options returns a string on error, + # instead of throwing an exception, so assume that the intention is to just skip, + # instead of crashing. + if isinstance(datum["features"], str): + datum = None return datum @@ -468,165 +219,23 @@ def load_graph_from_index(self, data_idx): Parameters: data_idx: The index of the data to retrieve Returns: - A dictionary containing the data for the specified index with keys "graph_with_features", "labels" and "smiles" (optional). + A Data object containing the data for the specified index with keys corresponding to the tasks. """ - filename = os.path.join( - self.data_path, format(data_idx // 1000, "04d"), format(data_idx, "07d") + ".pkl" + labels = {} + graphium_cpp.load_labels_from_index( + self.data_path, + data_idx, + self.mol_file_data_offsets, + self.task_names, + self.label_num_cols, + self.label_dtypes, + labels, ) - with fsspec.open(filename, "rb") as f: - data_dict = torch.load(f) - return data_dict - - def merge( - self, datasets: Dict[str, SingleTaskDataset] - ) -> Tuple[List[str], List[str], List[Dict[str, Any]], List[Any]]: - r"""This function merges several single task datasets into a multitask dataset. - - The idea: for each of the smiles, labels, features and tasks, we create a corresponding list that concatenates these items across all tasks. - In particular, for any index, the elements in the smiles, labels, features and task lists at that index will correspond to each other (i.e. match up). - Over this list of all smiles (which we created by concatenating the smiles across all tasks), we compute their molecular ID using functions from Datamol. - Once again, we will have a list of molecular IDs which is the same size as the list of smiles, labels, features and tasks. - We then use numpy's `unique` function to find the exact list of unique molecular IDs as these will identify the molecules in our dataset. We also get the - inverse from numpy's `unique`, which will allow us to index in addition to the list of all molecular IDs, the list of all smiles, labels, features and tasks. - Finally, we use this inverse to construct the list of list of smiles, list of label dictionaries (indexed by task) and the list of features such that - the indices match up. This is what is needed for the `get_item` function to work. - - Parameters: - datasets: A dictionary of single-task datasets - Returns: - A tuple of (list of molecular IDs, list of smiles, list of label dictionaries, list of features) - """ + data_dict = Data() + for task, values in labels.items(): + data_dict[task] = values - # Get all the smiles, labels, features and tasks. - all_lists = self._get_all_lists_ids(datasets=datasets) - mol_ids, inv = self._get_inv_of_mol_ids(all_mol_ids=all_lists["mol_ids"]) - - # Store the smiles. - smiles = [[] for _ in range(len(mol_ids))] - for all_idx, unique_idx in enumerate(inv): - smiles[unique_idx].append(all_lists["smiles"][all_idx]) - - # Store the labels. - labels = [Data() for _ in range(len(mol_ids))] - for all_idx, unique_idx in enumerate(inv): - task: str = all_lists["tasks"][all_idx] - label = all_lists["labels"][all_idx] - labels[unique_idx][task] = label - - if all_idx < len(all_lists["features"]): - features = all_lists["features"][all_idx] - labels[unique_idx]["x"] = torch.empty( - (features.num_nodes, 1) - ) # IPU is not happy with zero-sized tensors, so use shape (features.num_nodes, 1) here - labels[unique_idx]["edge_index"] = torch.empty((2, features.num_edges)) - - # Store the features - if len(all_lists["features"]) > 0: - features = [-1 for i in range(len(mol_ids))] - for all_idx, unique_idx in enumerate(inv): - features[unique_idx] = all_lists["features"][all_idx] - return mol_ids, smiles, labels, features - else: - return mol_ids, smiles, labels - - def _get_all_lists_ids(self, datasets: Dict[str, SingleTaskDataset]) -> Dict[str, Any]: - all_smiles = [] - all_features = [] - all_labels = [] - all_mol_ids = [] - all_tasks = [] - - for task, ds in datasets.items(): - if len(ds) == 0: - continue - # Get data from single task dataset - ds_smiles = [ds[i]["smiles"] for i in range(len(ds))] - ds_labels = [ds[i]["labels"] for i in range(len(ds))] - if "unique_ids" in ds[0].keys(): - ds_mol_ids = [ds[i]["unique_ids"] for i in range(len(ds))] - else: - ds_mol_ids = smiles_to_unique_mol_ids( - ds_smiles, - n_jobs=self.n_jobs, - featurization_batch_size=self.featurization_batch_size, - backend=self.backend, - progress=self.progress, - progress_desc=f"{task}: mol to ids", - ) - if "features" in ds[0]: - ds_features = [ds[i]["features"] for i in range(len(ds))] - else: - ds_features = None - all_smiles.extend(ds_smiles) - all_labels.extend(ds_labels) - all_mol_ids.extend(ds_mol_ids) - if ds_features is not None: - all_features.extend(ds_features) - - task_list = [task] * ds.__len__() - all_tasks.extend(task_list) - - all_lists = { - "smiles": all_smiles, - "features": all_features, - "labels": all_labels, - "mol_ids": all_mol_ids, - "tasks": all_tasks, - } - - return all_lists - - def _get_inv_of_mol_ids(self, all_mol_ids): - mol_ids, inv = np.unique(all_mol_ids, return_inverse=True) - return mol_ids, inv - - def _find_valid_label(self, task, ds): - r""" - For a given dataset, find a genuine label for that dataset - """ - valid_label = None - for i in range(len(ds)): - if ds[i] is not None: - valid_label = ds[i]["labels"] - break - - if valid_label is None: - raise ValueError(f"Dataset for task {task} has no valid labels.") - - return valid_label - - def set_label_size_dict(self, datasets: Dict[str, SingleTaskDataset]): - r""" - This gives the number of labels to predict for a given task. - """ - task_labels_size = {} - for task, ds in datasets.items(): - if len(ds) == 0: - continue - - valid_label = self._find_valid_label(task, ds) - - # Assume for a fixed task, the label dimension is the same across data points - torch_label = torch.as_tensor(valid_label) - - # First dimension is graph-specific - task_labels_size[task] = torch_label.size() - return task_labels_size - - def set_label_dtype_dict(self, datasets: Dict[str, SingleTaskDataset]): - r""" - Gets correct dtype for a given label - """ - task_labels_dtype = {} - for task, ds in datasets.items(): - if len(ds) == 0: - continue - - valid_label = self._find_valid_label(task, ds) - - torch_label = torch.as_tensor(valid_label) - task_labels_dtype[task] = torch_label.dtype - return task_labels_dtype + return data_dict def __repr__(self) -> str: """ @@ -643,11 +252,6 @@ def __repr__(self) -> str: ) return out_str - # Faster to compute the statistics if we unbatch first. - features = self.features - if isinstance(self.features, Batch): - self.features = self.features.to_data_list() - out_str = ( f"-------------------\n{self.__class__.__name__}\n" + f"\tabout = {self.about}\n" @@ -665,111 +269,33 @@ def __repr__(self) -> str: + f"-------------------\n" ) - # Restore the original features. - self.features = features - return out_str -class FakeDataset(MultitaskDataset): - """ - A dataset to hold the fake data. - """ - - def __init__( - self, datasets: Dict[str, SingleTaskDataset], num_mols: int = 1234, indexing_same_elem: bool = False - ): - """ - Parameters: - datasets: - A dictionary of datasets. The keys are the task names and the values are the datasets. - num_mols: - The number of molecules to generate. In reality, it is the same molecule, - but `num_mols` will change the length of the dataset. - indexing_same_elem: - If True, the same molecule is used for all samples. - Otherwise, a deepcopied molecule is used for each sample. - """ - self.indexing_same_elem = indexing_same_elem - self.num_mols = num_mols - self.num_datasets = len(datasets) - - self.about = "FakeDatasets" - task = next(iter(datasets)) - if "features" in datasets[task][0]: - self.mol_ids, self.smiles, self.labels, self.features = self.merge(datasets) - if self.indexing_same_elem is False: - self.mol_ids, self.smiles, self.labels, self.features = self.deepcopy_mol( - self.mol_ids, self.smiles, self.labels, self.features - ) - else: - self.mol_ids, self.smiles, self.labels = self.merge(datasets) - if self.indexing_same_elem is False: - self.mol_ids, self.smiles, self.labels, _ = self.deepcopy_mol( - self.mol_ids, self.smiles, self.labels - ) - - self.labels_size = self.set_label_size_dict(datasets) - self.labels_dtype = self.set_label_dtype_dict(datasets) - self.features = self.features - - def _get_inv_of_mol_ids(self, all_mol_ids): - # The generated data is a single molecule duplicated - mol_ids = np.array(all_mol_ids) - inv = [_ for _ in range(len(mol_ids) // self.num_datasets)] * self.num_datasets - mol_ids = np.unique(inv) - return mol_ids, inv - - def deepcopy_mol(self, mol_ids, labels, smiles, features=None): - """ - Create a deepcopy of the single molecule num_mols times - - Args: - mol_ids (array): The single value for the mol ID - labels (List[Dict]): List containing one dict with the label name-value pairs - smiles (List[List[str]]): List of list containing SMILE sting - features (List[Data], optional): list containing Data object. Defaults to None. - - Returns: - The deep copy of the inputs - """ - logger.info("Duplicating the single dataset element...") - mol_ids = [deepcopy(mol_ids[0]) for _ in range(self.num_mols)] - logger.info("Finished `mol_ids`") - labels = [deepcopy(labels[0]) for _ in range(self.num_mols)] - logger.info("Finished `labels`") - smiles = [deepcopy(smiles[0]) for _ in range(self.num_mols)] - logger.info("Finished `smiles`") - if features is not None: - features = [deepcopy(features[0]) for _ in range(self.num_mols)] - logger.info("Finished `features`") - return mol_ids, labels, smiles, features - - def __len__(self): - r""" - Returns the number of molecules - """ - return self.num_mols - - def __getitem__(self, idx): - r""" - get the data for at the specified index - Parameters: - idx: The index of the data to retrieve - Returns: - A dictionary containing the data for the specified index with keys "mol_ids", "smiles", "labels", and "features" - """ - datum = {} - if self.indexing_same_elem is True: - # If using a single memory location override the idx value passed - idx = 0 - if self.labels is not None: - datum["labels"] = self.labels[idx] - - if self.features is not None: - datum["features"] = self.features[idx] - - return datum +def torch_enum_to_dtype(v: Union[int, torch.dtype]): + if isinstance(v, torch.dtype): + return v + + mapping = [ + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + torch.complex32, + torch.complex64, + torch.complex128, + torch.bool, + torch.qint8, + torch.quint8, + torch.qint32, + torch.bfloat16, + torch.quint4x2, + ] + return mapping[v] if (v >= 0 and v < len(mapping)) else None def get_num_nodes_per_graph(graphs): diff --git a/graphium/data/multilevel_utils.py b/graphium/data/multilevel_utils.py index 7f9ed5813..a096979dd 100644 --- a/graphium/data/multilevel_utils.py +++ b/graphium/data/multilevel_utils.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -21,45 +21,110 @@ def extract_labels(df: pd.DataFrame, task_level: str, label_cols: List[str]): - """Extracts labels in label_cols from dataframe df for a given task_level. - Returns a list of numpy arrays converted to the correct shape. Multiple - targets are concatenated for each graph. + """Extracts the labels specified by label_cols from dataframe df. + If task_level is "graph", each entry in df must be a single numeric value, + and this function returns a single, 2D numpy array containing the data. + If task_level is something else, each entry in df must be a numpy array, + python list, or single numeric value, and this function returns both a 2D + numpy array of data and a 1D numpy array of integers indicating the row + number in the first array where each molecule's data starts, with an extra + integer at the end that should equal the total number of rows in the first + array. The first array can have type float16, float32, or float64, + depending on the largest precision of input data, and arrays of varying + sizes across columns are padded with nan values, so that a single molecule + occupies a fixed number of rows and len(label_cols) columns. """ - def unpack(graph_data): - graph_data = pd.to_numeric(graph_data, errors="coerce") - if isinstance(graph_data, str): - graph_data_list = ast.literal_eval(graph_data) - return np.array(graph_data_list) - elif isinstance(graph_data, (int, float)): - return np.array([graph_data]) - elif isinstance(graph_data, list): - return np.array(graph_data) - elif isinstance(graph_data, np.ndarray): - if len(graph_data.shape) == 0: - graph_data = np.expand_dims(graph_data, 0) - if graph_data.shape[0] == 0: - graph_data = np.array([np.nan]) - # TODO: Warning - return graph_data - else: - raise ValueError( - f"Graph data should be one of str, float, int, list, np.ndarray, got {type(graph_data)}" - ) - - def unpack_column(data: pd.Series): - return data.apply(unpack) - - def merge_columns(data: pd.Series): - data = data.to_list() - data = [np.array([np.nan]) if not isinstance(d, np.ndarray) and math.isnan(d) else d for d in data] - padded_data = itertools.zip_longest(*data, fillvalue=np.nan) - data = np.stack(list(padded_data), 1).T - return data - - unpacked_df: pd.DataFrame = df[label_cols].apply(unpack_column) - output = unpacked_df.apply(merge_columns, axis="columns").to_list() + num_rows = df.shape[0] + num_cols = len(label_cols) if task_level == "graph": - return np.concatenate(output) - return output + output = np.empty((num_rows, num_cols), dtype=np.float64) + + for col_index, col in enumerate(label_cols): + for i, v in enumerate(df[col]): + if isinstance(v, float): + output[i, col_index] = v + continue + + v = pd.to_numeric(v, errors="coerce") + + if isinstance(v, (int, float)): + output[i, col_index] = v + + else: + raise ValueError(f"Graph data should be one of float or int, got {type(v)}") + + return output, None + + # First, find the max length of each row (likely the number of nodes or edges) + # +1 is for the cumulative sum below + begin_offsets = np.zeros((num_rows + 1,), dtype=np.int64) + max_type = np.float16 + for col in label_cols: + for i, v in enumerate(df[col]): + if not isinstance(v, np.ndarray) and not isinstance(v, (int, float, list)): + v = pd.to_numeric(v, errors="coerce") + length = 0 + if isinstance(v, np.ndarray): + if len(v.shape) == 1: + length = v.shape[0] + elif len(v.shape) == 0: + length = 0 + else: + raise ValueError( + f"Graph data should be 1D np.ndarray, got ndarray with {len(v.shape)} dimensions" + ) + dtype = v.dtype + if dtype == np.float64: + max_type = np.float64 + elif dtype == np.float32 and max_type == np.float16: + max_type = np.float32 + elif isinstance(v, (int, float)): + length = 1 + max_type = np.float64 + elif isinstance(v, list): + length = len(v) + max_type = np.float64 + else: + raise ValueError(f"Graph data should be one of float, int, list, np.ndarray, got {type(v)}") + # The +1 is so that the cumulative sum below gives the beginning offsets + begin_offsets[i + 1] = max(begin_offsets[i + 1], length) + + begin_offsets = np.cumsum(begin_offsets) + full_num_rows = begin_offsets[-1] + + output = np.empty((full_num_rows, num_cols), dtype=max_type) + + # Now, fill in the values + for col_index, col in enumerate(label_cols): + for i, v in enumerate(df[col]): + full_row = begin_offsets[i] + end_row = begin_offsets[i + 1] + + if not isinstance(v, np.ndarray): + v = pd.to_numeric(v, errors="coerce") + + if isinstance(v, np.ndarray): + length = v.shape[0] if len(v.shape) == 1 else 0 + for j in range(length): + output[full_row + j, col_index] = v[j] + + elif isinstance(v, (int, float)): + length = 1 + output[full_row, col_index] = v + + elif isinstance(v, list): + length = len(v) + for j in range(length): + output[full_row + j, col_index] = v[j] + + else: + raise ValueError(f"Graph data should be one of float, int, list, np.ndarray, got {type(v)}") + + # Fill the rest of the rows in the column with nan + if full_row + length != end_row: + for row in range(full_row + length, end_row): + output[row, col_index] = np.nan + + return output, begin_offsets diff --git a/graphium/data/normalization.py b/graphium/data/normalization.py index 994e8939b..d2e8444b0 100644 --- a/graphium/data/normalization.py +++ b/graphium/data/normalization.py @@ -57,6 +57,12 @@ def __init__( self.data_mean = None self.data_std = None + def set_statistics(self, data_min, data_max, data_mean, data_std): + self.data_min = data_min + self.data_max = data_max + self.data_mean = data_mean + self.data_std = data_std + def calculate_statistics(self, array): """ Saves the normalization parameters (e.g. mean and variance) to the object. @@ -106,13 +112,11 @@ def denormalize(self, input): return input elif self.method == "normal": mean, std = torch.tensor(self.data_mean), torch.tensor(self.data_std) - if input.device.type != "ipu": # Cast to device if not on IPU - mean, std = mean.to(input.device), std.to(input.device) + mean, std = mean.to(input.device), std.to(input.device) return (input * std) + mean elif self.method == "unit": dmax, dmin = torch.tensor(self.data_max), torch.tensor(self.data_min) - if input.device.type != "ipu": # Cast to device if not on IPU - dmax, dmin = dmax.to(input.device), dmin.to(input.device) + dmax, dmin = dmax.to(input.device), dmin.to(input.device) return input * (dmax - dmin) + dmin else: raise ValueError(f"normalization method {self.method} not recognised.") diff --git a/graphium/data/utils.py b/graphium/data/utils.py index aa5151a90..5136ce60e 100644 --- a/graphium/data/utils.py +++ b/graphium/data/utils.py @@ -25,7 +25,6 @@ import graphium from torch_geometric.data import Data -from graphium.features.featurizer import GraphDict GRAPHIUM_DATASETS_BASE_URL = "gs://graphium-public/datasets" GRAPHIUM_DATASETS = { @@ -129,7 +128,7 @@ def get_keys(pyg_data): return pyg_data.keys() -def found_size_mismatch(task: str, features: Union[Data, GraphDict], labels: np.ndarray, smiles: str) -> bool: +def found_size_mismatch(task: str, features: Data, labels: np.ndarray, smiles: str) -> bool: """Check if a size mismatch exists between features and labels with respect to node/edge/nodepair. Args: diff --git a/graphium/features/README.md b/graphium/features/README.md index 4188948fe..14b123106 100644 --- a/graphium/features/README.md +++ b/graphium/features/README.md @@ -7,8 +7,5 @@ ## What is in this folder? - ✅ `featurizer.py`: featurization code for the molecules, adding node, edge and graph features to the mol object -- `nmp.py`: check if a string can be converted to float, helper function for featurization -- `positional_encoding.py`: code for computing all raw positional and structural encoding of the graph, see `graph_positional_encoder` function -- `properties.py`: code for computing properties of the molecule -- `rw.py`: code for computing random walk positional encoding -- `spectral.py`: code for computing the spectral positional encoding such as the Laplacian eigenvalues and eigenvectors \ No newline at end of file + +Positional encodings, and atom/bond features (`nmp.py`) have been moved to the `/graphium_cpp` folder. \ No newline at end of file diff --git a/graphium/features/__init__.py b/graphium/features/__init__.py index 40984a2a4..e9cb41d1f 100644 --- a/graphium/features/__init__.py +++ b/graphium/features/__init__.py @@ -1,9 +1,2 @@ -from .featurizer import get_mol_atomic_features_onehot -from .featurizer import get_mol_atomic_features_float -from .featurizer import get_mol_edge_features -from .featurizer import mol_to_adj_and_features -from .featurizer import mol_to_graph_dict from .featurizer import mol_to_graph_signature -from .featurizer import GraphDict from .featurizer import mol_to_pyggraph -from .featurizer import to_dense_array diff --git a/graphium/features/commute.py b/graphium/features/commute.py deleted file mode 100644 index a7cea768c..000000000 --- a/graphium/features/commute.py +++ /dev/null @@ -1,69 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Dict, Any - -import numpy as np - -from scipy.sparse import spmatrix, issparse -from scipy.linalg import pinv - - -def compute_commute_distances( - adj: Union[np.ndarray, spmatrix], num_nodes: int, cache: Dict[str, Any] -) -> Tuple[np.ndarray, str, Dict[str, Any]]: - """ - Compute avg. commute time/distance between nodepairs. This is the avg. number of steps a random walker, starting - at node i, will take before reaching a given node j for the first time, and then return to node i. - - Reference: Saerens et al. "The principal components analysis of a graph, and its relationships to spectral clustering." ECML. 2004. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - Returns: - dist [num_nodes, num_nodes]: 2D array with avg. commute distances between nodepairs - base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here nodepair - cache: Updated dictionary of cached objects - """ - - base_level = "nodepair" - - if "commute" in cache: - dist = cache["commute"] - - else: - if issparse(adj): - adj = adj.toarray() - - volG = adj.sum() - - if "pinvL" in cache: - pinvL = cache["pinvL"] - - else: - L = np.diagflat(np.sum(adj, axis=1)) - adj - pinvL = pinv(L) - cache["pinvL"] = pinvL - - dist = volG * np.asarray( - [ - [pinvL[i, i] + pinvL[j, j] - 2 * pinvL[i, j] for j in range(num_nodes)] - for i in range(num_nodes) - ] - ) - cache["commute"] = dist - - return dist, base_level, cache diff --git a/graphium/features/electrostatic.py b/graphium/features/electrostatic.py deleted file mode 100644 index 58dc115f7..000000000 --- a/graphium/features/electrostatic.py +++ /dev/null @@ -1,58 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Dict, Any - -import numpy as np - -from scipy.linalg import pinv -from scipy.sparse import spmatrix, issparse - - -def compute_electrostatic_interactions( - adj: Union[np.ndarray, spmatrix], cache: Dict[str, Any] -) -> Tuple[np.ndarray, str, Dict[str, Any]]: - """ - Compute electrostatic interaction of nodepairs. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix - cache: Dictionary of cached objects - Returns: - electrostatic [num_nodes, num_nodes]: 2D array with electrostatic interactions of node nodepairs - base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here nodepair - cache: Updated dictionary of cached objects - """ - - base_level = "nodepair" - - if "electrostatic" in cache: - electrostatic = cache["electrostatic"] - - else: - if "pinvL" in cache: - pinvL = cache["pinvL"] - - else: - if issparse(adj): - adj = adj.toarray() - - L = np.diagflat(np.sum(adj, axis=1)) - adj - pinvL = pinv(L) - cache["pinvL"] = pinvL - - electrostatic = pinvL - np.diag(pinvL) # This means that the "ground" is set to any given atom - cache["electrostatic"] = electrostatic - - return electrostatic, base_level, cache diff --git a/graphium/features/featurizer.py b/graphium/features/featurizer.py index 8d8e18159..21d874de1 100644 --- a/graphium/features/featurizer.py +++ b/graphium/features/featurizer.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -23,966 +23,27 @@ from torch_geometric.data import Data -from rdkit import Chem -import datamol as dm +import graphium_cpp -from graphium.features import nmp -from graphium.utils.tensor import one_of_k_encoding -from graphium.features.positional_encoding import get_all_positional_encodings +# These are the integers that correspond with the torch data types in C++ +NP_DTYPE_TO_TORCH_INT = {np.float16: 5, np.float32: 6, np.float64: 7} -def to_dense_array(array: np.ndarray, dtype: str = None) -> np.ndarray: - r""" - Assign the node data - Parameters: - array: The array to convert to dense - dtype: The dtype of the array - Returns: - The dense array - """ - if array is not None: - if issparse(array): - if array.dtype == np.float16: # float16 doesn't support `todense` - array = array.astype(np.float32) - array = array.todense() - - if dtype is not None: - array = array.astype(dtype) - return array - - -def to_dense_tensor(tensor: Tensor, dtype: str = None) -> Tensor: - r""" - Assign the node data - Parameters: - array: The array to convert to dense - dtype: The dtype of the array - Returns: - The dense array - """ - if tensor is not None: - if tensor.is_sparse: - tensor = tensor.todense() - if dtype is not None: - tensor = tensor.to(dtype) - return tensor - - -def _mask_nans_inf(mask_nan: Optional[str], array: np.ndarray, array_name: str) -> np.ndarray: - r""" - mask the NaNs in the array - Parameters: - mask_nan: How to mask the NaNs - array: The array to mask - array_name: The name of the array - Returns: - The masked array - """ - if (mask_nan is None) or (array is None): - return array - - new_array = array - if issparse(new_array): - new_array = new_array.data - nans = ~np.isfinite(new_array) - - # Mask the NaNs - if nans.any(): - msg = f"There are {np.sum(nans)} NaNs in `{array_name}`" - if mask_nan == "raise": - raise ValueError(msg) - elif mask_nan == "warn": - logger.warning(msg) - else: - new_array[nans] = mask_nan - if issparse(array): - array.data = new_array - new_array = array - return new_array - - -def get_mol_atomic_features_onehot(mol: dm.Mol, property_list: List[str]) -> Dict[str, Tensor]: - r""" - Get the following set of features for any given atom - - * One-hot representation of the atom - * One-hot representation of the atom degree - * One-hot representation of the atom implicit valence - * One-hot representation of the the atom hybridization - * Whether the atom is aromatic - * The atom's formal charge - * The atom's number of radical electrons - - Additionally, the following features can be set, depending on the value of input Parameters - - * One-hot representation of the number of hydrogen atom in the the current atom neighborhood if `explicit_H` is false - * One-hot encoding of the atom chirality, and whether such configuration is even possible - - Parameters: - - mol: - molecule from which to extract the properties - - property_list: - A list of integer atomic properties to get from the molecule. - The integer values are converted to a one-hot vector. - Callables are not supported by this function. - - Accepted properties are: - - - "atomic-number" - - "degree" - - "valence", "total-valence" - - "implicit-valence" - - "hybridization" - - "chirality" - - "phase" - - "type" - - "group" - - "period" - - Returns: - prop_dict: - A dictionnary where the element of ``property_list`` are the keys - and the values are np.ndarray of shape (N, OH). N is the number of atoms - in ``mol`` and OH the lenght of the one-hot encoding. - - """ - - prop_dict = {} - - for prop in property_list: - prop = prop.lower() - prop_name = prop - - property_array = [] - for ii, atom in enumerate(mol.GetAtoms()): - if prop in ["atomic-number"]: - one_hot = one_of_k_encoding(atom.GetSymbol(), nmp.ATOM_LIST) - elif prop in ["degree"]: - one_hot = one_of_k_encoding(atom.GetDegree(), nmp.ATOM_DEGREE_LIST) - elif prop in ["valence", "total-valence"]: - prop_name = "valence" - one_hot = one_of_k_encoding(atom.GetTotalValence(), nmp.VALENCE) - elif prop in ["implicit-valence"]: - one_hot = one_of_k_encoding(atom.GetImplicitValence(), nmp.VALENCE) - elif prop in ["hybridization"]: - one_hot = one_of_k_encoding(atom.GetHybridization(), nmp.HYBRIDIZATION_LIST) - elif prop in ["chirality"]: - try: - one_hot = one_of_k_encoding(atom.GetProp("_CIPCode"), nmp.CHIRALITY_LIST) - one_hot.append(int(atom.HasProp("_ChiralityPossible"))) - except: - one_hot = [0, 0, int(atom.HasProp("_ChiralityPossible"))] - elif prop in "phase": - one_hot = one_of_k_encoding(nmp.PHASE[atom.GetAtomicNum() - 1], nmp.PHASE_SET) - elif prop in "type": - one_hot = one_of_k_encoding(nmp.TYPE[atom.GetAtomicNum() - 1], nmp.TYPE_SET) - elif prop in "group": - one_hot = one_of_k_encoding(nmp.GROUP[atom.GetAtomicNum() - 1], nmp.GROUP_SET) - elif prop in "period": - one_hot = one_of_k_encoding(nmp.PERIOD[atom.GetAtomicNum() - 1], nmp.PERIOD_SET) - else: - raise ValueError(f"Unsupported property `{prop}`") - - property_array.append(np.asarray(one_hot, dtype=np.float16)) - - prop_dict[prop_name] = np.stack(property_array, axis=0) - - return prop_dict - - -def get_mol_conformer_features( - mol: dm.Mol, - property_list: Union[List[str], List[Callable]], - mask_nan: Optional[Union[float, str]] = None, -) -> Dict[str, np.ndarray]: - r"""obtain the conformer features of a molecule - Parameters: - - mol: - molecule from which to extract the properties - - property_list: - A list of conformer property to get from the molecule - Accepted properties are: - - "positions_3d" - - Returns: - prop_dict: a dictionary where the element of ``property_list`` are the keys - """ - prop_dict = {} - has_conf = True - - try: - mol.GetConformer() - except: - has_conf = False - # * currently only accepts "positions_3d", raise errors otherwise - for prop in property_list: - if isinstance(prop, str): - if prop in ["positions_3d"]: # locating 3d conformer coordinates - if not has_conf: - positions = np.full((mol.GetNumAtoms(), 3), float("nan"), dtype=np.float16) - else: - positions = [[], [], []] - for i in range(mol.GetNumAtoms()): - pos = mol.GetConformer().GetAtomPosition(i) - positions[0].append(pos.x) - positions[1].append(pos.y) - positions[2].append(pos.z) - positions = np.asarray(positions, dtype=np.float16).T - prop_dict[prop] = positions - else: - raise ValueError( - str(prop) + " is not currently supported as a conformer property in `property_list`" - ) - else: - raise ValueError(f"Elements in `property_list` must be str or callable, provided `{type(prop)}`") - - prop_dict[prop] = _mask_nans_inf(mask_nan, prop_dict[prop], prop) - - return prop_dict - - -def get_mol_atomic_features_float( - mol: dm.Mol, - property_list: Union[List[str], List[Callable]], - offset_carbon: bool = True, - mask_nan: Union[str, float, type(None)] = "raise", -) -> Dict[str, np.ndarray]: - r""" - Get a dictionary of floating-point arrays of atomic properties. - To ensure all properties are at a similar scale, some of the properties - are divided by a constant. - - There is also the possibility of offseting by the carbon value using - the `offset_carbon` parameter. - - Parameters: - - mol: - molecule from which to extract the properties - - property_list: - A list of atomic properties to get from the molecule, such as 'atomic-number', - 'mass', 'valence', 'degree', 'electronegativity'. - Some elements are divided by a factor to avoid feature explosion. - - Accepted properties are: - - - "atomic-number" - - "mass", "weight" - - "valence", "total-valence" - - "implicit-valence" - - "hybridization" - - "chirality" - - "hybridization" - - "aromatic" - - "ring", "in-ring" - - "min-ring" - - "max-ring" - - "num-ring" - - "degree" - - "radical-electron" - - "formal-charge" - - "vdw-radius" - - "covalent-radius" - - "electronegativity" - - "ionization", "first-ionization" - - "melting-point" - - "metal" - - "single-bond" - - "aromatic-bond" - - "double-bond" - - "triple-bond" - - "is-carbon" - - "group" - - "period" - - offset_carbon: - Whether to subract the Carbon property from the desired atomic property. - For example, if we want the mass of the Lithium (6.941), the mass of the - Carbon (12.0107) will be subracted, resulting in a value of -5.0697 - - mask_nan: - Deal with molecules that fail a part of the featurization. - NaNs can happen when taking the of a noble gas, - or other properties that are not measured for specific atoms. - - - "raise": Raise an error when there is a nan or inf in the featurization - - "warn": Raise a warning when there is a nan or inf in the featurization - - "None": DEFAULT. Don't do anything - - "Floating value": Replace nans or inf by the specified value - - Returns: - - prop_dict: - A dictionnary where the element of ``property_list`` are the keys - and the values are np.ndarray of shape (N,). N is the number of atoms - in ``mol``. - - """ - - periodic_table = Chem.GetPeriodicTable() - prop_dict = {} - C = Chem.Atom("C") - C_num = C.GetAtomicNum() - offC = bool(offset_carbon) - atom_list = list(mol.GetAtoms()) - - for prop in property_list: - prop_name = None - - property_array = np.zeros(mol.GetNumAtoms(), dtype=np.float16) - for ii, atom in enumerate(atom_list): - val = None - atomic_num = atom.GetAtomicNum() - - if isinstance(prop, str): - prop = prop.lower() - prop_name = prop - - if prop in ["atomic-number"]: - val = (atomic_num - (offC * C_num)) / 5 - elif prop in ["mass", "weight"]: - prop_name = "mass" - val = (atom.GetMass() - (offC * C.GetMass())) / 10 - elif prop in ["valence", "total-valence"]: - prop_name = "valence" - val = atom.GetTotalValence() - (offC * 4) - elif prop in ["implicit-valence"]: - val = atom.GetImplicitValence() - elif prop in ["hybridization"]: - val = atom.GetHybridization() - elif prop in ["chirality"]: - val = (atom.GetProp("_CIPCode") == "R") if atom.HasProp("_CIPCode") else 2 - elif prop in ["hybridization"]: - val = atom.GetHybridization() - elif prop in ["aromatic"]: - val = atom.GetIsAromatic() - elif prop in ["ring", "in-ring"]: - prop_name = "in-ring" - val = atom.IsInRing() - elif prop in ["min-ring"]: - ring_info = mol.GetRingInfo() - val = ring_info.MinAtomRingSize(atom.GetIdx()) - elif prop in ["max-ring"]: - rings = mol.GetRingInfo().AtomRings() - val = 0 - for ring in rings: - if atom.GetIdx() in ring: - if len(ring) > val: - val = len(ring) - elif prop in ["num-ring"]: - ring_info = mol.GetRingInfo() - val = ring_info.NumAtomRings(atom.GetIdx()) - elif prop in ["degree"]: - val = atom.GetTotalDegree() - (offC * 2) - elif prop in ["radical-electron"]: - val = atom.GetNumRadicalElectrons() - elif prop in ["formal-charge"]: - val = atom.GetFormalCharge() - elif prop in ["vdw-radius"]: - val = periodic_table.GetRvdw(atom.GetAtomicNum()) - offC * periodic_table.GetRvdw(C_num) - elif prop in ["covalent-radius"]: - val = periodic_table.GetRcovalent(atomic_num) - offC * periodic_table.GetRcovalent(C_num) - elif prop in ["electronegativity"]: - val = ( - nmp.ELECTRONEGATIVITY[atom.GetAtomicNum() - 1] - - offC * nmp.ELECTRONEGATIVITY[C_num - 1] - ) - elif prop in ["ionization", "first-ionization"]: - prop_name = "ionization" - val = (nmp.FIRST_IONIZATION[atomic_num - 1] - offC * nmp.FIRST_IONIZATION[C_num - 1]) / 5 - elif prop in ["melting-point"]: - val = (nmp.MELTING_POINT[atomic_num - 1] - offC * nmp.MELTING_POINT[C_num - 1]) / 200 - elif prop in ["metal"]: - val = nmp.METAL[atomic_num - 1] - elif prop in "group": - val = float(nmp.GROUP[atomic_num - 1]) - offC * float(nmp.GROUP[C_num - 1]) - elif prop in "period": - val = float(nmp.PERIOD[atomic_num - 1]) - offC * float(nmp.PERIOD[C_num - 1]) - elif "-bond" in prop: - bonds = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()] - if prop in ["single-bond"]: - val = len([bond == 1 for bond in bonds]) - elif prop in ["aromatic-bond"]: - val = len([bond == 1.5 for bond in bonds]) - elif prop in ["double-bond"]: - val = len([bond == 2 for bond in bonds]) - elif prop in ["triple-bond"]: - val = len([bond == 3 for bond in bonds]) - else: - raise ValueError(f"{prop} is not a correct bond.") - val -= offC * 1 - elif prop in ["is-carbon"]: - val = atom.GetAtomicNum() == 6 - val -= offC * 1 - else: - raise ValueError(f"Unsupported property `{prop}`") - - elif callable(prop): - prop_name = str(prop) - val = prop(atom) - else: - ValueError(f"Elements in `property_list` must be str or callable, provided `{type(prop)}`") - - if val is None: - raise ValueError("val is undefined.") - - property_array[ii] = val - - if prop_name is None: - raise ValueError("prop_name is undefined.") - - # Mask the NaNs - prop_dict[prop_name] = _mask_nans_inf(mask_nan, property_array, "atom featurization") - - return prop_dict - - -def get_simple_mol_conformer(mol: dm.Mol) -> Union[Chem.rdchem.Conformer, None]: - r""" - If the molecule has a conformer, then it will return the conformer at idx `0`. - Otherwise, it generates a simple molecule conformer using `rdkit.Chem.rdDistGeom.EmbedMolecule` - and returns it. This is meant to be used in simple functions like `GetBondLength`, - not in functions requiring complex 3D structure. - - Parameters: - - mol: Rdkit Molecule - - Returns: - conf: A conformer of the molecule, or `None` if it fails - """ - - val = 0 - if mol.GetNumConformers() == 0: - val = Chem.rdDistGeom.EmbedMolecule(mol) - if val == -1: - val = Chem.rdDistGeom.EmbedMolecule( - mol, - enforceChirality=False, - ignoreSmoothingFailures=True, - useBasicKnowledge=True, - useExpTorsionAnglePrefs=True, - forceTol=0.1, - ) - - if val == -1: - conf = None - logger.warn("Couldn't compute conformer for molecule `{}`".format(Chem.MolToSmiles(mol))) - else: - conf = mol.GetConformer(0) - - return conf - - -def get_estimated_bond_length(bond: Chem.rdchem.Bond, mol: dm.Mol) -> float: - r""" - Estimate the bond length between atoms by looking at the estimated atomic radius - that depends both on the atom type and the bond type. The resulting bond-length is - then the sum of the radius. - - Keep in mind that this function only provides an estimate of the bond length and not - the true one based on a conformer. The vast majority od estimated bond lengths will - have an error below 5% while some bonds can have an error up to 20%. This function - is mostly useful when conformer generation fails for some molecules, or for - increased computation speed. - - Parameters: - bond: The bond to measure its lenght - mol: The molecule containing the bond (used to get neighbouring atoms) - - Returns: - bond_length: The bond length in Angstrom, typically a value around 1-2. - - """ - - # Get the atoms connected by the bond - idx1 = bond.GetBeginAtomIdx() - idx2 = bond.GetEndAtomIdx() - atom1 = mol.GetAtomWithIdx(idx1).GetAtomicNum() - atom2 = mol.GetAtomWithIdx(idx2).GetAtomicNum() - bond_type = bond.GetBondType() - - # Get single bond atomic radius - if bond_type == Chem.rdchem.BondType.SINGLE: - rad1 = [nmp.BOND_RADIUS_SINGLE[atom1 - 1]] - rad2 = [nmp.BOND_RADIUS_SINGLE[atom2 - 1]] - # Get double bond atomic radius - elif bond_type == Chem.rdchem.BondType.DOUBLE: - rad1 = [nmp.BOND_RADIUS_DOUBLE[atom1 - 1]] - rad2 = [nmp.BOND_RADIUS_DOUBLE[atom2 - 1]] - # Get triple bond atomic radius - elif bond_type == Chem.rdchem.BondType.TRIPLE: - rad1 = [nmp.BOND_RADIUS_TRIPLE[atom1 - 1]] - rad2 = [nmp.BOND_RADIUS_TRIPLE[atom2 - 1]] - # Get average of single bond and double bond atomic radius - elif bond_type == Chem.rdchem.BondType.AROMATIC: - rad1 = [nmp.BOND_RADIUS_SINGLE[atom1 - 1], nmp.BOND_RADIUS_DOUBLE[atom1 - 1]] - rad2 = [nmp.BOND_RADIUS_SINGLE[atom2 - 1], nmp.BOND_RADIUS_DOUBLE[atom2 - 1]] - - # Average the bond lengths, while ignoring nans in case some missing value - rad1_float = [elem for elem in rad1 if elem is not None] - rad2_float = [elem for elem in rad2 if elem is not None] - - if len(rad1_float) > 0: - rad1_float = sum(rad1_float) / len(rad1_float) - else: - rad1_float = float(nmp.BOND_RADIUS_SINGLE[atom1 - 1]) - - if len(rad2_float) > 0: - rad2_float = sum(rad2_float) / len(rad2_float) - else: - rad2_float = float(nmp.BOND_RADIUS_SINGLE[atom2 - 1]) - - bond_length = rad1_float + rad2_float - return bond_length - - -def get_mol_edge_features( - mol: dm.Mol, property_list: List[str], mask_nan: Union[str, float, type(None)] = "raise" -) -> Dict[str, np.ndarray]: - r""" - Get the following set of features for any given bond - See `graphium.features.nmp` for allowed values in one hot encoding - - * One-hot representation of the bond type. Note that you should not kekulize your - molecules, if you expect this to take aromatic bond into account. - * Bond stereo type, following CIP classification - * Whether the bond is conjugated - * Whether the bond is in a ring - - Parameters: - mol: rdkit.Chem.Molecule - the molecule of interest - - property_list: - A list of edge properties to return for the given molecule. - Accepted properties are: - - - "bond-type-onehot" - - "bond-type-float" - - "stereo" - - "in-ring" - - "conjugated" - - "conformer-bond-length" (might cause problems with complex molecules) - - "estimated-bond-length" - - Returns: - prop_dict: - A dictionnary where the element of ``property_list`` are the keys - and the values are np.ndarray of shape (N,). N is the number of atoms - in ``mol``. - - """ - - prop_dict = {} - - # Compute features for each bond - num_bonds = mol.GetNumBonds() - for prop in property_list: - property_array = [] - for ii in range(num_bonds): - prop = prop.lower() - bond = mol.GetBondWithIdx(ii) - - if prop in ["bond-type-onehot"]: - encoding = one_of_k_encoding(bond.GetBondType(), nmp.BOND_TYPES) - elif prop in ["bond-type-float"]: - encoding = [bond.GetBondTypeAsDouble()] - elif prop in ["stereo"]: - encoding = one_of_k_encoding(bond.GetStereo(), nmp.BOND_STEREO) - elif prop in ["in-ring"]: - encoding = [bond.IsInRing()] - elif prop in ["conjugated"]: - encoding = [bond.GetIsConjugated()] - elif prop in ["conformer-bond-length"]: - conf = get_simple_mol_conformer(mol) - if conf is not None: - idx1 = bond.GetBeginAtomIdx() - idx2 = bond.GetEndAtomIdx() - encoding = [Chem.rdMolTransforms.GetBondLength(conf, idx1, idx2)] - else: - encoding = [0] - elif prop in ["estimated-bond-length"]: - encoding = [get_estimated_bond_length(bond, mol)] - - else: - raise ValueError(f"Unsupported property `{prop}`") - - property_array.append(np.asarray(encoding, dtype=np.float16)) - - if num_bonds > 0: - property_array = np.stack(property_array, axis=0) - # Mask the NaNs - prop_dict[prop] = _mask_nans_inf(mask_nan, property_array, "edge property") - else: - # Add an empty vector with the right shape - arr_len = 1 - if prop in ["bond-type-onehot"]: - arr_len = len(nmp.BOND_TYPES) + 1 - elif prop in ["stereo"]: - arr_len = len(nmp.BOND_STEREO) + 1 - - prop_dict[prop] = np.zeros((0, arr_len)) - - return prop_dict - - -def mol_to_adj_and_features( - mol: Union[str, dm.Mol], - atom_property_list_onehot: List[str] = [], - atom_property_list_float: List[Union[str, Callable]] = [], - conformer_property_list: List[str] = [], - edge_property_list: List[str] = [], - add_self_loop: bool = False, - explicit_H: bool = False, - use_bonds_weights: bool = False, - pos_encoding_as_features: Dict[str, Any] = None, - dtype: np.dtype = np.float16, - mask_nan: Union[str, float, type(None)] = "raise", -) -> Union[ - coo_matrix, - Union[Tensor, None], - Union[Tensor, None], - Dict[str, Tensor], - Union[Tensor, None], - Dict[str, Tensor], -]: - r""" - Transforms a molecule into an adjacency matrix representing the molecular graph - and a set of atom and bond features. - - It also returns the positional encodings associated to the graph. - - Parameters: - - mol: - The molecule to be converted - - atom_property_list_onehot: - List of the properties used to get one-hot encoding of the atom type, - such as the atom index represented as a one-hot vector. - See function `get_mol_atomic_features_onehot` - - atom_property_list_float: - List of the properties used to get floating-point encoding of the atom type, - such as the atomic mass or electronegativity. - See function `get_mol_atomic_features_float` - - conformer_property_list: - list of properties used to encode the conformer information, outside of atom properties, currently support "positions_3d" - - edge_property_list: - List of the properties used to encode the edges, such as the edge type - and the stereo type. - - add_self_loop: - Whether to add a value of `1` on the diagonal of the adjacency matrix. - - explicit_H: - Whether to consider the Hydrogens explicitely. If `False`, the hydrogens - are implicit. - - use_bonds_weights: - Whether to use the floating-point value of the bonds in the adjacency matrix, - such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5 - - pos_encoding_as_features: keyword arguments for function `graph_positional_encoder` - to generate positional encoding for node features. - - dtype: - The torch data type used to build the graph - - mask_nan: - Deal with molecules that fail a part of the featurization. - NaNs can happen when taking the of a noble gas, - or other properties that are not measured for specific atoms. - - - "raise": Raise an error when there is a nan or inf in the featurization - - "warn": Raise a warning when there is a nan or inf in the featurization - - "None": DEFAULT. Don't do anything - - "Floating value": Replace nans or inf by the specified value - Returns: - - adj: - torch coo sparse adjacency matrix of the molecule - - ndata: - Concatenated node data of the atoms, based on the properties from - `atom_property_list_onehot` and `atom_property_list_float`. - If no properties are given, it returns `None` - - edata: - Concatenated node edge of the molecule, based on the properties from - `edge_property_list`. - If no properties are given, it returns `None` - - pe_dict: - Dictionary of all positional encodings. Current supported keys: - - - "pos_enc_feats_sign_flip": - Node positional encoding that requires augmentation via sign-flip. - For example, eigenvectors of the Laplacian are ambiguous to the - sign and are returned here. - - - "pos_enc_feats_no_flip": - Node positional encoding that requires does not use sign-flip. - For example, distance from centroid are returned here. - - - "rwse": - Node structural encoding corresponding to the diagonal of the random - walk matrix - - conf_dict: - contains the 3d positions of a conformer of the molecule or 0s if none is found - - """ - - if isinstance(mol, str): - mol = dm.to_mol(mol, ordered=True) - - # Add or remove explicit hydrogens - if explicit_H: - mol = Chem.AddHs(mol) - else: - mol = Chem.RemoveHs(mol) - - num_nodes = mol.GetNumAtoms() - - adj = mol_to_adjacency_matrix( - mol, use_bonds_weights=use_bonds_weights, add_self_loop=add_self_loop, dtype=dtype - ) - - # Get the node features - atom_features_onehot = get_mol_atomic_features_onehot(mol, atom_property_list_onehot) - atom_features_float = get_mol_atomic_features_float(mol, atom_property_list_float, mask_nan=mask_nan) - conf_dict = get_mol_conformer_features(mol, conformer_property_list, mask_nan=mask_nan) - ndata = list(atom_features_float.values()) + list(atom_features_onehot.values()) - ndata = [d[:, np.newaxis] if d.ndim == 1 else d for d in ndata] - - if len(ndata) > 0: - ndata = np.concatenate(ndata, axis=1).astype(dtype=dtype) - else: - ndata = None - - # Get the edge features - edge_features = get_mol_edge_features(mol, edge_property_list, mask_nan=mask_nan) - edata = list(edge_features.values()) - edata = [np.expand_dims(d, axis=1) if d.ndim == 1 else d for d in edata] - if len(edata) > 0: - edata = np.concatenate(edata, axis=1).astype(dtype=dtype) - else: - edata = None - - # Get all positional encodings - pe_dict = get_all_positional_encodings(adj, num_nodes, pos_encoding_as_features) - - # Mask the NaNs - for pe_key, pe_val in pe_dict.items(): - pe_val = np.asarray(pe_val, dtype=dtype) - pe_dict[pe_key] = _mask_nans_inf(mask_nan, pe_val, pe_key) - - return adj, ndata, edata, pe_dict, conf_dict - - -def mol_to_adjacency_matrix( - mol: dm.Mol, - use_bonds_weights: bool = False, - add_self_loop: bool = False, - dtype: np.dtype = np.float32, -) -> coo_matrix: - r""" - Convert a molecule to a sparse adjacency matrix, as a torch Tensor. - Instead of using the Rdkit `GetAdjacencyMatrix()` method, this method - uses the bond ordering from the molecule object, which is the same as - the bond ordering in the bond features. - - Warning: - Do not use `Tensor.coalesce()` on the returned adjacency matrix, as it - will change the ordering of the bonds. - - Args: - mol: A molecule in the form of a SMILES string or an RDKit molecule object. - - use_bonds_weights: - If `True`, the adjacency matrix will contain the bond type as the - value of the edge. If `False`, the adjacency matrix will contain - `1` as the value of the edge. - - add_self_loop: - If `True`, the adjacency matrix will contain a self-loop for each - node. - - dtype: - The data type used to build the graph - - Returns: - adj: - coo sparse adjacency matrix of the molecule - """ - - # Get the indices for the adjacency matrix, and the bond value - adj_idx, adj_val = [], [] - for bond in mol.GetBonds(): - adj_idx.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) - adj_idx.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) - if use_bonds_weights: - val = nmp.BOND_TYPES[bond.GetBondType()] - else: - val = 1 - adj_val.extend([val, val]) - - # Convert to torch coo sparse tensor - if len(adj_val) > 0: # ensure tensor is not empty - adj = coo_matrix( - (torch.as_tensor(adj_val), torch.as_tensor(adj_idx).T.reshape(2, -1)), - shape=(mol.GetNumAtoms(), mol.GetNumAtoms()), - dtype=dtype, - ) - else: - # Special case for molecules with one atom - adj = coo_matrix(([], np.array([[], []])), shape=(mol.GetNumAtoms(), mol.GetNumAtoms()), dtype=dtype) - - # Add self loops - if add_self_loop: - arange = np.arange(adj.shape[0], dtype=int) - adj[arange, arange] = 1 - return adj - - -class GraphDict(dict): - def __init__( - self, - dic: Dict, - ): - r""" - Store the parameters required to initialize a `pyg.data.Data`, but - as a dictionary to reduce memory consumption. - - Possible keys for the dictionary: - - - adj: A sparse Tensor containing the adjacency matrix - - - ndata: A dictionnary containing different keys and Tensors - associated to the node features. - - - edata: A dictionnary containing different keys and Tensors - associated to the edge features. - - - dtype: The dtype for the floating data. - - - mask_nan: - Deal with molecules that fail a part of the featurization. - NaNs can happen when taking the of a noble gas, - or other properties that are not measured for specific atoms. - - - "raise": Raise an error when there is a nan or inf in the featurization - - "warn": Raise a warning when there is a nan or inf in the featurization - - "None": DEFAULT. Don't do anything - - "Floating value": Replace nans or inf by the specified value - """ - default_dic = { - "dtype": np.float16, - "mask_nan": "raise", - } - data = dic.pop("data", {}) - # ndata = dic.pop("ndata", {}) - # edata = dic.pop("edata", {}) - # for key in edata.keys(): - # assert key.startswith("edge_"), f"Edge features must start with 'edge_' but got {key}" - default_dic.update(dic) - default_dic.update(data) - # default_dic.update(ndata) - # default_dic.update(edata) - super().__init__(default_dic) - - @property - def keys(self): - return list(super().keys()) - - @property - def values(self): - return list(super().self.values()) - - def make_pyg_graph(self, **kwargs) -> Data: - """ - Convert the current dictionary of parameters, containing an adjacency matrix with node/edge data - into a `pyg.data.Data` of torch Tensors. - - `**kwargs` can be used to overwrite any parameter from the current dictionary. See `GraphDict.__init__` - for a list of parameters - """ - - num_nodes = self.adj.shape[0] - data_dict = {} - - # Convert the numpy and numpy sparse data to torch - for key, val in self.items(): - if key in ["adj", "dtype", "mask_nan"]: # Skip the parameters - continue - elif isinstance(val, np.ndarray): - # Convert the data to the specified dtype in torch format - val = val.astype(self.dtype) - data_dict[key] = torch.as_tensor(val) - elif issparse(val): - data_dict[key] = torch.as_tensor(val.astype(np.float32).todense()) - # `torch.sparse_coo_tensor` is too slow. Slows down the multiprocessing of features by >3x on 32 cores. - # indices = torch.from_numpy(np.vstack((val.row, val.col)).astype(np.int64)) - # data_dict[key] = torch.sparse_coo_tensor(indices=indices, values=val.data, size=val.shape) - elif isinstance(val, torch.Tensor): - data_dict[key] = val - else: - pass # Skip the other parameters - - # Create the PyG graph object `Data` - edge_index = torch.as_tensor(np.vstack((self.adj.row, self.adj.col))) - edge_weight = torch.as_tensor(self.adj.data) - data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes, **data_dict) - return data - - @property - def adj(self): - return self["adj"] - - @property - def dtype(self): - return self["dtype"] - - @property - def mask_nan(self): - return self["mask_nan"] - - @property - def num_nodes(self) -> int: - return self.adj.shape[0] - - @property - def num_edges(self) -> int: - if issparse(self.adj): - return self.adj.nnz - else: - return np.count_nonzero(self.adj) # No division by 2 because edges are counted twice - - -def mol_to_graph_dict( - mol: dm.Mol, - atom_property_list_onehot: List[str] = [], - atom_property_list_float: List[Union[str, Callable]] = [], +def mol_to_pyggraph( + mol: str, + atom_property_list_onehot: torch.Tensor = torch.tensor(data=[], dtype=torch.int64), + atom_property_list_float: torch.Tensor = torch.tensor(data=[], dtype=torch.int64), conformer_property_list: List[str] = [], - edge_property_list: List[str] = [], + edge_property_list: torch.Tensor = torch.tensor(data=[], dtype=torch.int64), add_self_loop: bool = False, explicit_H: bool = False, use_bonds_weights: bool = False, - pos_encoding_as_features: Dict[str, Any] = None, + pos_encoding_as_features: Tuple[List[str], torch.Tensor] = ([], torch.tensor(data=[], dtype=torch.int64)), dtype: np.dtype = np.float16, on_error: str = "ignore", mask_nan: Union[str, float, type(None)] = "raise", max_num_atoms: Optional[int] = None, -) -> Union[GraphDict, str]: +) -> Union[Data, str]: r""" Transforms a molecule into an adjacency matrix representing the molecular graph and a set of atom and bond features, and re-organizes them into a dictionary @@ -999,12 +60,10 @@ def mol_to_graph_dict( atom_property_list_onehot: List of the properties used to get one-hot encoding of the atom type, such as the atom index represented as a one-hot vector. - See function `get_mol_atomic_features_onehot` atom_property_list_float: List of the properties used to get floating-point encoding of the atom type, such as the atomic mass or electronegativity. - See function `get_mol_atomic_features_float` conformer_property_list: list of properties used to encode the conformer information, outside of atom properties, currently support "positions_3d" @@ -1068,191 +127,83 @@ def mol_to_graph_dict( - "dtype": The numpy dtype for the floating data. """ - input_mol = mol + if not isinstance(mol, str): + raise ValueError( + f"mol_to_pyggraph requires that molecule be received as a string, not type " + str(type(mol)) + ) + try: - if isinstance(mol, str): - mol = dm.to_mol(mol, ordered=True) - if explicit_H: - mol = Chem.AddHs(mol) + has_conformer = "positions_3d" in conformer_property_list + pe_index = 4 + if has_conformer: + pe_index = 5 + mask_nan_value = 0.0 + if mask_nan is None: + mask_nan_style_int = 0 + elif mask_nan == "raise" or mask_nan == "warn": + mask_nan_style_int = 1 else: - mol = Chem.RemoveHs(mol) - num_atoms = mol.GetNumAtoms() - if (max_num_atoms is not None) and (num_atoms > max_num_atoms): - raise ValueError(f"Maximum number of atoms greater than permitted {num_atoms}>{max_num_atoms}") - ( - adj, - ndata, - edata, - pe_dict, - conf_dict, - ) = mol_to_adj_and_features( - mol=mol, - atom_property_list_onehot=atom_property_list_onehot, - atom_property_list_float=atom_property_list_float, - conformer_property_list=conformer_property_list, - edge_property_list=edge_property_list, - add_self_loop=add_self_loop, - explicit_H=explicit_H, - use_bonds_weights=use_bonds_weights, - pos_encoding_as_features=pos_encoding_as_features, - mask_nan=mask_nan, + mask_nan_style_int = 2 + mask_nan_value = float(mask_nan) + tensors, num_nans, nan_tensor_index = graphium_cpp.featurize_smiles( + mol, + atom_property_list_onehot, + atom_property_list_float, + has_conformer, + edge_property_list, + pos_encoding_as_features[1], + True, # duplicate_edges, so that we don't have to duplicate below + add_self_loop, + explicit_H, + use_bonds_weights, + True, # offset_carbon + NP_DTYPE_TO_TORCH_INT[dtype], + mask_nan_style_int, + mask_nan_value, ) + + if num_nans > 0: + if nan_tensor_index == 2: + array_name = "atom featurization" + elif nan_tensor_index == 3: + array_name = "edge property" + elif nan_tensor_index == 4 and has_conformer: + array_name = "positions_3d" + else: + array_name = pos_encoding_as_features[0][nan_tensor_index - pe_index] + msg = f"There are {num_nans} NaNs in `{array_name}`" + if mask_nan == "raise": + raise ValueError(msg) + elif mask_nan == "warn": + logger.warning(msg) + + num_atoms = tensors[2].size(0) + data_dict = {"feat": tensors[2], "edge_feat": tensors[3]} + if has_conformer: + data_dict["positions_3d"] = tensors[4] + for i in range(len(tensors) - pe_index): + data_dict[pos_encoding_as_features[0][i]] = tensors[i + pe_index] + # Create the PyG graph object `Data` + data = Data(edge_index=tensors[0], edge_weight=tensors[1], num_nodes=num_atoms, **data_dict) + return data + except Exception as e: if on_error.lower() == "raise": raise e elif on_error.lower() == "warn": - smiles = input_mol - if isinstance(smiles, dm.Mol): - smiles = Chem.MolToSmiles(input_mol) - msg = str(e) + "\nIgnoring following molecule:" + smiles + msg = str(e) + "\nIgnoring following molecule:" + mol logger.warning(msg) return str(e) elif on_error.lower() == "ignore": return str(e) - - graph_dict = {"adj": adj, "data": {}, "dtype": dtype} - - # Assign the node data - if ndata is not None: - graph_dict["data"]["feat"] = ndata - - # Assign the edge data - if edata is not None: - if issparse(edata): - edata = to_dense_array(edata, dtype=dtype) - hetero_edata = edata.repeat(2, axis=0) - graph_dict["data"]["edge_feat"] = hetero_edata - - # Put the positional encodings as node features - # TODO: add support for PE on edges - for key, pe in pe_dict.items(): - graph_dict["data"][key] = pe - - # put the conformer positions here - for key, val in conf_dict.items(): - graph_dict["data"][key] = val - - graph_dict = GraphDict(graph_dict) - return graph_dict - - -def mol_to_pyggraph( - mol: dm.Mol, - atom_property_list_onehot: List[str] = [], - atom_property_list_float: List[Union[str, Callable]] = [], - conformer_property_list: List[str] = [], - edge_property_list: List[str] = [], - add_self_loop: bool = False, - explicit_H: bool = False, - use_bonds_weights: bool = False, - pos_encoding_as_features: Dict[str, Any] = None, - dtype: np.dtype = np.float16, - on_error: str = "ignore", - mask_nan: Union[str, float, type(None)] = "raise", - max_num_atoms: Optional[int] = None, -) -> Union[Data, str]: - r""" - Transforms a molecule into an adjacency matrix representing the molecular graph - and a set of atom and bond features. - - Then, the adjacency matrix and node/edge features are used to build a - `pyg.data.Data` with pytorch Tensors. - - Parameters: - - mol: - The molecule to be converted - - atom_property_list_onehot: - List of the properties used to get one-hot encoding of the atom type, - such as the atom index represented as a one-hot vector. - See function `get_mol_atomic_features_onehot` - - atom_property_list_float: - List of the properties used to get floating-point encoding of the atom type, - such as the atomic mass or electronegativity. - See function `get_mol_atomic_features_float` - - conformer_property_list: - list of properties used to encode the conformer information, outside of atom properties, currently support "positions_3d" - - edge_property_list: - List of the properties used to encode the edges, such as the edge type - and the stereo type. - - add_self_loop: - Whether to add a value of `1` on the diagonal of the adjacency matrix. - - explicit_H: - Whether to consider the Hydrogens explicitely. If `False`, the hydrogens - are implicit. - - use_bonds_weights: - Whether to use the floating-point value of the bonds in the adjacency matrix, - such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5 - - pos_encoding_as_features: keyword arguments for function `graph_positional_encoder` - to generate positional encoding for node features. - - dtype: - The numpy data type used to build the graph - - on_error: - What to do when the featurization fails. This can change the - behavior of `mask_nan`. - - - "raise": Raise an error - - "warn": Raise a warning and return a string of the error - - "ignore": Ignore the error and return a string of the error - - mask_nan: - Deal with molecules that fail a part of the featurization. - NaNs can happen when taking the of a noble gas, - or other properties that are not measured for specific atoms. - - - "raise": Raise an error when there is a nan in the featurization - - "warn": Raise a warning when there is a nan in the featurization - - "None": DEFAULT. Don't do anything - - "Floating value": Replace nans by the specified value - - max_num_atoms: - Maximum number of atoms for a given molecule. If a molecule with more atoms - is give, an error is raised, but catpured according to the rules of - `on_error`. - Returns: - - graph: - Pyg graph, with `graph['feat']` corresponding to the concatenated - node data from `atom_property_list_onehot` and `atom_property_list_float`, - `graph['edge_feat']` corresponding to the concatenated edge data from `edge_property_list`. - There are also additional entries for the positional encodings. - - """ - graph_dict = mol_to_graph_dict( - mol=mol, - atom_property_list_onehot=atom_property_list_onehot, - atom_property_list_float=atom_property_list_float, - conformer_property_list=conformer_property_list, - edge_property_list=edge_property_list, - add_self_loop=add_self_loop, - explicit_H=explicit_H, - use_bonds_weights=use_bonds_weights, - pos_encoding_as_features=pos_encoding_as_features, - dtype=dtype, - on_error=on_error, - mask_nan=mask_nan, - max_num_atoms=max_num_atoms, - ) - - if (graph_dict is not None) and not isinstance(graph_dict, str): - return graph_dict.make_pyg_graph() - else: - return graph_dict + else: + # Invalid on_error value, so default to raising an exception. + raise e def mol_to_graph_signature(featurizer_args: Dict[str, Any] = None) -> Dict[str, Any]: """ - Get the default arguments of `mol_to_graph_dict` and update it + Get the default arguments of `mol_to_pyggraph` and update it with a provided dict of arguments in order to get a fulle signature of the featurizer args actually used for the features computation. @@ -1262,8 +213,8 @@ def mol_to_graph_signature(featurizer_args: Dict[str, Any] = None) -> Dict[str, A dictionary of featurizer arguments """ - # Get the signature of `mol_to_graph_dict` - signature = inspect.signature(mol_to_graph_dict) + # Get the signature of `mol_to_pyggraph` + signature = inspect.signature(mol_to_pyggraph) # Filter out empty arguments (without default value) parameters = list(filter(lambda param: param.default is not param.empty, signature.parameters.values())) diff --git a/graphium/features/graphormer.py b/graphium/features/graphormer.py deleted file mode 100644 index d62010801..000000000 --- a/graphium/features/graphormer.py +++ /dev/null @@ -1,55 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Dict, Any - -import numpy as np -import networkx as nx - -from scipy.sparse import spmatrix, issparse - - -def compute_graphormer_distances( - adj: Union[np.ndarray, spmatrix], num_nodes: int, cache: Dict[str, Any] -) -> Tuple[np.ndarray, str, Dict[str, Any]]: - """ - Compute Graphormer distance between nodepairs. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - Returns: - dist [num_nodes, num_nodes]: 2D array with Graphormer distances between nodepairs - base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here nodepair - cache: Updated dictionary of cached objects - """ - - base_level = "nodepair" - - if "graphormer" in cache: - dist = cache["graphormer"] - - else: - if issparse(adj): - adj = adj.toarray() - - G = nx.from_numpy_array(adj) - paths = nx.all_pairs_shortest_path(G) - - dist_dict = {i: {j: len(path) - 1 for j, path in paths_from_i.items()} for i, paths_from_i in paths} - dist = np.asarray([[dist_dict[i][j] for j in range(num_nodes)] for i in range(num_nodes)]) - cache["graphormer"] = dist - - return dist, base_level, cache diff --git a/graphium/features/positional_encoding.py b/graphium/features/positional_encoding.py deleted file mode 100644 index 8acc231d8..000000000 --- a/graphium/features/positional_encoding.py +++ /dev/null @@ -1,181 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Optional, Dict, Any, OrderedDict -from copy import deepcopy -import numpy as np -import torch -from scipy.sparse import spmatrix -from collections import OrderedDict as OderedDictClass - -from graphium.features.spectral import compute_laplacian_pe -from graphium.features.rw import compute_rwse -from graphium.features.electrostatic import compute_electrostatic_interactions -from graphium.features.commute import compute_commute_distances -from graphium.features.graphormer import compute_graphormer_distances -from graphium.features.transfer_pos_level import transfer_pos_level - - -def get_all_positional_encodings( - adj: Union[np.ndarray, spmatrix], - num_nodes: int, - pos_kwargs: Optional[Dict] = None, -) -> Tuple["OrderedDict[str, np.ndarray]"]: - r""" - Get features positional encoding. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - num_nodes: Number of nodes in the graph - pos_encoding_as_features: keyword arguments for function `graph_positional_encoder` - to generate positional encoding for node features. - - Returns: - pe_dict: Dictionary of positional and structural encodings - """ - - pos_kwargs = {} if pos_kwargs is None else pos_kwargs - - pe_dict = OderedDictClass() - - # Initialize cache - cache = {} - - # Get the positional encoding for the features - if len(pos_kwargs) > 0: - for pos_name, this_pos_kwargs in pos_kwargs["pos_types"].items(): - this_pos_kwargs = deepcopy(this_pos_kwargs) - pos_type = this_pos_kwargs.pop("pos_type", None) - pos_level = this_pos_kwargs.pop("pos_level", None) - this_pe, cache = graph_positional_encoder( - deepcopy(adj), - num_nodes, - pos_type=pos_type, - pos_level=pos_level, - pos_kwargs=this_pos_kwargs, - cache=cache, - ) - if pos_level == "node": - pe_dict.update({f"{pos_type}": this_pe}) - else: - pe_dict.update({f"{pos_level}_{pos_type}": this_pe}) - - return pe_dict - - -def graph_positional_encoder( - adj: Union[np.ndarray, spmatrix], - num_nodes: int, - pos_type: Optional[str] = None, - pos_level: Optional[str] = None, - pos_kwargs: Optional[Dict[str, Any]] = None, - cache: Optional[Dict[str, Any]] = None, -) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: - r""" - Get a positional encoding that depends on the parameters. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - num_nodes: Number of nodes in the graph - pos_type: The type of positional encoding to use. If None, it must be provided by `pos_kwargs["pos_type"]`. Supported types are: - - laplacian_eigvec \ - - laplacian_eigval \ -> cache connected comps. & eigendecomp. - - rwse - - electrostatic \ - - commute \ -> cache pinvL - - graphormer - pos_level: Positional level to output. If None, it must be provided by `pos_kwargs["pos_level"]`. - - node - - edge - - nodepair - - graph - pos_kwargs: Extra keyword arguments for the positional encoding. Can include the keys pos_type and pos_level. - cache: Dictionary of cached objects - - Returns: - pe: Positional or structural encoding - cache: Updated dictionary of cached objects - """ - - pos_kwargs = deepcopy(pos_kwargs) - if pos_kwargs is None: - pos_kwargs = {} - if cache is None: - cache = {} - - # Get the positional type - pos_type2 = pos_kwargs.pop("pos_type", None) - if pos_type is None: - pos_type = pos_type2 - if pos_type2 is not None: - assert ( - pos_type == pos_type2 - ), f"The positional type must be the same in `pos_type` and `pos_kwargs['pos_type']`. Provided: {pos_type} and {pos_type2}" - assert pos_type is not None, "Either `pos_type` or `pos_kwargs['pos_type']` must be provided." - - # Get the positional level - pos_level2 = pos_kwargs.pop("pos_level", None) - if pos_level is None: - pos_level = pos_level2 - if pos_level2 is not None: - assert ( - pos_level == pos_level2 - ), f"The positional level must be the same in `pos_level` and `pos_kwargs['pos_level']`. Provided: {pos_level} and {pos_level2}" - assert pos_level is not None, "Either `pos_level` or `pos_kwargs['pos_level']` must be provided." - - # Convert to numpy array - if isinstance(adj, torch.sparse.Tensor): - adj = adj.to_dense().numpy() - elif isinstance(adj, torch.Tensor): - adj = adj.numpy() - adj = adj.astype(np.float64) - - # Calculate positional encoding - if pos_type == "laplacian_eigvec": - _, pe, base_level, cache = compute_laplacian_pe(adj, cache=cache, **pos_kwargs) - - elif pos_type == "laplacian_eigval": - pe, _, base_level, cache = compute_laplacian_pe(adj, cache=cache, **pos_kwargs) - - elif pos_type == "rw_return_probs": - pe, base_level, cache = compute_rwse( - adj.astype(np.float32), num_nodes=num_nodes, cache=cache, pos_type=pos_type, **pos_kwargs - ) - - elif pos_type == "rw_transition_probs": - pe, base_level, cache = compute_rwse( - adj.astype(np.float32), num_nodes=num_nodes, cache=cache, pos_type=pos_type, **pos_kwargs - ) - - elif pos_type == "electrostatic": - pe, base_level, cache = compute_electrostatic_interactions(adj, cache, **pos_kwargs) - - elif pos_type == "commute": - pe, base_level, cache = compute_commute_distances(adj, num_nodes, cache, **pos_kwargs) - - elif pos_type == "graphormer": - pe, base_level, cache = compute_graphormer_distances(adj, num_nodes, cache, **pos_kwargs) - - else: - raise ValueError(f"Unknown `pos_type`: {pos_type}") - - # Convert to float32 and Convert between different pos levels - if isinstance(pe, (list, tuple)): - pe = [this_pe.astype(np.float32) for this_pe in pe] - pe = [transfer_pos_level(this_pe, base_level, pos_level, adj, num_nodes, cache) for this_pe in pe] - else: - pe = np.real(pe).astype(np.float32) - pe = transfer_pos_level(pe, base_level, pos_level, adj, num_nodes, cache) - - return pe, cache diff --git a/graphium/features/properties.py b/graphium/features/properties.py deleted file mode 100644 index 89a90ffee..000000000 --- a/graphium/features/properties.py +++ /dev/null @@ -1,127 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Union, List, Callable - -import numpy as np -import datamol as dm - -from rdkit.Chem import rdMolDescriptors as rdMD -from loguru import logger - - -def get_prop_or_none( - prop: Callable, n: int, *args: Union[dm.Mol, str], **kwargs: Union[dm.Mol, str] -) -> Union[List[float], List[None]]: - r""" - return properties. If error, return list of `None` with lenght `n`. - Parameters: - prop: The property to compute. - n: The number of elements in the property. - *args: The arguments to pass to the property. - **kwargs: The keyword arguments to pass to the property. - Returns: - The property or a list of `None` with lenght `n`. - """ - logger.warning("get_prop_or_none is deprecated. Use `datamol.to_fp` instead.") - try: - return prop(*args, **kwargs) - except RuntimeError: - return [None] * n - - -def get_props_from_mol( - mol: Union[dm.Mol, str], - properties: Union[List[str], str] = "autocorr3d", -) -> np.ndarray: - r""" - Function to get a given set of desired properties from a molecule, - and output a property list. - - Parameters: - mol: The molecule from which to compute the properties. - properties: - The list of properties to compute for each molecule. It can be the following: - - - 'descriptors' - - 'autocorr3d' - - 'rdf' - - 'morse' - - 'whim' - - 'all' - - Returns: - props: np.array(float) - The array of properties for the desired molecule - classes_start_idx: list(int) - The list of index specifying the start of each new class of - descriptor or property. For example, if props has 20 elements, - the first 5 are rotatable bonds, the next 8 are morse, and - the rest are whim, then ``classes_start_idx = [0, 5, 13]``. - This will mainly be useful to normalize the features of - each class. - classes_names: list(str) - The name of the classes associated to each starting index. - Will be usefull to understand what property is the network learning. - - """ - - logger.warning("get_props_from_mol is deprecated. Use `datamol.to_fp` instead.") - - if isinstance(mol, str): - mol = dm.to_mol( - mol - ) # Doesn't need `ordered=True` because the fingerprints don't depend on the atom order - - if isinstance(properties, str): - properties = [properties] - - properties = [p.lower() for p in properties] - - # Initialize arrays - props = [] # Property vector for the features - classes_start_idx = [] # The starting index for each property class - classes_names = [] - - # Generate a 3D structure for the molecule - mol = dm.add_hs(mol) - - if ("autocorr3d" in properties) or ("all" in properties): - # Some kind of 3D description of the molecule - classes_names.append("autocorr3d") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcAUTOCORR3D, 80, mol)) - - if ("rdf" in properties) or ("all" in properties): - # The radial distribution function (better than the inertia) - # https://en.wikipedia.org/wiki/Radial_distribution_function - classes_names.append("rdf") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcRDF, 210, mol)) - - if ("morse" in properties) or ("all" in properties): - # Molecule Representation of Structures based on Electron diffraction descriptors - classes_names.append("morse") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcMORSE, 224, mol)) - - if ("whim" in properties) or ("all" in properties): - # WHIM descriptors are 3D structural descriptors obtained from the - # (x,y,z)‐atomic coordinates of a molecular conformation of a chemical, - # and are used successfully in QSAR modelling. - classes_names.append("whim") - classes_start_idx.append(len(props)) - props.extend(get_prop_or_none(rdMD.CalcWHIM, 114, mol)) - - return np.array(props), classes_start_idx, classes_names diff --git a/graphium/features/rw.py b/graphium/features/rw.py deleted file mode 100644 index c7eada2ba..000000000 --- a/graphium/features/rw.py +++ /dev/null @@ -1,169 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Optional, List, Dict, Any, Iterable - -from scipy.sparse import issparse, spmatrix, coo_matrix -import numpy as np -import torch - -from torch_geometric.utils import to_dense_adj, from_scipy_sparse_matrix -from torch_scatter import scatter_add -from torch_geometric.utils.num_nodes import maybe_num_nodes - - -def compute_rwse( - adj: Union[np.ndarray, spmatrix], - ksteps: Union[int, List[int]], - num_nodes: int, - cache: Dict[str, Any], - pos_type: str = "rw_return_probs" or "rw_transition_probs", - space_dim: int = 0, -) -> Tuple[np.ndarray, str, Dict[str, Any]]: - """ - Compute Random Walk Spectral Embedding (RWSE) for given list of K steps. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix - ksteps: List of numbers of steps for the random walks. If int, a list is generated from 1 to ksteps. - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - pos_type: Desired output - space_dim: Estimated dimensionality of the space. Used to - correct the random-walk diagonal by a factor `k^(space_dim/2)`. - In euclidean space, this correction means that the height of - the gaussian distribution stays almost constant across the number of - steps, if `space_dim` is the dimension of the euclidean space. - Returns: - Two possible outputs: - rw_return_probs [num_nodes, len(ksteps)]: Random-Walk k-step landing probabilities - rw_transition_probs [num_nodes, num_nodes, len(ksteps)]: Random-Walk k-step transition probabilities - base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here either node or nodepair - cache: Updated dictionary of cached objects - """ - - base_level = "node" if pos_type == "rw_return_probs" else "nodepair" - - # Manually handles edge case of 1 atom molecules here - if not isinstance(ksteps, Iterable): - ksteps = list(range(1, ksteps + 1)) - if num_nodes == 1: - if pos_type == "rw_return_probs": - return np.ones((1, len(ksteps))), base_level, cache - else: - return np.ones((1, 1, len(ksteps))), base_level, cache - - # Get the edge indices from the adjacency matrix - if not issparse(adj): - if "coo_adj" in cache: - adj = cache["coo_adj"] - elif "csr_adj" in cache: - adj = cache["csr_adj"] - else: - adj = coo_matrix(adj, dtype=np.float64) - cache["coo_adj"] = adj - - edge_index, edge_weight = from_scipy_sparse_matrix(adj) - - # Compute the random-walk transition probabilities - if "ksteps" in cache: - cached_k = cache["ksteps"] - missing_k = [k for k in ksteps if k not in cached_k] - if missing_k == []: - pass - elif min(missing_k) < min(cached_k): - Pk_dict = get_Pks(missing_k, edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes) - cache["ksteps"] = sorted(missing_k + cache["ksteps"]) - for k in missing_k: - cache["Pk"][k] = Pk_dict[k] - else: - start_k = min([max(cached_k), min(missing_k)]) - start_Pk = cache["Pk"][start_k] - Pk_dict = get_Pks( - missing_k, - edge_index=edge_index, - edge_weight=edge_weight, - num_nodes=num_nodes, - start_Pk=start_Pk, - start_k=start_k, - ) - cache["ksteps"] = sorted(cache["ksteps"] + missing_k) - for k in missing_k: - cache["Pk"][k] = Pk_dict[k] - else: - Pk_dict = get_Pks(ksteps, edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes) - - cache["ksteps"] = list(Pk_dict.keys()) - cache["Pk"] = Pk_dict - - pe_list = [] - if pos_type == "rw_return_probs": - for k in ksteps: - pe_list.append(torch.diagonal(cache["Pk"][k], dim1=-2, dim2=-1) * (k ** (space_dim / 2))) - else: - for k in ksteps: - pe_list.append(cache["Pk"][k]) - - pe = torch.stack(pe_list, dim=-1).numpy() - - return pe, base_level, cache - - -def get_Pks( - ksteps: List[int], - edge_index: Tuple[torch.Tensor, torch.Tensor], - edge_weight: Optional[torch.Tensor] = None, - num_nodes: Optional[int] = None, - start_Pk: Optional[torch.Tensor] = None, - start_k: Optional[int] = None, -) -> Dict[int, np.ndarray]: - """ - Compute Random Walk landing probabilities for given list of K steps. - - Parameters: - ksteps: List of numbers of k-steps for which to compute the RW landings - edge_index: PyG sparse representation of the graph - edge_weight: Edge weights - num_nodes: Number of nodes in the graph - - Returns: - 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs - """ - if edge_weight is None: - edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) - num_nodes = maybe_num_nodes(edge_index, num_nodes) - src = edge_index[0] - deg = scatter_add(edge_weight, src, dim=0, dim_size=num_nodes) # Out degrees. - deg_inv = deg.pow(-1.0) - deg_inv.masked_fill_(deg_inv == float("inf"), 0) - - if edge_index.numel() == 0: - P = edge_index.new_zeros((1, num_nodes, num_nodes)) - else: - # P = D^-1 * A - P = torch.diag(deg_inv).float() @ to_dense_adj( - edge_index, max_num_nodes=num_nodes - ) # 1 x (Num nodes) x (Num nodes) - - if start_Pk is not None: - Pk = start_Pk @ P.clone().detach().matrix_power(min(ksteps) - start_k) - else: - Pk = P.clone().detach().matrix_power(min(ksteps)) - - Pk_dict = {} - for k in range(min(ksteps), max(ksteps) + 1): - Pk_dict[k] = Pk.squeeze(0) - Pk = Pk @ P - - return Pk_dict diff --git a/graphium/features/spectral.py b/graphium/features/spectral.py deleted file mode 100644 index 55d8527a4..000000000 --- a/graphium/features/spectral.py +++ /dev/null @@ -1,218 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, Dict, Any -from scipy.linalg import eig -from scipy.sparse import csr_matrix, diags, issparse, spmatrix -import numpy as np -import torch -import networkx as nx - -from graphium.utils.tensor import is_dtype_torch_tensor, is_dtype_numpy_array - - -def compute_laplacian_pe( - adj: Union[np.ndarray, spmatrix], - num_pos: int, - cache: Dict[str, Any], - disconnected_comp: bool = True, - normalization: str = "none", -) -> Tuple[np.ndarray, str, Dict[str, Any]]: - r""" - Compute the Laplacian eigenvalues and eigenvectors of the Laplacian of the graph. - - Parameters: - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - num_pos: Number of Laplacian eigenvectors to compute - cache: Dictionary of cached objects - disconnected_comp: Whether to compute the eigenvectors for each connected component - normalization: Normalization to apply to the Laplacian - - Returns: - Two possible outputs: - eigvals [num_nodes, num_pos]: Eigenvalues of the Laplacian repeated for each node. - This repetition is necessary in case of disconnected components, where - the eigenvalues of the Laplacian are not the same for each node. - eigvecs [num_nodes, num_pos]: Eigenvectors of the Laplacian - base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here node - cache: Updated dictionary of cached objects - """ - - base_level = "node" - - # Sparsify the adjacency patrix - if not issparse(adj): - if "csr_adj" not in cache: - adj = csr_matrix(adj, dtype=np.float64) - cache["csr_adj"] = adj - else: - adj = cache["csr_adj"] - - # Compute the Laplacian, and normalize it - if f"L_{normalization}_sp" not in cache: - D = np.array(np.sum(adj, axis=1)).flatten() - D_mat = diags(D) - L = -adj + D_mat - L_norm = normalize_matrix(L, degree_vector=D, normalization=normalization) - cache[f"L_{normalization}_sp"] = L_norm - else: - L_norm = cache[f"L_{normalization}_sp"] - - components = [] - - if disconnected_comp: - if "components" not in cache: - # Get the list of connected components - components = list(nx.connected_components(nx.from_scipy_sparse_array(adj))) - cache["components"] = components - - else: - components = cache["components"] - - # Compute the eigenvectors for each connected component, and stack them together - if len(components) > 1: - if "lap_eig_comp" not in cache: - eigvals = np.zeros((adj.shape[0], num_pos), dtype=np.complex64) - eigvecs = np.zeros((adj.shape[0], num_pos), dtype=np.complex64) - for component in components: - comp = list(component) - this_L = L_norm[comp][:, comp] - this_eigvals, this_eigvecs = _get_positional_eigvecs(this_L, num_pos=num_pos) - - # Eigenvalues previously set to infinity are now set to 0 - # Any NaN in the eigvals or eigvecs will be set to 0 - this_eigvecs[~np.isfinite(this_eigvecs)] = 0.0 - this_eigvals[~np.isfinite(this_eigvals)] = 0.0 - - eigvals[comp, :] = np.expand_dims(this_eigvals, axis=0) - eigvecs[comp, :] = this_eigvecs - cache["lap_eig_comp"] = (eigvals, eigvecs) - - else: - eigvals, eigvecs = cache["lap_eig_comp"] - - else: - if "lap_eig" not in cache: - eigvals, eigvecs = _get_positional_eigvecs(L, num_pos=num_pos) - - # Eigenvalues previously set to infinity are now set to 0 - # Any NaN in the eigvals or eigvecs will be set to 0 - eigvecs[~np.isfinite(eigvecs)] = 0.0 - eigvals[~np.isfinite(eigvals)] = 0.0 - eigvals = np.repeat(np.expand_dims(eigvals, axis=0), adj.shape[0], axis=0) - - cache["lap_eig"] = (eigvals, eigvecs) - - else: - eigvals, eigvecs = cache["lap_eig"] - - return eigvals, eigvecs, base_level, cache - - -def _get_positional_eigvecs( - matrix: Union[np.ndarray, spmatrix], - num_pos: int, -) -> Tuple[np.ndarray, np.ndarray]: - r""" - compute the eigenvalues and eigenvectors of a matrix - Parameters: - matrix: Matrix to compute the eigenvalues and eigenvectors of - num_pos: Number of eigenvalues and eigenvectors to compute - Returns: - eigvals: Eigenvalues of the matrix - eigvecs: Eigenvectors of the matrix - """ - mat_len = matrix.shape[0] - eigvals, eigvecs = eig(matrix.todense()) - - # Pad with non-sense eigenvectors if required - if num_pos > mat_len: - temp_EigVal = np.ones(num_pos - mat_len, dtype=np.float64) + float("inf") - temp_EigVec = np.zeros((mat_len, num_pos - mat_len), dtype=np.float64) - eigvals = np.concatenate([eigvals, temp_EigVal], axis=0) - eigvecs = np.concatenate([eigvecs, temp_EigVec], axis=1) - - # Sort and keep only the first `num_pos` elements - sort_idx = eigvals.argsort() - eigvals = eigvals[sort_idx] - eigvals = eigvals[:num_pos] - eigvecs = eigvecs[:, sort_idx] - eigvecs = eigvecs[:, :num_pos] - - # Normalize the eigvecs - eigvecs = eigvecs / np.maximum(np.sqrt(np.sum(eigvecs**2, axis=0, keepdims=True)), 1e-4) - - return eigvals, eigvecs - - -def normalize_matrix( - matrix: Union[np.ndarray, spmatrix], - degree_vector=None, - normalization: str = None, -) -> Union[np.ndarray, spmatrix]: - r""" - Normalize a given matrix using its degree vector - - Parameters - --------------- - - matrix: torch.tensor(N, N) or scipy.sparse.spmatrix(N, N) - A square matrix representing either an Adjacency matrix or a Laplacian. - - degree_vector: torch.tensor(N) or np.ndarray(N) or None - A vector representing the degree of ``matrix``. - ``None`` is only accepted if ``normalization==None`` - - normalization: str or None, Default='none' - Normalization to use on the eig_matrix - - - 'none' or ``None``: no normalization - - - 'sym': Symmetric normalization ``D^-0.5 L D^-0.5`` - - - 'inv': Inverse normalization ``D^-1 L`` - - Returns - ----------- - matrix: torch.tensor(N, N) or scipy.sparse.spmatrix(N, N) - The normalized matrix - - """ - - # Transform the degree vector into a matrix - if degree_vector is None: - if not ((normalization is None) or (normalization.lower() == "none")): - raise ValueError("`degree_vector` cannot be `None` if `normalization` is not `None`") - else: - if is_dtype_numpy_array(matrix.dtype): - with np.errstate(divide="ignore", invalid="ignore"): - degree_inv = np.expand_dims(degree_vector**-0.5, axis=1) - degree_inv[np.isinf(degree_inv)] = 0 - elif is_dtype_torch_tensor(matrix.dtype): - degree_inv = torch.unsqueeze(degree_vector**-0.5, dim=1) - degree_inv[torch.isinf(degree_inv)] = 0 - - # Compute the normalized matrix - if (normalization is None) or (normalization.lower() == "none"): - pass - elif normalization.lower() == "sym": - matrix = degree_inv * matrix * degree_inv.T - elif normalization.lower() == "inv": - matrix = (degree_inv**2) * matrix - else: - raise ValueError( - f'`normalization` should be `None`, `"None"`, `"sym"` or `"inv"`, but `{normalization}` was provided' - ) - - return matrix diff --git a/graphium/features/transfer_pos_level.py b/graphium/features/transfer_pos_level.py deleted file mode 100644 index 4bb70e160..000000000 --- a/graphium/features/transfer_pos_level.py +++ /dev/null @@ -1,376 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Tuple, Union, List, Dict, Any, Optional - -import numpy as np - -from scipy.sparse import spmatrix, issparse, coo_matrix - -from torch_geometric.utils import from_scipy_sparse_matrix - - -def transfer_pos_level( - pe: np.ndarray, - in_level: str, - out_level: str, - adj: Union[np.ndarray, spmatrix], - num_nodes: int, - cache: Optional[Dict[str, Any]] = None, -) -> np.ndarray: - r""" - Transfer positional encoding between different positional levels (node, edge, nodepair, graph) - - Parameters: - pe: Input pe with pos_level defined by in_level - in_level: pos_level of input pe - out_level: desired pos_level of output pe - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - - Returns: - pe: Output pe with pos_level defined by out_level - """ - - if cache is None: - cache = {} - - if in_level == "node": - if out_level == "node": - pass - - elif out_level == "edge": - pe, cache = node_to_edge(pe, adj, cache) - - elif out_level == "nodepair": - pe = node_to_nodepair(pe, num_nodes) - - elif out_level == "graph": - raise NotImplementedError("Transfer function (node -> graph) not yet implemented.") - - else: - raise ValueError(f"Unknown `pos_level`: {out_level}") - - elif in_level == "edge": - raise NotImplementedError("Transfer function (edge -> *) not yet implemented.") - - elif in_level == "nodepair": - if len(pe.shape) == 2: - pe = np.expand_dims(pe, -1) - - if out_level == "node": - pe = nodepair_to_node(pe) - - elif out_level == "edge": - pe, cache = nodepair_to_edge(pe, adj, cache) - - elif out_level == "nodepair": - pass - - elif out_level == "graph": - raise NotImplementedError("Transfer function (nodepair -> graph) not yet implemented.") - - else: - raise ValueError(f"Unknown `pos_level`: {out_level}") - - elif in_level == "graph": - if out_level == "node": - pe = graph_to_node(pe, num_nodes, cache) - - elif out_level in ["edge", "nodepair"]: - raise NotImplementedError("Transfer function (graph -> edge/nodepair) not yet implemented.") - - else: - raise ValueError(f"Unknown `pos_level`: {out_level}") - - else: - raise ValueError(f"Unknown `pos_level`: {in_level}") - - return pe - - -# Transfer functions between different levels, i.e., node, edge, nodepair and graph level. - -# TODO: -# - Implement missing transfer functions below -# - Are transfer functions graph -> edge/nodepair and edge -> graph needed? - - -def node_to_edge( - pe: np.ndarray, adj: Union[np.ndarray, spmatrix], cache: Optional[Dict[str, Any]] = None -) -> Tuple[np.ndarray, Dict[str, Any]]: - r""" - Get an edge-level positional encoding from a node-level positional encoding. - -> For each edge, concatenate the sum and absolute difference of pe of source and destination node. - - Parameters: - pe [num_nodes, num_feat]: Node-level positional encoding - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - cache: Dictionary of cached objects - - Returns: - edge_pe [2 * num_edges, 2 * num_feat]: Edge-level positional encoding - cache: Updated dictionary of cached objects - """ - - if cache is None: - cache = {} - - if not issparse(adj): - if "coo_adj" in cache: - adj = cache["coo_adj"] - elif "csr_adj" in cache: - adj = cache["csr_adj"] - else: - adj = coo_matrix(adj, dtype=np.float64) - cache["coo_adj"] = adj - - edge_index, _ = from_scipy_sparse_matrix(adj) - src, dst = edge_index[0], edge_index[1] - - pe_sum = pe[src] + pe[dst] - pe_abs_diff = np.abs(pe[src] - pe[dst]) - - edge_pe = np.concatenate((pe_sum, pe_abs_diff), axis=-1) - - return edge_pe, cache - - -def node_to_nodepair(pe: np.ndarray, num_nodes: int) -> np.ndarray: - r""" - Get a nodepair-level positional encoding from a node-level positional encoding. - -> For each nodepair (i,j) concatenate the sum and absolute difference of pe at node i and j. - - Parameters: - pe [num_nodes, num_feat]: Node-level positional encoding - num_nodes: Number of nodes in the graph - - Returns: - nodepair_pe [num_nodes, num_nodes, 2 * num_feat]: Nodepair-level positional encoding - """ - - expanded_pe = np.expand_dims(pe, axis=1) - expanded_pe = np.repeat(expanded_pe, repeats=num_nodes, axis=1) - - pe_sum = expanded_pe + expanded_pe.transpose([1, 0, 2]) - pe_abs_diff = np.abs(expanded_pe - expanded_pe.transpose([1, 0, 2])) - - nodepair_pe = np.concatenate((pe_sum, pe_abs_diff), axis=-1) - - return nodepair_pe - - -def node_to_graph(pe: np.ndarray, num_nodes: int) -> np.ndarray: - r""" - Get a graph-level positional encoding from a node-level positional encoding. - -> E.g., min/max/mean-pooling of node features. - - Parameters: - pe [num_nodes, num_feat]: Node-level positional encoding - num_nodes: Number of nodes in the graph - - Returns: - graph_pe [1, num_feat]: Graph-level positional encoding - """ - - raise NotImplementedError("Transfer function (node -> graph) not yet implemented.") - - -def edge_to_node(pe: np.ndarray, adj: Union[np.ndarray, spmatrix]) -> np.ndarray: - r""" - Get a node-level positional encoding from an edge-level positional encoding. - -> E.g., min/max/mean-pooling of information from edges (i,j) that contain node i - - Parameters: - pe [num_edges, num_feat]: Edge-level positional encoding - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - - Returns: - node_pe [num_edges, num_feat]: Node-level positional encoding - """ - - raise NotImplementedError("Transfer function (edge -> node) not yet implemented.") - - -def edge_to_nodepair( - pe: np.ndarray, adj: Union[np.ndarray, spmatrix], num_nodes: int, cache: Optional[Dict[str, Any]] = None -) -> np.ndarray: - r""" - Get a nodepair-level positional encoding from an edge-level positional encoding. - -> Zero-padding of non-existing edges. - - Parameters: - pe [num_edges, num_feat]: Edge-level positional encoding - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - - Returns: - nodepair_pe [num_edges, num_edges, num_feat]: Nodepair-level positional encoding - cache: Updated dictionary of cached objects - """ - - if cache is None: - cache = {} - - num_feat = pe.shape[-1] - - if not isinstance(adj, coo_matrix): - if "coo_adj" in cache: - adj = cache["coo_adj"] - else: - adj = coo_matrix(adj, dtype=np.float64) - cache["coo_adj"] = adj - - dst, src = adj.row, adj.col - - nodepair_pe = np.zeros((num_nodes, num_nodes, num_feat)) - - for i in range(len(dst)): - nodepair_pe[dst[i], src[i], ...] = pe[i, ...] - - return nodepair_pe, cache - - -def edge_to_graph(pe: np.ndarray) -> np.ndarray: - r""" - Get a graph-level positional encoding from an edge-level positional encoding. - - Parameters: - pe [num_edges, num_feat]: Edge-level positional encoding - - Returns: - graph_pe [1, num_feat]: Graph-level positional encoding - """ - - raise NotImplementedError("Transfer function (edge -> graph) not yet implemented.") - - -def nodepair_to_node(pe: np.ndarray, stats_list: List = [np.min, np.mean, np.std]) -> np.ndarray: - r""" - Get a node-level positional encoding from a graph-level positional encoding. - -> Calculate statistics over rows & cols of input positional encoding - - Parameters: - pe [num_nodes, num_nodes, num_feat]: Nodepair-level positional encoding - stats_list: List of statistics to calculate per row/col of nodepair-level pe - - Returns: - node_pe [num_nodes, 2 * len(stats_list) * num_feat]: Node-level positional encoding - """ - - num_feat = pe.shape[-1] - - node_pe_list = [] - - for stat in stats_list: - for i in range(num_feat): - node_pe_list.append(stat(pe[..., i], axis=0)) - node_pe_list.append(stat(pe[..., i], axis=1)) - node_pe = np.stack(node_pe_list, axis=-1) - - return node_pe - - -def nodepair_to_edge( - pe: np.ndarray, adj: Union[np.ndarray, spmatrix], cache: Optional[Dict[str, Any]] = None -) -> np.ndarray: - r""" - Get a edge-level positional encoding from a nodepair-level positional encoding. - -> Mask and sparsify nodepair-level positional encoding - - Parameters: - pe [num_nodes, num_nodes, num_feat]: Nodepair-level positional encoding - adj [num_nodes, num_nodes]: Adjacency matrix of the graph - cache: Dictionary of cached objects - - Returns: - edge_pe [num_edges, num_feat]: Edge-level positional encoding - cache: Updated dictionary of cached objects - """ - - if cache is None: - cache = {} - - num_feat = pe.shape[-1] - - if not isinstance(adj, coo_matrix): - if "coo_adj" in cache: - adj = cache["coo_adj"] - else: - adj = coo_matrix(adj, dtype=np.float64) - cache["coo_adj"] = adj - - dst, src = adj.row, adj.col - - edge_pe = np.zeros((len(dst), num_feat)) - - for i in range(len(src)): - edge_pe[i, ...] = pe[dst[i], src[i]] - - return edge_pe, cache - - -def nodepair_to_graph(pe: np.ndarray, num_nodes: int) -> np.ndarray: - r""" - Get a graph-level positional encoding from a nodepair-level positional encoding. - -> E.g., min/max/mean-pooling of entries of input pe - - Parameters: - pe [num_nodes, num_nodes, num_feat]: Nodepair-level positional encoding - num_nodes: Number of nodes in the graph - - Returns: - graph_pe [1, num_feat]: Graph-level positional encoding - """ - - raise NotImplementedError("Transfer function (nodepair -> graph) not yet implemented.") - - -def graph_to_node( - pe: Union[np.ndarray, List], num_nodes: int, cache: Optional[Dict[str, Any]] = None -) -> np.ndarray: - r""" - Get a node-level positional encoding from a nodepair-level positional encoding. - -> E.g., expand dimension of graph-level pe - - Parameters: - pe [num_feat]: Nodepair-level positional encoding (or list of them if graph disconnected) - num_nodes: Number of nodes in the graph - cache: Dictionary of cached objects - - Returns: - node_pe [num_nodes, num_feat]: Node-level positional encoding - """ - - if cache is None: - cache = {} - - node_pe = None - - # The key 'components' is only in cache if disconnected_comp == True when computing base pe - if "components" in cache: - if len(cache["components"]) > 1: - node_pe = np.zeros((num_nodes, len(pe))) - components = cache["components"] - - for i, component in enumerate(components): - comp = list(component) - node_pe[comp, :] = np.real(pe[i]) - - if node_pe is None: - node_pe = np.tile(pe, (num_nodes, 1)) - - return node_pe diff --git a/graphium/finetuning/finetuning.py b/graphium/finetuning/finetuning.py index 97d6d7fc7..4100218c5 100644 --- a/graphium/finetuning/finetuning.py +++ b/graphium/finetuning/finetuning.py @@ -29,7 +29,8 @@ def __init__( finetuning_module: str, added_depth: int = 0, unfreeze_pretrained_depth: Optional[int] = None, - epoch_unfreeze_all: int = 0, + epoch_unfreeze_all: Optional[int] = 0, + always_freeze_modules: Optional[Union[List, str]] = None, train_bn: bool = False, ): """ @@ -41,6 +42,7 @@ def __init__( added_depth: Number of layers of finetuning module that have been modified rel. to pretrained model unfreeze_pretrained_depth: Number of additional layers to unfreeze before layers modified rel. to pretrained model epoch_unfreeze_all: Epoch to unfreeze entire model + always_freeze_modules: Module that always stay frozen while finetuning train_bn: Boolean value indicating if batchnorm layers stay in training mode """ @@ -51,6 +53,11 @@ def __init__( if unfreeze_pretrained_depth is not None: self.training_depth += unfreeze_pretrained_depth self.epoch_unfreeze_all = epoch_unfreeze_all + self.always_freeze_modules = always_freeze_modules + if self.always_freeze_modules == 'none': + self.always_freeze_modules = None + if isinstance(self.always_freeze_modules, str): + self.always_freeze_modules = [self.always_freeze_modules] self.train_bn = train_bn def freeze_before_training(self, pl_module: pl.LightningModule): @@ -105,3 +112,7 @@ def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer """ if epoch == self.epoch_unfreeze_all: self.unfreeze_and_add_param_group(modules=pl_module, optimizer=optimizer, train_bn=self.train_bn) + + if self.always_freeze_modules is not None: + for module_name in self.always_freeze_modules: + self.freeze_module(pl_module, module_name, pl_module.model.pretrained_model.net._module_map) \ No newline at end of file diff --git a/graphium/finetuning/finetuning_architecture.py b/graphium/finetuning/finetuning_architecture.py index 864016141..e0b3fbe92 100644 --- a/graphium/finetuning/finetuning_architecture.py +++ b/graphium/finetuning/finetuning_architecture.py @@ -345,4 +345,4 @@ def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = """ # For the post-nn network, all the dimension are divided - return self.net.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim) + return self.net.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim) \ No newline at end of file diff --git a/graphium/finetuning/utils.py b/graphium/finetuning/utils.py index 7b9f7df74..2ef440bc8 100644 --- a/graphium/finetuning/utils.py +++ b/graphium/finetuning/utils.py @@ -22,16 +22,34 @@ import graphium +def filter_cfg_for_custom_task(config: Dict[str, Any], task: str, task_type: str): + """ + Filter a base config for the task type (regression vs. classification) + """ + + cfg = deepcopy(config) + + # Filter the relevant config sections + if "predictor" in cfg and "metrics_on_progress_bar" in cfg["predictor"]: + cfg["predictor"]["metrics_on_progress_bar"] = {task: cfg["predictor"]["metrics_on_progress_bar"][task_type]} + if "predictor" in cfg and "loss_fun" in cfg["predictor"]: + cfg["predictor"]["loss_fun"] = {task: cfg["predictor"]["loss_fun"][task_type]} + if "metrics" in cfg: + cfg["metrics"] = {task: cfg["metrics"][task_type]} + + return cfg + + def filter_cfg_based_on_admet_benchmark_name(config: Dict[str, Any], names: Union[List[str], str]): """ Filter a base config for the full TDC ADMET benchmarking group to only have settings related to a subset of the endpoints """ - if config["datamodule"]["module_type"] != "ADMETBenchmarkDataModule": + if config["datamodule"]["module_type"] != "TDCBenchmarkDataModule": # NOTE (cwognum): For now, this implies we only support the ADMET benchmark from TDC. # It is easy to extend this in the future to support more datasets. - raise ValueError("You can only use this method for the `ADMETBenchmarkDataModule`") + raise ValueError("You can only use this method for the `TDCBenchmarkDataModule`") if isinstance(names, str): names = [names] @@ -61,17 +79,45 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): """ Function combining information from configuration and pretrained model for finetuning. """ - task = cfg["finetuning"]["task"] - # Filter the config based on the task name - # NOTE (cwognum): This prevents the need for having many different files for each of the tasks - # with lots and lots of config repetition. - cfg = filter_cfg_based_on_admet_benchmark_name(cfg, task) + benchmark = cfg["constants"].pop("benchmark", None) + task_type = cfg["constants"].pop("task_type", None) + task = cfg["finetuning"].get("task", "task") + + if benchmark == "custom" and task_type is not None: + cfg = filter_cfg_for_custom_task(cfg, task, task_type) + else: + # Filter the config based on the task name + # NOTE (cwognum): This prevents the need for having many different files for each of the tasks + # with lots and lots of config repetition. + cfg = filter_cfg_based_on_admet_benchmark_name(cfg, task) + cfg_finetune = cfg["finetuning"] # Load pretrained model pretrained_model = cfg_finetune["pretrained_model"] - pretrained_predictor = PredictorModule.load_pretrained_model(pretrained_model, device="cpu") + if isinstance(pretrained_model, dict): + mode = pretrained_model.get('mode') + size = pretrained_model.get('size') + model = pretrained_model.get('model') + pretraining_seed = pretrained_model.get('pretraining_seed') + if mode == 'width': + size = size[:4] + elif mode == 'depth': + size = size[:2] + elif mode == 'molecule': + size = size[:3] + elif mode == 'label': + size = size[:3] + elif mode == 'ablation': + size = f"_{size}" + pretrained_model_name = f"{mode}{size}_{model}_s{pretraining_seed}" + + else: + pretrained_model_name = pretrained_model + + cfg_finetune["pretrained_model"] = pretrained_model_name + pretrained_predictor = PredictorModule.load_pretrained_model(pretrained_model_name, device="cpu") # Inherit shared configuration from pretrained # Architecture @@ -123,6 +169,10 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): # Update config new_module_kwargs.update(upd_kwargs) + depth, hidden_dims = new_module_kwargs['depth'], new_module_kwargs['hidden_dims'] + if isinstance(hidden_dims, list): + if len(hidden_dims) != depth: + new_module_kwargs['hidden_dims'] = (depth - 1) * [hidden_dims[0]] if sub_module_from_pretrained is None: cfg_arch[finetuning_module] = new_module_kwargs @@ -151,6 +201,13 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): # Change architecture to FullGraphFinetuningNetwork cfg_arch["model_type"] = "FullGraphFinetuningNetwork" + def _change_dropout(c): + if isinstance(c, dict): + return {k: (_change_dropout(v) if k != 'dropout' else cfg['constants']['model_dropout']) for k, v in c.items()} + return c + + cfg_arch = _change_dropout(cfg_arch) + cfg["architecture"] = cfg_arch pretrained_overwriting_kwargs = deepcopy(cfg["finetuning"]) @@ -160,6 +217,7 @@ def modify_cfg_for_finetuning(cfg: Dict[str, Any]): "finetuning_head", "unfreeze_pretrained_depth", "epoch_unfreeze_all", + "always_freeze_modules", ] for key in drop_keys: @@ -213,4 +271,4 @@ def update_cfg_arch_for_module( ) cfg_arch[module_name].update({new_sub_module: cfg_arch_from_pretrained[module_name][sub_module]}) - return cfg_arch + return cfg_arch \ No newline at end of file diff --git a/graphium/ipu/__init__.py b/graphium/fingerprinting/__init__.py similarity index 100% rename from graphium/ipu/__init__.py rename to graphium/fingerprinting/__init__.py diff --git a/graphium/fingerprinting/data.py b/graphium/fingerprinting/data.py new file mode 100644 index 000000000..76856a6a9 --- /dev/null +++ b/graphium/fingerprinting/data.py @@ -0,0 +1,252 @@ +from typing import Any, List, Dict, Literal, Union + +import os + +import torch + +import pandas as pd +import numpy as np + +from pytorch_lightning import LightningDataModule + +from torch.utils.data import Dataset, DataLoader + +from graphium.data.datamodule import BaseDataModule, MultitaskFromSmilesDataModule, TDCBenchmarkDataModule, DatasetProcessingParams +from graphium.trainer.predictor import PredictorModule +from graphium.fingerprinting.fingerprinter import Fingerprinter + + +class FingerprintDataset(Dataset): + """ + Dataset class for fingerprints useful for probing experiments. + + Parameters: + labels: Labels for the dataset. + fingerprints: Dictionary of fingerprints, where keys specify model and layer of extraction. + smiles: List of SMILES strings. + """ + def __init__( + self, + labels: torch.Tensor, + fingerprints: Dict[str, torch.Tensor], + smiles: List[str] = None, + ): + self.labels = labels + self.fingerprints = fingerprints + self.smiles = smiles + + def __len__(self): + return len(self.labels) + + def __getitem__(self, index): + fp_list = [] + for val in self.fingerprints.values(): + fp_list.append(val[index]) + + if self.smiles is not None: + return fp_list, self.labels[index], self.smiles[index] + else: + return fp_list, self.labels[index] + + +class FingerprintDatamodule(LightningDataModule): + """ + DataModule class for extracting fingerprints from one (or multiple) pretrained model(s). + + Parameters: + pretrained_models: Dictionary of pretrained models (keys) and list of layers (values), repectively to use. + task: Task to extract fingerprints for. + benchmark: Benchmark to extract fingerprints for. + df_path: Path to the DataFrame containing the SMILES strings. + batch_size: Batch size for fingerprint extraction (i.e., the forward passes of the pretrained models). + split_type: Type of split to use for the dataset. + splits_path: Path to the splits file. + split_val: Fraction of validation data. + split_test: Fraction of test data. + data_seed: Seed for data splitting. + num_workers: Number of workers for data loading. + device: Device to use for fingerprint extraction. + mol_cache_dir: Directory to cache the molecules in. + fps_cache_dir: Directory to cache the fingerprints in + + """ + def __init__( + self, + pretrained_models: Dict[str, List[str]], + task: str = "herg", + benchmark: Literal["tdc", None] = "tdc", + df_path: str = None, + batch_size: int = 64, + split_type: Literal["random", "scaffold"] = "random", + splits_path: str = None, + split_val: float = 0.1, + split_test: float = 0.1, + data_seed: int = 42, + num_workers: int = 0, + device: str = "cpu", + mol_cache_dir: str = "./expts/data/cache", + fps_cache_dir: str = "./expts/data/cache", + ): + super().__init__() + + assert benchmark is not None or df_path is not None, "Either benchmark or df_path must be provided" + + self.pretrained_models = pretrained_models + self.task = task + self.benchmark = benchmark + self.df_path = df_path + self.batch_size = batch_size + self.split_type = split_type + self.splits_path = splits_path + self.split_val = split_val + self.split_test = split_test + self.data_seed = data_seed + self.num_workers = num_workers + self.device = device + self.mol_cache_dir = mol_cache_dir + self.fps_cache_dir = fps_cache_dir + if benchmark is not None: + # Check if benchmark naming is already implied in config + if f"{benchmark}/{task}" not in mol_cache_dir: + self.mol_cache_dir = f"{mol_cache_dir}/{benchmark}/{task}" + if f"{benchmark}/{task}" not in fps_cache_dir: + self.fps_cache_dir = f"{fps_cache_dir}/{benchmark}/{task}" + + self.train_dataset = None + self.valid_dataset = None + self.test_dataset = None + + self.splits = [] + + def prepare_data(self) -> None: + if self.fps_cache_dir is not None and os.path.exists(f"{self.fps_cache_dir}/fps.pt"): + self.smiles, self.labels, self.fps_dict = torch.load(f"{self.fps_cache_dir}/fps.pt").values() + self.splits = list(self.smiles.keys()) + + else: + # Check which splits are needed + self.splits = [] + add_all = self.benchmark is not None or self.splits_path is not None + if add_all or self.split_val + self.split_test < 1: + self.splits.append("train") + if add_all or self.split_val > 0: + self.splits.append("valid") + if add_all or self.split_test > 0: + self.splits.append("test") + + self.data = { + "smiles": {split: [] for split in self.splits}, + "labels": {split: [] for split in self.splits}, + "fps": {split: {} for split in self.splits}, + } + + for model, layers in self.pretrained_models.items(): + predictor = PredictorModule.load_pretrained_model(model, device=self.device) + predictor.featurization.pop("max_num_atoms", None) + + # Featurization + if self.benchmark is None: + assert self.df_path is not None, "df_path must be provided if not using an integrated benchmark" + + # Add a dummy task column (filled with NaN values) in case no such column is provided + base_datamodule = BaseDataModule() + smiles_df = base_datamodule._read_table(self.df_path) + task_cols = [col for col in smiles_df if col.startswith("task_")] + if len(task_cols) == 0: + df_path, file_type = ".".join(self.df_path.split(".")[:-1]), self.df_path.split(".")[-1] + + smiles_df["task_dummy"] = np.nan + + if file_type == "parquet": + smiles_df.to_parquet(f"{df_path}_with_dummy_task_col.{file_type}", index=False) + else: + smiles_df.to_csv(f"{df_path}_with_dummy_task_col.{file_type}", index=False) + + self.df_path = f"{df_path}_with_dummy_task_col.{file_type}" + + task_specific_args = { + "fingerprinting": DatasetProcessingParams( + df_path=self.df_path, + smiles_col="smiles", + label_cols="task_*", + task_level="graph", + splits_path=self.splits_path, + split_type=self.split_type, + split_val=self.split_val, + split_test=self.split_test, + seed=self.data_seed, + ) + } + label_key = "graph_fingerprinting" + + datamodule = MultitaskFromSmilesDataModule( + task_specific_args=task_specific_args, + batch_size_inference=128, + featurization=predictor.featurization, + featurization_n_jobs=0, + processed_graph_data_path=f"{self.mol_cache_dir}/mols/", + ) + + elif self.benchmark == "tdc": + datamodule = TDCBenchmarkDataModule( + tdc_benchmark_names=[self.task], + tdc_train_val_seed=self.data_seed, + batch_size_inference=128, + featurization=predictor.featurization, + featurization_n_jobs=self.num_workers, + processed_graph_data_path=f"{self.mol_cache_dir}/mols/", + ) + label_key = f"graph_{self.task}" + + else: + raise ValueError(f"Invalid benchmark: {self.benchmark}") + + datamodule.prepare_data() + datamodule.setup() + + loader_dict = {} + if "train" in self.splits: + datamodule.train_ds.return_smiles = True + loader_dict["train"] = datamodule.get_dataloader(datamodule.train_ds, shuffle=False, stage="predict") + if "valid" in self.splits: + datamodule.val_ds.return_smiles = True + loader_dict["valid"] = datamodule.get_dataloader(datamodule.val_ds, shuffle=False, stage="predict") + if "test" in self.splits: + datamodule.test_ds.return_smiles = True + loader_dict["test"] = datamodule.get_dataloader(datamodule.test_ds, shuffle=False, stage="predict") + + for split, loader in loader_dict.items(): + if len(self.data["smiles"][split]) == 0: + for batch in loader: + self.data["smiles"][split] += [item for item in batch["smiles"]] + self.data["labels"][split] += batch["labels"][label_key] + + with Fingerprinter(predictor, layers, out_type="torch") as fp: + fps = fp.get_fingerprints_for_dataset(loader, store_dict=True) + for fp_name, fp in fps.items(): + self.data["fps"][split][f"{model}/{fp_name}"] = fp + + os.makedirs(self.fps_cache_dir, exist_ok=True) + torch.save(self.data, f"{self.fps_cache_dir}/fps.pt") + + def setup(self, stage: str) -> None: + # Creating datasets + if stage == "fit": + self.train_dataset = FingerprintDataset(self.labels["train"], self.fps_dict["train"]) + self.valid_dataset = FingerprintDataset(self.labels["valid"], self.fps_dict["valid"]) + else: + self.test_dataset = FingerprintDataset(self.labels["test"], self.fps_dict["test"]) + + def get_fp_dims(self): + fp_dict = next(iter(self.fps_dict.values())) + + return [fp.size(1) for fp in fp_dict.values()] + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True) + + def val_dataloader(self): + return DataLoader(self.valid_dataset, batch_size=len(self.valid_dataset), shuffle=False) + + def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=len(self.test_dataset), shuffle=False) \ No newline at end of file diff --git a/graphium/finetuning/fingerprinting.py b/graphium/fingerprinting/fingerprinter.py similarity index 79% rename from graphium/finetuning/fingerprinting.py rename to graphium/fingerprinting/fingerprinter.py index 8bfdb5d94..f29023145 100644 --- a/graphium/finetuning/fingerprinting.py +++ b/graphium/fingerprinting/fingerprinter.py @@ -138,8 +138,9 @@ def setup(self): self.network._enable_readout_cache(list(self._spec.keys())) return self - def get_fingerprints_for_batch(self, batch): + def get_fingerprints_for_batch(self, batch, store_dict: bool=False): """Get the fingerprints for a single batch""" + self.network.eval() if not self.network._cache_readouts: raise RuntimeError( @@ -152,18 +153,31 @@ def get_fingerprints_for_batch(self, batch): with torch.inference_mode(): if self.predictor is not None: batch["features"] = self.predictor._convert_features_dtype(batch["features"]) + device = next(iter(self.network.parameters())).device + for key, val in batch["features"].items(): + if isinstance(val, torch.Tensor): + batch["features"][key] = val.to(device) self.network(batch["features"]) - readout_list = [] - for module_name, layers in self._spec.items(): - readout_list.extend( - [self.network._module_map[module_name]._readout_cache[layer].cpu() for layer in layers] - ) - - feats = torch.cat(readout_list, dim=-1) - return self._convert_output_type(feats) + if store_dict: + readout_dict = {} + for module_name, layers in self._spec.items(): + for layer in layers: + readout_dict[f"{module_name}:{layer}"] = self._convert_output_type(self.network._module_map[module_name]._readout_cache[layer].cpu()) - def get_fingerprints_for_dataset(self, dataloader): + return readout_dict + + else: + readout_list = [] + for module_name, layers in self._spec.items(): + readout_list.extend( + [self.network._module_map[module_name]._readout_cache[layer].cpu() for layer in layers] + ) + + feats = torch.cat(readout_list, dim=-1) + return self._convert_output_type(feats) + + def get_fingerprints_for_dataset(self, dataloader, store_dict: bool=False): """Return the fingerprints for an entire dataset""" original_out_type = self._out_type @@ -171,13 +185,29 @@ def get_fingerprints_for_dataset(self, dataloader): fps = [] for batch in tqdm.tqdm(dataloader, desc="Fingerprinting batches"): - feats = self.get_fingerprints_for_batch(batch) + feats = self.get_fingerprints_for_batch(batch, store_dict=store_dict) fps.append(feats) - fps = torch.cat(fps, dim=0) - self._out_type = original_out_type - return self._convert_output_type(fps) + + if store_dict: + fps_dict = fps[0] + for key, value in fps_dict.items(): + fps_dict[key] = [value] + for item in fps[1:]: + for key, value in item.items(): + fps_dict[key].extend([value]) + + self._out_type = original_out_type + for key, value in fps_dict.items(): + fps_dict[key] = self._convert_output_type(torch.cat(value, dim=0)) + + return fps_dict + + else: + fps = torch.cat(fps, dim=0) + + return self._convert_output_type(fps) def teardown(self): """Restore the network to its original state""" @@ -202,4 +232,4 @@ def _convert_output_type(self, feats: torch.Tensor): """Small utility function to convert output types""" if self._out_type == "numpy": feats = feats.numpy() - return feats + return feats \ No newline at end of file diff --git a/graphium/graphium_cpp/commute.cpp b/graphium/graphium_cpp/commute.cpp new file mode 100644 index 000000000..8dc13ec67 --- /dev/null +++ b/graphium/graphium_cpp/commute.cpp @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file defines and instantiates the `compute_commute_distances` +//! function, declared in commute.h and called from features.cpp + +#include "commute.h" + +#include "electrostatic.h" +#include "spectral.h" + +#include +#include + +// Computes the "commute distance" between each pair of nodes, outputting to `matrix`. +// See the declaration in commute.h for more details. +template +void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const T* weights) { + + if (laplacian_pseudoinverse.size() == 0) { + compute_laplacian_pseudoinverse(n, row_starts, neighbors, data, laplacian_pseudoinverse, weights); + } + + T full_sum = T(0); + if (weights != nullptr) { + for (size_t i = 0, weights_size = row_starts[n]; i < weights_size; ++i) { + full_sum += weights[i]; + } + } + else { + // Unweighted, so just twice the unique edge count + // (each edge appears twice in neighbors) + full_sum = T(row_starts[n]); + } + + // Allocate the memory for the output + matrix.resize(n * n); + + // Compute the distances from the pseudoinverse + for (size_t row = 0, row_diag_index = 0, i = 0; row < n; ++row, row_diag_index += (n + 1)) { + for (size_t col = 0, col_diag_index = 0; col < n; ++col, ++i, col_diag_index += (n + 1)) { + matrix[i] = full_sum * ( + laplacian_pseudoinverse[row_diag_index] + + laplacian_pseudoinverse[col_diag_index] + - 2 * laplacian_pseudoinverse[row*n + col]); + } + } +} + +// Explicit instantiations of `compute_commute_distances` for `float` and `double` +template void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const float* weights); +template void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const double* weights); diff --git a/graphium/graphium_cpp/commute.h b/graphium/graphium_cpp/commute.h new file mode 100644 index 000000000..9d8b2d871 --- /dev/null +++ b/graphium/graphium_cpp/commute.h @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This header file declares the `compute_commute_distances` function, +//! defined in commute.cpp and called from features.cpp + +#pragma once + +#include "spectral.h" + +#include +#include + +//! Computes the "commute distance", `2*E*(P_ii + P_jj - 2*P_ij)`, for each node pair `ij`, +//! where P is the Laplacian pseudoinverse and E is the total number of unique edges. +//! Template type `T` can be `float` or `double`. Implementation is in commute.cpp +//! +//! @param n Number of nodes +//! @param row_starts Array of `n+1` indices into `neighbors`, indicating where each node's +//! neighbors start, plus one at the end to indicate the full length of +//! `neighbors` +//! @param neighbors Concatenated array of all neighbors of all nodes, in order +//! @param data Cache for the eigendecomposition of the graph Laplacian matrix +//! @param laplacian_pseudoinverse If empty, this will be filled with the pseudoinverse of the +//! graph Laplacian matrix, else its contents will be assumed to +//! contain the cached pseudoinverse of the graph Laplacian +//! @param matrix The output commute distances for all `n^2` node pairs +//! @param weights Optional array of edge weights, in the order corresponding with neighbors. +//! If non-null, the distances will be scaled by the sum of all weights, instead +//! of `2*E`. +template +void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const T* weights = nullptr); + +// Instantiation declarations of `compute_commute_distances` for `float` and `double` +// The explicit instantiations are in commute.cpp +extern template void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const float* weights); +extern template void compute_commute_distances( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const double* weights); diff --git a/graphium/graphium_cpp/electrostatic.cpp b/graphium/graphium_cpp/electrostatic.cpp new file mode 100644 index 000000000..2271f1a1c --- /dev/null +++ b/graphium/graphium_cpp/electrostatic.cpp @@ -0,0 +1,119 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file defines and instantiates the `compute_electrostatic_interactions` +//! and `compute_laplacian_pseudoinverse` functions, +//! declared in electrostatic.h and called from features.cpp and commute.cpp + +#include "electrostatic.h" + +#include "spectral.h" + +#include +#include +#include + +// Computes the pseudoinverse of the graph Laplacian, outputting to `matrix`. +// See the declaration in electrostatic.h for more details. +template +void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const T* weights) { + + // If we've already computed the eigendecomposition with the correct normalization, + // skip recomputing it. + if (data.eigenvalues.size() != n || data.normalization != Normalization::NONE) { + compute_laplacian_eigendecomp(n, row_starts, neighbors, Normalization::NONE, data, 1, nullptr, weights); + } + + // Allocate the space for the output and initialize to zero. + // The clear() call is so that resize() initializes all values to zero. + matrix.clear(); + matrix.resize(size_t(n) * n, T(0)); + const T maxEigenvalue = data.eigenvalues.back(); + // zero_threshold is an estimate of how accurately the diagonalization + // algorithm determines eigenvalues close to zero. Anything smaller + // should be considered zero for the pseudoinverse. + const T eigendecomp_relative_threshold = T(1e-6); + const T zero_threshold = n * eigendecomp_relative_threshold * maxEigenvalue; + for (size_t eigenIndex = 0; eigenIndex < n; ++eigenIndex) { + // This is a positive semi-definite matrix, so we don't need to take the absolute value + // when checking the threshold. + if (data.eigenvalues[eigenIndex] < zero_threshold) { + continue; + } + const T eigenvalueInverse = T(1) / data.eigenvalues[eigenIndex]; + const T* const eigenvector = data.vectors.data() + eigenIndex * n; + for (size_t row = 0, i = 0; row < n; ++row) { + for (size_t col = 0; col < n; ++col, ++i) { + const T value = eigenvalueInverse * eigenvector[row] * eigenvector[col]; + matrix[i] += value; + } + } + } +} + +// Explicit instantiations of `compute_laplacian_pseudoinverse` for `float` and `double` +template void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const float* weights); +template void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const double* weights); + +// Computes the "electrostatic interactions" between each pair of nodes, outputting to `matrix`. +// See the declaration in electrostatic.h for more details. +template +void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const T* weights) { + + if (laplacian_pseudoinverse.size() == 0) { + compute_laplacian_pseudoinverse(n, row_starts, neighbors, data, laplacian_pseudoinverse, weights); + } + + // Allocate the memory for the output + matrix.resize(n * n); + + // Subtract the diagonal value from each column + for (size_t row = 0, i = 0; row < n; ++row) { + for (size_t col = 0, diag_index = 0; col < n; ++col, ++i, diag_index += (n+1)) { + matrix[i] = laplacian_pseudoinverse[i] - laplacian_pseudoinverse[diag_index]; + } + } +} + +// Explicit instantiations of `compute_electrostatic_interactions` for `float` and `double` +template void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const float* weights); +template void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const double* weights); diff --git a/graphium/graphium_cpp/electrostatic.h b/graphium/graphium_cpp/electrostatic.h new file mode 100644 index 000000000..5e10583ab --- /dev/null +++ b/graphium/graphium_cpp/electrostatic.h @@ -0,0 +1,98 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This header file declares the `compute_electrostatic_interactions` +//! and `compute_laplacian_pseudoinverse` functions, +//! defined in electrostatic.cpp and called from features.cpp and commute.cpp + +#pragma once + +#include "spectral.h" + +#include +#include + +//! Computes the pseudoinverse of the graph Laplacian matrix. +//! Template type `T` can be `float` or `double`. Implementation is in electrostatic.cpp +//! +//! @param n Number of nodes +//! @param row_starts Array of `n+1` indices into `neighbors`, indicating where each node's +//! neighbors start, plus one at the end to indicate the full length of +//! `neighbors` +//! @param neighbors Concatenated array of all neighbors of all nodes, in order +//! @param data Cache for the eigendecomposition of the graph Laplacian matrix +//! @param matrix The output pseudoinverse of the graph Laplacian matrix +//! @param weights Optional array of edge weights, in the order corresponding with neighbors. +//! If null, the edge weights are all 1. +template +void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const T* weights = nullptr); + +// Instantiation declarations of `compute_laplacian_pseudoinverse` for `float` and `double` +// The explicit instantiations are in electrostatic.cpp +extern template void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const float* weights); +extern template void compute_laplacian_pseudoinverse( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& matrix, + const double* weights); + +//! Computes the "electrostatic interactions", `P_ij - P_jj`, for each node pair `ij`, +//! where P is the Laplacian pseudoinverse. +//! Template type `T` can be `float` or `double`. Implementation is in electrostatic.cpp +//! +//! @param n Number of nodes +//! @param row_starts Array of `n+1` indices into `neighbors`, indicating where each node's +//! neighbors start, plus one at the end to indicate the full length of +//! `neighbors` +//! @param neighbors Concatenated array of all neighbors of all nodes, in order +//! @param data Cache for the eigendecomposition of the graph Laplacian matrix +//! @param laplacian_pseudoinverse If empty, this will be filled with the pseudoinverse of the +//! graph Laplacian matrix, else its contents will be assumed to +//! contain the cached pseudoinverse of the graph Laplacian +//! @param matrix The output electrostatic interactions for all `n^2` node pairs, i.e. the +//! pseudoinverse of the graph Laplacian matrix, with the diagonal subtracted from +//! each column, stored in row-major order. +//! @param weights Optional array of edge weights, in the order corresponding with neighbors. +//! If null, the edge weights are all 1. +template +void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const T* weights = nullptr); + +// Instantiation declarations of `compute_electrostatic_interactions` for `float` and `double` +// The explicit instantiations are in electrostatic.cpp +extern template void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const float* weights); +extern template void compute_electrostatic_interactions( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + LaplacianData& data, + std::vector& laplacian_pseudoinverse, + std::vector& matrix, + const double* weights); diff --git a/graphium/graphium_cpp/features.cpp b/graphium/graphium_cpp/features.cpp new file mode 100644 index 000000000..744d255f1 --- /dev/null +++ b/graphium/graphium_cpp/features.cpp @@ -0,0 +1,1538 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file defines generic feature-related functions, +//! some of which are declared in features.h for exporting to Python. + +#define DEBUG_LOGGING 0 + +#include "features.h" + +#include "commute.h" +#include "electrostatic.h" +#include "float_features.h" +#include "graphormer.h" +#include "one_hot.h" +#include "random_walk.h" +#include "spectral.h" + +#include // For RDKit's addHs +#include // For RDKit's EmbedMolecule + +#include + +// This is called by `featurize_smiles` to parse the SMILES string into an RWMol and +// cache some data about the atoms and bonds. +static GraphData read_graph(const std::string& smiles_string, bool explicit_H) { + std::unique_ptr mol{ parse_mol(smiles_string, explicit_H) }; + + if (!mol) { + return GraphData{ 0, std::unique_ptr(), 0, std::unique_ptr(), std::move(mol) }; + } + + const size_t num_atoms = mol->getNumAtoms(); + const size_t num_bonds = mol->getNumBonds(); +#if DEBUG_LOGGING + printf("# atoms = %zu\n# bonds = %zu\n", num_atoms, num_bonds); +#endif +#if REPORT_STATS + ++statsMolAtomCounts[(num_atoms >= STATS_NUM_MOL_ATOM_COUNTS) ? (STATS_NUM_MOL_ATOM_COUNTS - 1) : num_atoms]; + ++statsMolBondCounts[(num_bonds >= STATS_NUM_MOL_BOND_COUNTS) ? (STATS_NUM_MOL_BOND_COUNTS - 1) : num_bonds]; + statsTotalNumAtoms += num_atoms; + statsTotalNumBonds += num_bonds; +#endif + +#if ORDER_ATOMS + // Determine a canonical ordering of the atoms, if desired. + std::vector atomOrder; + atomOrder.reserve(num_atoms); + RDKit::Canon::rankMolAtoms(*mol, atomOrder); + assert(atomOrder.size() == num_atoms); +#endif + + // Allocate an array of atom data, and fill it from the RDKit atom data. + std::unique_ptr atoms(new CompactAtom[num_atoms]); + for (size_t atomIdx = 0; atomIdx < num_atoms; ++atomIdx) { + const RDKit::Atom* const atom = mol->getAtomWithIdx(atomIdx); + auto atomicNum = atom->getAtomicNum(); + auto totalDegree = atom->getTotalDegree(); + auto formalCharge = atom->getFormalCharge(); + const RDKit::Atom::ChiralType chiralType = atom->getChiralTag(); + auto totalNumHs = atom->getTotalNumHs(); + const RDKit::Atom::HybridizationType hybridization = atom->getHybridization(); + + const bool isAromatic = atom->getIsAromatic(); +#if REPORT_STATS + ++statsElementCounts[(atomicNum < 0 || atomicNum >= STATS_NUM_ELEMENTS) ? (STATS_NUM_ELEMENTS - 1) : atomicNum]; + ++statsDegreeCounts[(totalDegree < 0 || totalDegree >= STATS_NUM_DEGREES) ? (STATS_NUM_DEGREES - 1) : totalDegree]; + size_t formalChargeIndex = formalCharge + int(STATS_CHARGE_OFFSET); + if (formalCharge < -int(STATS_CHARGE_OFFSET)) { + formalChargeIndex = 0; + } + else if (formalCharge > int(STATS_CHARGE_OFFSET)) { + formalChargeIndex = STATS_NUM_CHARGES - 1; + } + + ++statsChargeCounts[formalChargeIndex]; + ++statsChiralityCounts[(size_t(chiralType) >= STATS_NUM_CHIRALITIES) ? (STATS_NUM_CHIRALITIES - 1) : size_t(chiralType)]; + ++statsHCounts[(totalNumHs < 0 || totalNumHs >= STATS_NUM_HS) ? (STATS_NUM_HS - 1) : totalNumHs]; + ++statsHybridizationCounts[(size_t(hybridization) >= STATS_NUM_HYBRIDIZATIONS) ? (STATS_NUM_HYBRIDIZATIONS - 1) : size_t(hybridization)]; + statsAromaticAtomCount += (isAromatic ? 1 : 0); +#endif + const double mass = atom->getMass(); + +#if ORDER_ATOMS + const size_t destAtomIdx = atomOrder[atomIdx]; +#else + const size_t destAtomIdx = atomIdx; +#endif + atoms[destAtomIdx] = CompactAtom{ + uint8_t(atomicNum), + uint8_t(totalDegree), + int8_t(formalCharge), + uint8_t(chiralType), + uint8_t(totalNumHs), + uint8_t(hybridization), + isAromatic, + float(mass) + }; +#if DEBUG_LOGGING + printf( + "atom[%zu] = {%zu, %u, %d, %u, %u, %u, %s, %f}\n", + destAtomIdx, + int(atomicNum), + int(totalDegree), + int(formalCharge), + int(chiralType), + int(totalNumHs), + int(hybridization), + isAromatic ? "true" : "false", + mass + ); +#endif + } + + // Allocate an array of bond data, and fill it from the RDKit bond data. + std::unique_ptr bonds(new CompactBond[num_bonds]); + const RDKit::RingInfo* const ringInfo = mol->getRingInfo(); + for (size_t bondIdx = 0; bondIdx < num_bonds; ++bondIdx) { + const RDKit::Bond* const bond = mol->getBondWithIdx(bondIdx); + const RDKit::Bond::BondType bondType = bond->getBondType(); + const bool isConjugated = bond->getIsConjugated(); + // TODO: Verify that it's the same index as bond->getIdx() + const bool isInRing = (ringInfo->numBondRings(bondIdx) != 0); + const RDKit::Bond::BondStereo stereo = bond->getStereo(); + +#if REPORT_STATS + ++statsBondTypeCounts[(size_t(bondType) >= STATS_NUM_BOND_TYPES) ? (STATS_NUM_BOND_TYPES - 1) : size_t(bondType)]; + ++statsBondStereoCounts[(size_t(stereo) >= STATS_NUM_BOND_STEREOS) ? (STATS_NUM_BOND_STEREOS - 1) : size_t(stereo)]; + statsConjugatedBondCount += (isConjugated ? 1 : 0); + statsBondInRingCount += (isInRing ? 1 : 0); +#endif + + auto beginAtomIdx = bond->getBeginAtomIdx(); + auto endAtomIdx = bond->getEndAtomIdx(); +#if ORDER_ATOMS + beginAtomIdx = atomOrder[beginAtomIdx]; + endAtomIdx = atomOrder[endAtomIdx]; +#endif + bonds[bondIdx] = CompactBond{ + uint8_t(bondType), + isConjugated, + isInRing, + uint8_t(stereo), + beginAtomIdx, + endAtomIdx + }; +#if DEBUG_LOGGING + printf( + "bond[%zu] = {%u, %s, %s, %u, {%u, %u}}\n", + bondIdx, + int(bondType), + isConjugated ? "true" : "false", + isInRing ? "true" : "false", + int(stereo), + beginAtomIdx, + endAtomIdx + ); +#endif + } + + // Return a GraphData structure, taking ownership of the atom and bond data arrays. + return GraphData{ num_atoms, std::move(atoms), num_bonds, std::move(bonds), std::move(mol) }; +} + +// This is a structure for managing the adjacency data (CSR format) +struct NeighborData { + // This owns the data of all 3 arrays, which are actually a single, contiguous allocation. + std::unique_ptr deleter; + + // This is an array of indices into the other two arrays, indicating where + // each atom's neighbors start, including the first entry being 0 for the start of + // atom 0, and the num_atoms entry being 2*num_bonds (2x because each bond is on 2 atoms), + // so there are num_atoms+1 entries. The number of neighbors of an atom i is + // neighbor_starts[i+1]-neighbor_starts[i] + const uint32_t* neighbor_starts; + + // The neighbor atom for each bond, with each atom having an entry for each of + // its neighbors, so each bond occurs twice. + const uint32_t* neighbors; + + // This is in the same order as neighbors, but indicates the index of the bond. + // Each bond occurs twice, so each number occurs twice. + const uint32_t* bond_indices; +}; + +// Construct a NeighborData structure representing the molecule's graph in CSR format. +static NeighborData construct_neighbors(const GraphData& graph) { + const uint32_t num_atoms = graph.num_atoms; + const uint32_t num_bonds = graph.num_bonds; + // Do a single allocation for all 3 arrays. + std::unique_ptr deleter(new uint32_t[num_atoms + 1 + 4 * num_bonds]); + + uint32_t* neighbor_starts = deleter.get(); + for (uint32_t i = 0; i <= num_atoms; ++i) { + neighbor_starts[i] = 0; + } + + // First, get atom neighbor counts + const CompactBond* const bonds = graph.bonds.get(); + for (uint32_t i = 0; i < num_bonds; ++i) { + uint32_t a = bonds[i].beginAtomIdx; + uint32_t b = bonds[i].endAtomIdx; + // NOTE: +1 is because first entry will stay zero. + ++neighbor_starts[a + 1]; + ++neighbor_starts[b + 1]; + } + + // Find the starts by partial-summing the neighbor counts. + // NOTE: +1 is because first entry will stay zero. + std::partial_sum(neighbor_starts + 1, neighbor_starts + 1 + num_atoms, neighbor_starts + 1); + + // Fill in the neighbors and bond_indices arrays. + uint32_t* neighbors = neighbor_starts + num_atoms + 1; + uint32_t* bond_indices = neighbors + 2 * num_bonds; + for (uint32_t i = 0; i < num_bonds; ++i) { + uint32_t a = bonds[i].beginAtomIdx; + uint32_t b = bonds[i].endAtomIdx; + + uint32_t ai = neighbor_starts[a]; + neighbors[ai] = b; + bond_indices[ai] = i; + ++neighbor_starts[a]; + + uint32_t bi = neighbor_starts[b]; + neighbors[bi] = a; + bond_indices[bi] = i; + ++neighbor_starts[b]; + } + + // Shift neighbor_starts forward one after incrementing it. + uint32_t previous = 0; + for (uint32_t i = 0; i < num_atoms; ++i) { + uint32_t next = neighbor_starts[i]; + neighbor_starts[i] = previous; + previous = next; + } + + // NeighborData takes ownership of the memory. + return NeighborData{ std::move(deleter), neighbor_starts, neighbors, bond_indices }; +} + +// This is called by `create_all_features` to create a Torch tensor representing 3D atom +// positions. All atom positions are concatenated into a 1D tensor of length `3*num_atoms`. +template +at::Tensor get_conformer_features( + RDKit::ROMol &mol, + bool already_has_Hs, + c10::ScalarType dtype, + MaskNaNStyle mask_nan_style, + T mask_nan_value, + int64_t &num_nans, + const std::string& smiles_string) { + + const size_t n = mol.getNumAtoms(); + std::unique_ptr conformer_data(new T[3 * n]); + T* data = conformer_data.get(); + + std::unique_ptr mol_with_Hs_added; + RDKit::ROMol* mol_with_Hs = &mol; + if (mol.beginConformers() == mol.endConformers()) { + // No conformers. + // Before generating conformers, it's recommended to add Hs explicitly. + if (!already_has_Hs) { + // Add Hs. They're added at the end, so the original atoms + // will have the same indices as before. + mol_with_Hs_added.reset(new RDKit::RWMol(mol)); + RDKit::MolOps::addHs(*mol_with_Hs_added); + mol_with_Hs = mol_with_Hs_added.get(); + } + + // Default Python arguments to EmbedMolecule + int conformer_id = RDKit::DGeomHelpers::EmbedMolecule( + *mol_with_Hs, + 0, // maxIterations + -1, // seed + true, // clearConfs + false, // useRandomCoords + 2.0, // boxSizeMult + true, // randNedEig + 1, // numZeroFail + nullptr,// coordMap + 1e-3, // optimizerForceTol + false, // ignoreSmoothingFailures + true, // enforceChirality + true, // useExpTorsionAnglePrefs (default in Python; non-default in C++) + true, // useBasicKnowledge (default in Python; non-default in C++) + false, // verbose + 5.0, // basinThresh + false, // onlyHeavyAtomsForRMS + 1, // ETversion + false, // useSmallRingTorsions + false, // useMacrocycleTorsions + false // useMacrocycle14config + ); + + if (conformer_id == -1) { + // Custom arguments as fallback + RDKit::DGeomHelpers::EmbedMolecule( + *mol_with_Hs, + 0, // maxIterations + -1, // seed + true, // clearConfs + false, // useRandomCoords (TODO: consider using true) + 2.0, // boxSizeMult + true, // randNedEig + 1, // numZeroFail + nullptr,// coordMap + 0.1, // optimizerForceTol (changed) + true, // ignoreSmoothingFailures (changed) + false, // enforceChirality (changed) + true, // useExpTorsionAnglePrefs (default in Python; non-default in C++) + true, // useBasicKnowledge (default in Python; non-default in C++) + false, // verbose + 5.0, // basinThresh + false, // onlyHeavyAtomsForRMS + 1, // ETversion + false, // useSmallRingTorsions + false, // useMacrocycleTorsions + false // useMacrocycle14config + ); + } + } + if (mol_with_Hs->beginConformers() == mol_with_Hs->endConformers()) { + // Still no conformers: treat as NaN + for (size_t i = 0; i < 3 * n; ++i) { + data[i] = mask_nan_value; + } + if (mask_nan_style == MaskNaNStyle::REPORT) { + num_nans += 3*n; + } + printf("Warning: Couldn't compute conformer for molecule \"%s\"\n", smiles_string.c_str()); + } + else { + const RDKit::Conformer& conformer = mol_with_Hs->getConformer(); + const auto& positions = conformer.getPositions(); + assert(positions.size() >= n); + for (size_t i = 0; i < n; ++i, data += 3) { + const auto& position = positions[i]; + data[0] = FeatureValues::convertToFeatureType(position.x); + data[1] = FeatureValues::convertToFeatureType(position.y); + data[2] = FeatureValues::convertToFeatureType(position.z); + } + + num_nans += mask_nans(data, 3 * n, mask_nan_style, mask_nan_value); + } + + const int64_t dims[1] = { int64_t(3 * n) }; + return torch_tensor_from_array(std::move(conformer_data), dims, 1, dtype); +} + +// Maps float atom feature name strings to `AtomFloatFeature` enum values +static const std::unordered_map atom_float_name_to_enum{ + {std::string("atomic-number"), int64_t(AtomFloatFeature::ATOMIC_NUMBER)}, + {std::string("mass"), int64_t(AtomFloatFeature::MASS)}, + {std::string("weight"), int64_t(AtomFloatFeature::MASS)}, + {std::string("valence"), int64_t(AtomFloatFeature::VALENCE)}, + {std::string("total-valence"), int64_t(AtomFloatFeature::VALENCE)}, + {std::string("implicit-valence"), int64_t(AtomFloatFeature::IMPLICIT_VALENCE)}, + {std::string("hybridization"), int64_t(AtomFloatFeature::HYBRIDIZATION)}, + {std::string("chirality"), int64_t(AtomFloatFeature::CHIRALITY)}, + {std::string("aromatic"), int64_t(AtomFloatFeature::AROMATIC)}, + {std::string("ring"), int64_t(AtomFloatFeature::IN_RING)}, + {std::string("in-ring"), int64_t(AtomFloatFeature::IN_RING)}, + {std::string("min-ring"), int64_t(AtomFloatFeature::MIN_RING)}, + {std::string("max-ring"), int64_t(AtomFloatFeature::MAX_RING)}, + {std::string("num-ring"), int64_t(AtomFloatFeature::NUM_RING)}, + {std::string("degree"), int64_t(AtomFloatFeature::DEGREE)}, + {std::string("radical-electron"), int64_t(AtomFloatFeature::RADICAL_ELECTRON)}, + {std::string("formal-charge"), int64_t(AtomFloatFeature::FORMAL_CHARGE)}, + {std::string("vdw-radius"), int64_t(AtomFloatFeature::VDW_RADIUS)}, + {std::string("covalent-radius"), int64_t(AtomFloatFeature::COVALENT_RADIUS)}, + {std::string("electronegativity"), int64_t(AtomFloatFeature::ELECTRONEGATIVITY)}, + {std::string("ionization"), int64_t(AtomFloatFeature::IONIZATION)}, + {std::string("first-ionization"), int64_t(AtomFloatFeature::IONIZATION)}, + {std::string("melting-point"), int64_t(AtomFloatFeature::MELTING_POINT)}, + {std::string("metal"), int64_t(AtomFloatFeature::METAL)}, + {std::string("group"), int64_t(AtomFloatFeature::GROUP)}, + {std::string("period"), int64_t(AtomFloatFeature::PERIOD)}, + {std::string("single-bond"), int64_t(AtomFloatFeature::SINGLE_BOND)}, + {std::string("aromatic-bond"), int64_t(AtomFloatFeature::AROMATIC_BOND)}, + {std::string("double-bond"), int64_t(AtomFloatFeature::DOUBLE_BOND)}, + {std::string("triple-bond"), int64_t(AtomFloatFeature::TRIPLE_BOND)}, + {std::string("is-carbon"), int64_t(AtomFloatFeature::IS_CARBON)}, +}; + +// This is called from Python to list atom float features in a format that will be faster +// to interpret inside `featurize_smiles`, passed in the `atom_property_list_float` parameter. +// See the declaration in features.h for more details. +at::Tensor atom_float_feature_names_to_tensor(const std::vector& features) { + const size_t num_features = features.size(); + std::unique_ptr feature_enum_values(new int64_t[num_features]); + for (size_t i = 0; i < num_features; ++i) { + auto it = atom_float_name_to_enum.find(features[i]); + if (it != atom_float_name_to_enum.end()) { + feature_enum_values[i] = it->second; + } + else { + feature_enum_values[i] = int64_t(AtomFloatFeature::UNKNOWN); + } + } + const int64_t dims[1] = { int64_t(num_features) }; + return torch_tensor_from_array(std::move(feature_enum_values), dims, 1, c10::ScalarType::Long); +} + +// Maps one-hot atom feature name strings to `AtomOneHotFeature` enum values +static const std::unordered_map atom_onehot_name_to_enum{ + {std::string("atomic-number"), int64_t(AtomOneHotFeature::ATOMIC_NUM)}, + {std::string("degree"), int64_t(AtomOneHotFeature::DEGREE)}, + {std::string("valence"), int64_t(AtomOneHotFeature::VALENCE)}, + {std::string("total-valence"), int64_t(AtomOneHotFeature::VALENCE)}, + {std::string("implicit-valence"), int64_t(AtomOneHotFeature::IMPLICIT_VALENCE)}, + {std::string("hybridization"), int64_t(AtomOneHotFeature::HYBRIDIZATION)}, + {std::string("chirality"), int64_t(AtomOneHotFeature::CHIRALITY)}, + {std::string("phase"), int64_t(AtomOneHotFeature::PHASE)}, + {std::string("type"), int64_t(AtomOneHotFeature::TYPE)}, + {std::string("group"), int64_t(AtomOneHotFeature::GROUP)}, + {std::string("period"), int64_t(AtomOneHotFeature::PERIOD)}, +}; + +// This is called from Python to list atom one-hot features in a format that will be faster +// to interpret inside `featurize_smiles`, passed in the `atom_property_list_onehot` parameter. +// See the declaration in features.h for more details. +at::Tensor atom_onehot_feature_names_to_tensor(const std::vector& features) { + const size_t num_features = features.size(); + std::unique_ptr feature_enum_values(new int64_t[num_features]); + for (size_t i = 0; i < num_features; ++i) { + auto it = atom_onehot_name_to_enum.find(features[i]); + if (it != atom_onehot_name_to_enum.end()) { + feature_enum_values[i] = it->second; + } + else { + feature_enum_values[i] = int64_t(AtomOneHotFeature::UNKNOWN); + } + } + const int64_t dims[1] = { int64_t(num_features) }; + return torch_tensor_from_array(std::move(feature_enum_values), dims, 1, c10::ScalarType::Long); +} + +// Maps bond feature name strings to `BondFeature` enum values +static const std::unordered_map bond_name_to_enum{ + {std::string("bond-type-onehot"), int64_t(BondFeature::TYPE_ONE_HOT)}, + {std::string("bond-type-float"), int64_t(BondFeature::TYPE_FLOAT)}, + {std::string("stereo"), int64_t(BondFeature::STEREO_ONE_HOT)}, + {std::string("in-ring"), int64_t(BondFeature::IN_RING)}, + {std::string("conjugated"), int64_t(BondFeature::CONJUGATED)}, + {std::string("conformer-bond-length"), int64_t(BondFeature::CONFORMER_BOND_LENGTH)}, + {std::string("estimated-bond-length"), int64_t(BondFeature::ESTIMATED_BOND_LENGTH)}, +}; + +// This is called from Python to list bond features in a format that will be faster +// to interpret inside `featurize_smiles`, passed in the `bond_property_list` parameter. +// See the declaration in features.h for more details. +at::Tensor bond_feature_names_to_tensor(const std::vector& features) { + const size_t num_features = features.size(); + std::unique_ptr feature_enum_values(new int64_t[num_features]); + for (size_t i = 0; i < num_features; ++i) { + auto it = bond_name_to_enum.find(features[i]); + if (it != bond_name_to_enum.end()) { + feature_enum_values[i] = it->second; + } + else { + feature_enum_values[i] = int64_t(BondFeature::UNKNOWN); + } + } + const int64_t dims[1] = { int64_t(num_features) }; + return torch_tensor_from_array(std::move(feature_enum_values), dims, 1, c10::ScalarType::Long); +} + +// Maps positional feature name strings to `PositionalFeature` enum values +static const std::unordered_map positional_name_to_enum{ + {std::string("laplacian_eigvec"), int64_t(PositionalFeature::LAPLACIAN_EIGENVEC)}, + {std::string("laplacian_eigval"), int64_t(PositionalFeature::LAPLACIAN_EIGENVAL)}, + {std::string("rw_return_probs"), int64_t(PositionalFeature::RW_RETURN_PROBS)}, + {std::string("rw_transition_probs"), int64_t(PositionalFeature::RW_TRANSITION_PROBS)}, + {std::string("electrostatic"), int64_t(PositionalFeature::ELECTROSTATIC)}, + {std::string("commute"), int64_t(PositionalFeature::COMMUTE)}, + {std::string("graphormer"), int64_t(PositionalFeature::GRAPHORMER)}, +}; + +// Maps feature level strings to `FeatureLevel` enum values +static const std::unordered_map feature_level_to_enum{ + {std::string("node"), int64_t(FeatureLevel::NODE)}, + {std::string("edge"), int64_t(FeatureLevel::EDGE)}, + {std::string("nodepair"), int64_t(FeatureLevel::NODEPAIR)}, + {std::string("graph"), int64_t(FeatureLevel::GRAPH)}, +}; + +// Maps normalization type strings to `Normalization` enum values +static const std::unordered_map normalization_to_enum{ + {std::string("none"), int64_t(Normalization::NONE)}, + {std::string("inv"), int64_t(Normalization::INVERSE)}, + {std::string("sym"), int64_t(Normalization::SYMMETRIC)}, +}; + +// This is called from Python to list positional features and their options in a format that +// will be faster to interpret inside `featurize_smiles`, passed in the `bond_property_list` +// parameter. +// See the declaration in features.h for more details. +std::pair,at::Tensor> positional_feature_options_to_tensor( + const pybind11::dict& dict) { + size_t num_features = 0; + size_t num_values = 0; + for (const auto& pair : dict) { + // The string keys (pair.first) of the outer dictionary aren't needed for this + if (!pybind11::isinstance(pair.second)) { + continue; + } + pybind11::dict feature_dict = pair.second.cast(); + pybind11::handle feature_name_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "pos_type")); + pybind11::handle feature_level_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "pos_level")); + if (!feature_name_handle || !feature_level_handle) { + continue; + } + std::string feature_name{ pybind11::str(feature_name_handle) }; + std::string feature_level{ pybind11::str(feature_level_handle) }; + + auto feature_it = positional_name_to_enum.find(feature_name); + auto level_it = feature_level_to_enum.find(feature_level); + if (feature_it == positional_name_to_enum.end() || level_it == feature_level_to_enum.end()) { + continue; + } + + PositionalFeature feature = PositionalFeature(feature_it->second); + switch (feature) { + case PositionalFeature::LAPLACIAN_EIGENVEC: + case PositionalFeature::LAPLACIAN_EIGENVAL: { + // Required int num_pos + pybind11::handle num_pos_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "num_pos")); + if (!num_pos_handle || !pybind11::isinstance(num_pos_handle)) { + break; + } + // Optional string normalization + pybind11::handle normalization_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "normalization")); + if (normalization_handle) { + if (!pybind11::isinstance(normalization_handle)) { + break; + } + std::string normalization_name{ pybind11::str(normalization_handle) }; + if (!normalization_to_enum.contains(normalization_name)) { + break; + } + } + // Optional bool disconnected_comp + pybind11::handle disconnected_comp_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "disconnected_comp")); + if (disconnected_comp_handle && !pybind11::isinstance(disconnected_comp_handle)) { + break; + } + num_values += 3 + 3; + ++num_features; + break; + } + case PositionalFeature::RW_RETURN_PROBS: + case PositionalFeature::RW_TRANSITION_PROBS: { + pybind11::handle ksteps_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "ksteps")); + if (!ksteps_handle) { + break; + } + int64_t power_count = 0; + if (pybind11::isinstance(ksteps_handle)) { + power_count = int64_t(ksteps_handle.cast()); + } + else if (pybind11::isinstance(ksteps_handle)) { + power_count = ksteps_handle.cast().size(); + } + if (power_count < 1) { + break; + } + pybind11::handle space_dim_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "space_dim")); + if (space_dim_handle && !pybind11::isinstance(space_dim_handle)) { + break; + } + num_values += 3 + 1 + power_count; + ++num_features; + break; + } + case PositionalFeature::ELECTROSTATIC: + case PositionalFeature::COMMUTE: + case PositionalFeature::GRAPHORMER: + num_values += 3; + ++num_features; + break; + } + } + + std::unique_ptr values(new int64_t[num_values]); + std::vector names(num_features); + + size_t prev_feature_index = 0; + size_t feature_index = 0; + size_t value_index = 0; + for (const auto& pair : dict) { + // The string keys (pair.first) of the outer dictionary aren't needed for this + if (!pybind11::isinstance(pair.second)) { + continue; + } + pybind11::dict feature_dict = pair.second.cast(); + pybind11::handle feature_name_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "pos_type")); + pybind11::handle feature_level_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "pos_level")); + if (!feature_name_handle || !feature_level_handle) { + continue; + } + std::string feature_name{ pybind11::str(feature_name_handle) }; + std::string feature_level{ pybind11::str(feature_level_handle) }; + + auto feature_it = positional_name_to_enum.find(feature_name); + auto level_it = feature_level_to_enum.find(feature_level); + if (feature_it == positional_name_to_enum.end() || level_it == feature_level_to_enum.end()) { + continue; + } + + PositionalFeature feature = PositionalFeature(feature_it->second); + switch (feature) { + case PositionalFeature::LAPLACIAN_EIGENVEC: + case PositionalFeature::LAPLACIAN_EIGENVAL: { + // Required int num_pos + pybind11::handle num_pos_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "num_pos")); + if (!num_pos_handle || !pybind11::isinstance(num_pos_handle)) { + continue; + } + // Optional string normalization + pybind11::handle normalization_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "normalization")); + Normalization normalization = Normalization::NONE; + if (normalization_handle) { + if (!pybind11::isinstance(normalization_handle)) { + continue; + } + std::string normalization_name{ pybind11::str(normalization_handle) }; + auto it = normalization_to_enum.find(normalization_name); + if (it == normalization_to_enum.end()) { + continue; + } + normalization = Normalization(it->second); + } + // Optional bool disconnected_comp + pybind11::handle disconnected_comp_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "disconnected_comp")); + if (disconnected_comp_handle && !pybind11::isinstance(disconnected_comp_handle)) { + continue; + } + values[value_index] = feature_it->second; + values[value_index + 1] = 3; + values[value_index + 2] = level_it->second; + values[value_index + 3] = int64_t(num_pos_handle.cast()); + values[value_index + 4] = int64_t(normalization); + values[value_index + 5] = disconnected_comp_handle ? bool(disconnected_comp_handle.cast()) : true; + value_index += 3 + 3; + ++feature_index; + break; + } + case PositionalFeature::RW_RETURN_PROBS: + case PositionalFeature::RW_TRANSITION_PROBS: { + // Required int or list[int] ksteps + pybind11::handle ksteps_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "ksteps")); + if (!ksteps_handle) { + continue; + } + int64_t power_count = 0; + if (pybind11::isinstance(ksteps_handle)) { + // Integer means use all powers from 1 up to this value, inclusive. + power_count = int64_t(ksteps_handle.cast()); + } + else if (pybind11::isinstance(ksteps_handle)) { + power_count = ksteps_handle.cast().size(); + } + if (power_count < 1) { + break; + } + // Optional int space_dim + pybind11::handle space_dim_handle = pybind11::handle(PyDict_GetItemString(feature_dict.ptr(), "space_dim")); + if (space_dim_handle && !pybind11::isinstance(space_dim_handle)) { + break; + } + values[value_index] = feature_it->second; + values[value_index + 1] = 1 + power_count; + values[value_index + 2] = level_it->second; + + int64_t space_dim = space_dim_handle ? int64_t(space_dim_handle.cast()) : 0; + values[value_index + 3] = space_dim; + if (pybind11::isinstance(ksteps_handle)) { + for (int64_t power = 1; power <= power_count; ++power) { + values[value_index + 3 + power] = power; + } + } + else if (pybind11::isinstance(ksteps_handle)) { + size_t power_index = 0; + int64_t prev_power = 0; + for(const auto item : ksteps_handle.cast()) { + int64_t next_power = pybind11::isinstance(item) ? int64_t(item.cast()) : prev_power; + if (next_power < prev_power) { + // Force the integers to be ascending + next_power = prev_power; + } + values[value_index + 3 + 1 + power_index] = next_power; + prev_power = next_power; + ++power_index; + } + } + value_index += 3 + 1 + power_count; + ++feature_index; + break; + } + case PositionalFeature::ELECTROSTATIC: + case PositionalFeature::COMMUTE: + case PositionalFeature::GRAPHORMER: + values[value_index] = feature_it->second; + values[value_index + 1] = 0; + values[value_index + 2] = level_it->second; + value_index += 3; + ++feature_index; + break; + } + if (feature_index != prev_feature_index) { + names[prev_feature_index] = (level_it->second == int64_t(FeatureLevel::NODE)) ? feature_name : (feature_level + std::string("_") + feature_name); + ++prev_feature_index; + } + } + assert(feature_index == num_features && prev_feature_index == num_features && value_index == num_values); + + const int64_t dims[1] = { int64_t(num_values) }; + return std::make_pair( + std::move(names), + torch_tensor_from_array(std::move(values), dims, 1, c10::ScalarType::Long)); +} + +// This is called by `create_all_features` to create the edge weights Torch tensor, +// including duplicating edges for both directions, and optionally adding self loops. +template +at::Tensor create_edge_weights( + const GraphData& graph, + bool duplicate_edges, + bool add_self_loop, + bool use_bonds_weights, + c10::ScalarType dtype) { + + const size_t edge_coo_count = (duplicate_edges ? 2*graph.num_bonds : graph.num_bonds) + + (add_self_loop ? graph.num_atoms : 0); + std::unique_ptr edge_weights(new T[edge_coo_count]); + + // TODO: Use use_bonds_weights to optionally give weights + // in same order as other edge features + for (size_t i = 0; i < edge_coo_count; ++i) { + edge_weights[i] = FeatureValues::one; + } + + const int64_t dims[1] = { int64_t(edge_coo_count) }; + return torch_tensor_from_array(std::move(edge_weights), dims, 1, dtype); +} + +// This is called by `create_all_features` to create the atom (node) features Torch tensor. +template +at::Tensor create_atom_features( + const GraphData& graph, + const at::Tensor& atom_property_list_onehot, + const at::Tensor& atom_property_list_float, + bool offset_carbon, + c10::ScalarType dtype, + MaskNaNStyle mask_nan_style, + T mask_nan_value, + int64_t &num_nans) { + + const size_t num_onehot_properties = (atom_property_list_onehot.scalar_type() == c10::ScalarType::Long && atom_property_list_onehot.ndimension() == 1) ? atom_property_list_onehot.size(0) : 0; + // NOTE: If TensorBase::data_ptr is ever removed, change it to TensorBase::const_data_ptr. + // Some torch version being used doesn't have const_data_ptr yet. + const int64_t* const property_list_onehot = (num_onehot_properties != 0) ? atom_property_list_onehot.data_ptr() : nullptr; + const size_t num_float_properties = (atom_property_list_float.scalar_type() == c10::ScalarType::Long && atom_property_list_float.ndimension() == 1) ? atom_property_list_float.size(0) : 0; + const int64_t* const property_list_float = (num_float_properties != 0) ? atom_property_list_float.data_ptr() : nullptr; + + size_t single_atom_float_count = num_float_properties; + for (size_t i = 0; i < num_onehot_properties; ++i) { + const int64_t property = property_list_onehot[i]; + single_atom_float_count += get_one_hot_atom_feature_size(AtomOneHotFeature(property)); + } + const size_t atom_float_count = single_atom_float_count * graph.num_atoms; + + std::unique_ptr atom_data(new T[atom_float_count]); + + T* current_atom_data = atom_data.get(); + + for (size_t i = 0; i < num_float_properties; ++i) { + const int64_t property = property_list_float[i]; + get_atom_float_feature(graph, current_atom_data, AtomFloatFeature(property), single_atom_float_count, offset_carbon); + ++current_atom_data; + } + for (size_t i = 0; i < num_onehot_properties; ++i) { + const int64_t property = property_list_onehot[i]; + current_atom_data += get_one_hot_atom_feature(graph, current_atom_data, AtomOneHotFeature(property), single_atom_float_count); + } + + num_nans += mask_nans(atom_data.get(), atom_float_count, mask_nan_style, mask_nan_value); + + const int64_t dims[2] = { int64_t(graph.num_atoms), int64_t(single_atom_float_count) }; + return torch_tensor_from_array(std::move(atom_data), dims, 2, dtype); +} + +// This is called by `create_all_features` to create the bond (edge) features Torch tensor. +template +at::Tensor create_bond_features( + const GraphData& graph, + const at::Tensor& bond_property_list, + const bool duplicate_edges, + bool add_self_loop, + c10::ScalarType dtype, + MaskNaNStyle mask_nan_style, + T mask_nan_value, + int64_t& num_nans) { + + const size_t num_properties = (bond_property_list.scalar_type() == c10::ScalarType::Long && bond_property_list.ndimension() == 1) ? bond_property_list.size(0) : 0; + const int64_t* const property_list = (num_properties != 0) ? bond_property_list.data_ptr() : nullptr; + + size_t single_bond_float_count = 0; + for (size_t i = 0; i < num_properties; ++i) { + const int64_t property = property_list[i]; + if (BondFeature(property) == BondFeature::TYPE_ONE_HOT || BondFeature(property) == BondFeature::STEREO_ONE_HOT) { + single_bond_float_count += get_one_hot_bond_feature_size(BondFeature(property)); + } + else { + ++single_bond_float_count; + } + } + // add_self_loop is only supported if duplicating edges + add_self_loop = add_self_loop && duplicate_edges; + size_t total_edge_count = graph.num_bonds; + if (duplicate_edges) { + total_edge_count = 2*total_edge_count + size_t(add_self_loop); + } + const size_t bond_float_count = single_bond_float_count * total_edge_count; + + std::unique_ptr bond_data(new T[bond_float_count]); + + T* current_bond_data = bond_data.get(); + + // This is the stride length (in floats) for each unique bond + const size_t duplicated_bond_float_count = duplicate_edges ? (2*single_bond_float_count) : single_bond_float_count; + + for (size_t i = 0; i < num_properties; ++i) { + const int64_t property = property_list[i]; + if (BondFeature(property) == BondFeature::TYPE_ONE_HOT || BondFeature(property) == BondFeature::STEREO_ONE_HOT) { + current_bond_data += get_one_hot_bond_feature(graph, current_bond_data, BondFeature(property), duplicated_bond_float_count); + } + else { + get_bond_float_feature(graph, current_bond_data, BondFeature(property), duplicated_bond_float_count); + ++current_bond_data; + } + } + + if (duplicate_edges) { + current_bond_data = bond_data.get(); + // Duplicate the data for each bond + for (size_t i = 0; i < graph.num_bonds; ++i) { + for (size_t j = 0; j < single_bond_float_count; ++j) { + current_bond_data[j+single_bond_float_count] = current_bond_data[j]; + } + current_bond_data += duplicated_bond_float_count; + } + if (add_self_loop) { + // Self loops don't have valid bond data, but don't treat them as NaNs. + // Fill with zeros, instead. + memset(current_bond_data, 0, graph.num_atoms * graph.num_atoms); + } + } + + num_nans += mask_nans(bond_data.get(), bond_float_count, mask_nan_style, mask_nan_value); + + int64_t dims[2] = { int64_t(total_edge_count), int64_t(single_bond_float_count) }; + return torch_tensor_from_array(std::move(bond_data), dims, 2, dtype); +} + +// This is called by `create_positional_features` to convert node-level feature data +// to edge-level feature data. Each edge has the average of the two nodes for all floats +// of the feature, followed by the absolute difference of the two nodes for all floats of +// the feature. +template +void node_to_edge( + std::unique_ptr& output_ptr, + size_t& floats_per_half_edge, + const IN_T* input, + const size_t n, + const size_t floats_per_node, + const GraphData& graph) { + + // Edge order must be consistent with the edges in the graph, + // which is not necessarily lexicographic order. + const size_t num_half_edges = 2*graph.num_bonds; + floats_per_half_edge = 2 * floats_per_node; + output_ptr.reset(new OUT_T[num_half_edges * 2 * floats_per_node]); + OUT_T* output = output_ptr.get(); + for (size_t bond = 0; bond < graph.num_bonds; ++bond, output += 2*floats_per_half_edge) { + const size_t atomi = graph.bonds[bond].beginAtomIdx; + const size_t atomj = graph.bonds[bond].endAtomIdx; + const IN_T* input_i = input + atomi * floats_per_node; + const IN_T* input_j = input + atomj * floats_per_node; + // For each edge, record all of the sums followed by all of the absolute differences + OUT_T* output_sum = output; + OUT_T* output_absdiff = output + floats_per_node; + for (size_t float_index = 0; float_index < floats_per_node; ++float_index) { + const IN_T sum = input_i[float_index] + input_j[float_index]; + const IN_T diff = input_i[float_index] - input_j[float_index]; + const IN_T absdiff = std::abs(diff); + const OUT_T sum_out = FeatureValues::convertToFeatureType(sum); + const OUT_T absdiff_out = FeatureValues::convertToFeatureType(absdiff); + output_sum[float_index] = sum_out; + output_absdiff[float_index] = absdiff_out; + // Same values for opposite direction + output_sum[floats_per_half_edge + float_index] = sum_out; + output_absdiff[floats_per_half_edge + float_index] = absdiff_out; + } + } +} + +// This is called by `create_positional_features` to convert node-level feature data +// to node-pair-level feature data. Each pair has the average of the two nodes for all floats +// of the feature, followed by the absolute difference of the two nodes for all floats of +// the feature. +template +void node_to_node_pair( + std::unique_ptr& output_ptr, + size_t& floats_per_pair, + const IN_T* input, + const size_t n, + const size_t floats_per_node) { + + floats_per_pair = 2 * floats_per_node; + output_ptr.reset(new OUT_T[n * n * floats_per_pair]); + OUT_T* output = output_ptr.get(); + const IN_T* input_i = input; + for (size_t i = 0; i < n; ++i, input_i += floats_per_node) { + const IN_T* input_j = input; + for (size_t j = 0; j < n; ++j, input_j += floats_per_node, output += floats_per_pair) { + // For each pair, record all of the sums followed by all of the absolute differences + OUT_T* output_sum = output; + OUT_T* output_absdiff = output + floats_per_node; + for (size_t float_index = 0; float_index < floats_per_node; ++float_index) { + const IN_T sum = input_i[float_index] + input_j[float_index]; + const IN_T diff = input_i[float_index] - input_j[float_index]; + const IN_T absdiff = std::abs(diff); + output_sum[float_index] = FeatureValues::convertToFeatureType(sum); + output_absdiff[float_index] = FeatureValues::convertToFeatureType(absdiff); + } + } + } +} + +// Used by `node_pair_to_node_helper` +enum class StatOperation { + MINIMUM, + MEAN +}; + +// Used by `node_pair_to_node_helper` +template +void stat_accum(T& accum, T v) { + switch (op) { + case StatOperation::MINIMUM: + accum = (v < accum) ? v : accum; + break; + case StatOperation::MEAN: + accum += v; + break; + } +} + +// Used by `node_pair_to_node_helper` +template +T stat_accum_finish(T accum, size_t n) { + switch (op) { + case StatOperation::MINIMUM: + return accum; + case StatOperation::MEAN: + return accum / n; + } +} + +// Used by `node_pair_to_node` +template +void node_pair_to_node_helper( + OUT_T* output, + const IN_T* input, + const size_t n, + const size_t floats_per_pair, + const size_t node_index) { + + // for each float per pair + for (size_t float_index = 0; float_index < floats_per_pair; ++float_index, output += 2) { + // across all rows (axis 0) of column node_index, then across all columns (axis 1) of row node_index + IN_T accum = input[node_index * floats_per_pair + float_index]; + for (size_t row = 1; row < n; ++row) { + stat_accum(accum, input[(row * n + node_index) * floats_per_pair + float_index]); + } + output[0] = FeatureValues::convertToFeatureType(stat_accum_finish(accum, n)); + accum = input[node_index * n * floats_per_pair + float_index]; + for (size_t col = 1; col < n; ++col) { + stat_accum(accum, input[(node_index * n + col) * floats_per_pair + float_index]); + } + output[1] = FeatureValues::convertToFeatureType(stat_accum_finish(accum, n)); + } +} + +// Used by `node_pair_to_node` +template +void node_pair_to_node_helper_stdev( + OUT_T* output, + const IN_T* input, + const size_t n, + const size_t floats_per_pair, + const size_t node_index) { + + // for each float per pair + for (size_t float_index = 0; float_index < floats_per_pair; ++float_index, output += 2) { + // across all rows (axis 0) of column node_index, then across all columns (axis 1) of row node_index + IN_T v = input[node_index * floats_per_pair + float_index]; + IN_T accum = v; + IN_T accum2 = v * v; + for (size_t row = 1; row < n; ++row) { + v = input[(row * n + node_index) * floats_per_pair + float_index]; + accum += v; + accum2 += v * v; + } + // NOTE: Using divisor n, the default in numpy.std, not n-1, the default elsewhere + accum /= n; + accum2 /= n; + output[0] = FeatureValues::convertToFeatureType(std::sqrt(accum2 - accum*accum)); + + v = input[node_index * n * floats_per_pair + float_index]; + accum = v; + accum2 = v * v; + for (size_t col = 1; col < n; ++col) { + v = input[(node_index * n + col) * floats_per_pair + float_index]; + accum += v; + accum2 += v * v; + } + // NOTE: Using divisor n, the default in numpy.std, not n-1, the default elsewhere + accum /= n; + accum2 /= n; + output[1] = FeatureValues::convertToFeatureType(std::sqrt(accum2 - accum*accum)); + } +} + +// This is called by `create_positional_features` to convert node-pair-level feature data +// to node-level feature data. Each node has the minimum of values in the column and the +// minimum of values in the row, for each float of the feature, followed by the mean of the +// values in the column and of values in the row, for each float, followed by the standard +// deviation of the values in the column and of values in the row, for each float. +template +void node_pair_to_node( + std::unique_ptr& output_ptr, + size_t& floats_per_node, + const IN_T* input, + const size_t n, + const size_t floats_per_pair) { + + const size_t num_ops = 3; + floats_per_node = num_ops * 2 * floats_per_pair; + output_ptr.reset(new OUT_T[n * floats_per_node]); + OUT_T* output = output_ptr.get(); + for (size_t node_index = 0; node_index < n; ++node_index) { + // min, mean, stdev (using divisor N, the default in numpy.std, not N-1, the default elsewhere) + node_pair_to_node_helper(output, input, n, floats_per_pair, node_index); + output += 2 * floats_per_pair; + node_pair_to_node_helper(output, input, n, floats_per_pair, node_index); + output += 2 * floats_per_pair; + node_pair_to_node_helper_stdev(output, input, n, floats_per_pair, node_index); + output += 2 * floats_per_pair; + } +} + +// This is called by `create_positional_features` to convert node-pair-level feature data +// to edge-level feature data. Each edge has the upper triangular floats for the pair, +// followed by the lower triangular floats for the pair. +template +void node_pair_to_edge( + std::unique_ptr& output_ptr, + size_t& floats_per_edge, + const IN_T* input, + const size_t n, + const size_t floats_per_pair, + const GraphData& graph) { + + // Edge order must be consistent with the edges in the graph, + // which is not necessarily lexicographic order. + const size_t num_half_edges = 2*graph.num_bonds; + floats_per_edge = floats_per_pair; + output_ptr.reset(new OUT_T[num_half_edges * floats_per_pair]); + OUT_T* output = output_ptr.get(); + for (size_t bond = 0; bond < graph.num_bonds; ++bond) { + const size_t atomi = graph.bonds[bond].beginAtomIdx; + const size_t atomj = graph.bonds[bond].endAtomIdx; + const IN_T* input_ij = input + ((atomi * n) + atomj) * floats_per_pair; + for (size_t float_index = 0; float_index < floats_per_pair; ++float_index, ++output) { + *output = FeatureValues::convertToFeatureType(input_ij[float_index]); + } + + const IN_T* input_ji = input + ((atomj * n) + atomi) * floats_per_pair; + for (size_t float_index = 0; float_index < floats_per_pair; ++float_index, ++output) { + *output = FeatureValues::convertToFeatureType(input_ji[float_index]); + } + } +} + +// This is called by `create_all_features` to create a Torch tensor for each +// "positional" feature, (not to be confused with a conformer feature.) +template +void create_positional_features( + const GraphData& graph, + const at::Tensor& positional_property_list, + c10::ScalarType dtype, + MaskNaNStyle mask_nan_style, + T mask_nan_value, + int64_t& num_nans, + int64_t& nan_tensor_index, + std::vector& tensors) { + + const size_t size = (positional_property_list.scalar_type() == c10::ScalarType::Long && positional_property_list.ndimension() == 1) ? positional_property_list.size(0) : 0; + const int64_t* const property_list = (size >= 3) ? positional_property_list.data_ptr() : nullptr; + + if (property_list == nullptr) { + return; + } + NeighborData neighbors = construct_neighbors(graph); + + LaplacianData laplacian_data; + LaplacianData laplacian_data_comp; + size_t num_components = 0; // 0 indicates that the components haven't been computed yet + std::vector components; + std::vector laplacian_pseudoinverse; + std::vector matrix; + size_t i = 0; + while (size >= i + 3) { + int64_t property = property_list[i]; + int64_t current_size = property_list[i + 1]; + FeatureLevel feature_level = FeatureLevel(property_list[i + 2]); + i += 3; + if (i + current_size > size || i + current_size < i) { + break; + } + FeatureLevel base_level; + std::unique_ptr base_data; + int64_t base_dims[3] = { 1,1,1 }; + size_t base_dim_count; + if ((property == int64_t(PositionalFeature::LAPLACIAN_EIGENVEC) || property == int64_t(PositionalFeature::LAPLACIAN_EIGENVAL)) && current_size == 3) { + size_t num_pos = (property_list[i] >= 0) ? size_t(property_list[i]) : 0; + Normalization normalization = Normalization(property_list[i + 1]); + bool disconnected_comp = (property_list[i + 2] != 0); + i += 3; + + // The common case is that there's only 1 component, even if disconnected_comp is true, + // so find the number of components, first. + if (disconnected_comp && num_components == 0) { + num_components = find_components(graph.num_atoms, neighbors.neighbor_starts, neighbors.neighbors, components); + } + const bool multiple_components = disconnected_comp && (num_components > 1); + + LaplacianData& current_data = multiple_components ? laplacian_data_comp : laplacian_data; + if (current_data.eigenvalues.size() == 0 || current_data.normalization != normalization) { + compute_laplacian_eigendecomp( + graph.num_atoms, + neighbors.neighbor_starts, + neighbors.neighbors, + normalization, + current_data, + multiple_components ? num_components : 1, + components.data()); + } + + const bool isVec = (property == int64_t(PositionalFeature::LAPLACIAN_EIGENVEC)); + base_level = FeatureLevel::NODE; + base_dims[0] = graph.num_atoms; + base_dims[1] = num_pos; + base_dim_count = 2; + base_data.reset(new double[graph.num_atoms * num_pos]); + + // Ensure exactly the tensor dimensions of num_atoms x num_pos before changing the level. + if (isVec) { + double* data = base_data.get(); + for (size_t atom_index = 0; atom_index < graph.num_atoms; ++atom_index, data += num_pos) { + for (size_t i = 0; i < num_pos && i < graph.num_atoms; ++i) { + // Row eigenvectors to column eigenvectors + data[i] = current_data.vectors[atom_index + i * graph.num_atoms]; + // There's no plausible way the eigenvectors should end up with NaNs, + // so just assert in debug builds. + assert(std::isfinite(data[i])); + } + // NOTE: Do not treat extra values as NaN. The original code filled them with zeros. + for (size_t i = graph.num_atoms; i < num_pos; ++i) { + data[i] = 0; + } + } + } + else { + double* data = base_data.get(); + const bool is_multi_component = (current_data.eigenvalues.size() == size_t(graph.num_atoms)*graph.num_atoms); + assert(is_multi_component || (current_data.eigenvalues.size() == graph.num_atoms)); + size_t source_row_start = 0; + for (size_t atom_index = 0; atom_index < graph.num_atoms; ++atom_index, data += num_pos) { + for (size_t i = 0; i < num_pos && i < graph.num_atoms; ++i) { + // Duplicate the eigenvalue for each atom + data[i] = current_data.eigenvalues[source_row_start + i]; + // There's no plausible way the eigenvalues should end up with NaNs, + // so just assert in debug builds. + assert(std::isfinite(data[i])); + } + // NOTE: Do not treat extra values as NaN. The original code filled them with zeros. + for (size_t i = graph.num_atoms; i < num_pos; ++i) { + data[i] = 0; + } + if (is_multi_component) { + source_row_start += graph.num_atoms; + } + } + } + } + else if ((property == int64_t(PositionalFeature::RW_RETURN_PROBS) || property == int64_t(PositionalFeature::RW_TRANSITION_PROBS)) && current_size >= 1) { + int space_dim = property_list[i]; + ++i; + uint32_t num_powers = current_size - 1; + const uint64_t* powers = reinterpret_cast(property_list + i); + i += num_powers; + const bool isProbs = (property == int64_t(PositionalFeature::RW_RETURN_PROBS)); + RandomWalkDataOption option = isProbs ? RandomWalkDataOption::PROBABILITIES : RandomWalkDataOption::MATRIX; + + std::vector output; + compute_rwse(num_powers, powers, graph.num_atoms, neighbors.neighbor_starts, neighbors.neighbors, option, output, space_dim); + + base_level = isProbs ? FeatureLevel::NODE : FeatureLevel::NODEPAIR; + + base_dims[0] = graph.num_atoms; + base_dims[1] = isProbs ? num_powers : graph.num_atoms; + base_dims[2] = isProbs ? 1 : num_powers; + base_dim_count = isProbs ? 2 : 3; + base_data.reset(new double[output.size()]); + std::copy(output.begin(), output.end(), base_data.get()); + } + else if (property == int64_t(PositionalFeature::ELECTROSTATIC) && current_size == 0) { + const double* weights = nullptr; + compute_electrostatic_interactions(graph.num_atoms, neighbors.neighbor_starts, neighbors.neighbors, laplacian_data, laplacian_pseudoinverse, matrix, weights); + + base_level = FeatureLevel::NODEPAIR; + base_dims[0] = graph.num_atoms; + base_dims[1] = graph.num_atoms; + base_dim_count = 2; + assert(matrix.size() == graph.num_atoms * size_t(graph.num_atoms)); + base_data.reset(new double[matrix.size()]); + std::copy(matrix.begin(), matrix.end(), base_data.get()); + } + else if (property == int64_t(PositionalFeature::COMMUTE) && current_size == 0) { + const double* weights = nullptr; + compute_commute_distances(graph.num_atoms, neighbors.neighbor_starts, neighbors.neighbors, laplacian_data, laplacian_pseudoinverse, matrix, weights); + + base_level = FeatureLevel::NODEPAIR; + base_dims[0] = graph.num_atoms; + base_dims[1] = graph.num_atoms; + base_dim_count = 2; + assert(matrix.size() == graph.num_atoms * size_t(graph.num_atoms)); + base_data.reset(new double[matrix.size()]); + std::copy(matrix.begin(), matrix.end(), base_data.get()); + } + else if (property == int64_t(PositionalFeature::GRAPHORMER) && current_size == 0) { + std::vector> queue; + std::vector all_pairs_distances; + compute_graphormer_distances(graph.num_atoms, neighbors.neighbor_starts, neighbors.neighbors, queue, all_pairs_distances); + + base_level = FeatureLevel::NODEPAIR; + base_dims[0] = graph.num_atoms; + base_dims[1] = graph.num_atoms; + base_dim_count = 2; + assert(all_pairs_distances.size() == graph.num_atoms * size_t(graph.num_atoms)); + base_data.reset(new double[all_pairs_distances.size()]); + std::copy(all_pairs_distances.begin(), all_pairs_distances.end(), base_data.get()); + } + + if (base_data.get() == nullptr) { + continue; + } + + // Change the level and convert to the correct type if needed. + std::unique_ptr final_data; + int64_t final_dims[3]; + std::copy(base_dims, base_dims + 3, final_dims); + size_t final_num_dims = base_dim_count; + if (feature_level != base_level) { + if (base_level == FeatureLevel::NODE) { + if (feature_level == FeatureLevel::EDGE) { + size_t floats_per_half_edge; + node_to_edge(final_data, floats_per_half_edge, base_data.get(), base_dims[0], base_dims[1], graph); + final_dims[0] = 2 * graph.num_bonds; + final_dims[1] = floats_per_half_edge; + final_dims[2] = 1; + } + else if (feature_level == FeatureLevel::NODEPAIR) { + size_t floats_per_pair; + node_to_node_pair(final_data, floats_per_pair, base_data.get(), base_dims[0], base_dims[1]); + final_num_dims = 3; + final_dims[1] = base_dims[0]; + final_dims[2] = floats_per_pair; + } + else { + // Not implemented + } + } + else if (base_level == FeatureLevel::NODEPAIR) { + if (feature_level == FeatureLevel::NODE) { + size_t floats_per_node; + node_pair_to_node(final_data, floats_per_node, base_data.get(), base_dims[0], base_dims[2]); + final_num_dims = 2; + final_dims[1] = floats_per_node; + final_dims[2] = 1; + } + else if (feature_level == FeatureLevel::EDGE) { + size_t floats_per_edge; + node_pair_to_edge(final_data, floats_per_edge, base_data.get(), base_dims[0], base_dims[2], graph); + final_num_dims = 2; + final_dims[0] = 2 * graph.num_bonds; + final_dims[1] = floats_per_edge; + final_dims[2] = 1; + } + else { + // Not implemented + } + } + else { + // Not implemented + } + } + else if (dtype != c10::ScalarType::Double) { + // Just convert + const size_t total_num_floats = final_dims[0] * final_dims[1] * final_dims[2]; + final_data.reset(new T[total_num_floats]); + for (size_t i = 0; i < total_num_floats; ++i) { + final_data[i] = FeatureValues::convertToFeatureType(base_data[i]); + } + } + else { + // Perfect match out of the box + // This will only be hit if T is double, but it still needs to compile + // for other cases, which is why the reinterpret_cast is needed. + final_data.reset(reinterpret_cast(base_data.release())); + } + + if (final_data.get() == nullptr) { + continue; + } + + tensors.push_back(torch_tensor_from_array(std::move(final_data), final_dims, final_num_dims, dtype)); + } +} + +// This is called by `featurize_smiles` after checking the tensor data type to create. +template +void create_all_features( + const GraphData& graph, + const at::Tensor& atom_property_list_onehot, + const at::Tensor& atom_property_list_float, + bool create_conformer_feature, + const at::Tensor& bond_property_list, + const at::Tensor& positional_property_list, + bool duplicate_edges, + bool add_self_loop, + bool already_has_Hs, + bool use_bonds_weights, + bool offset_carbon, + c10::ScalarType dtype, + MaskNaNStyle mask_nan_style, + T mask_nan_value, + int64_t& num_nans, + int64_t& nan_tensor_index, + const std::string& smiles_string, + std::vector& tensors) { + + if (mask_nan_style == MaskNaNStyle::NONE) { + // In some cases, the NONE and REPLACE styles can be combined. + mask_nan_value = FeatureValues::nan_value; + } + at::Tensor edge_weights_tensor = create_edge_weights( + graph, + duplicate_edges, + add_self_loop, + use_bonds_weights, + dtype); + tensors.push_back(std::move(edge_weights_tensor)); + at::Tensor atom_features_tensor = create_atom_features( + graph, + atom_property_list_onehot, + atom_property_list_float, + offset_carbon, + dtype, + mask_nan_style, + mask_nan_value, + num_nans); + tensors.push_back(std::move(atom_features_tensor)); + if (num_nans != 0) { + nan_tensor_index = tensors.size()-1; + } + at::Tensor bond_features_tensor = create_bond_features( + graph, + bond_property_list, + duplicate_edges, + add_self_loop, + dtype, + mask_nan_style, + mask_nan_value, + num_nans); + tensors.push_back(std::move(bond_features_tensor)); + if (nan_tensor_index < 0 && num_nans != 0) { + nan_tensor_index = tensors.size()-1; + } + if (create_conformer_feature) { + at::Tensor conformer_features_tensor = get_conformer_features( + *graph.mol, + already_has_Hs, + dtype, + mask_nan_style, + mask_nan_value, + num_nans, + smiles_string); + tensors.push_back(std::move(conformer_features_tensor)); + if (nan_tensor_index < 0 && num_nans != 0) { + nan_tensor_index = tensors.size(); + } + } + create_positional_features( + graph, + positional_property_list, + dtype, + mask_nan_style, + mask_nan_value, + num_nans, + nan_tensor_index, + tensors); +} + +// `featurize_smiles` is called from Python to get feature tensors for `smiles_string`. +// See the declaration in features.h for more details. +std::tuple, int64_t, int64_t> featurize_smiles( + const std::string& smiles_string, + const at::Tensor& atom_property_list_onehot, + const at::Tensor& atom_property_list_float, + bool create_conformer_feature, + const at::Tensor& bond_property_list, + const at::Tensor& positional_property_list, + bool duplicate_edges, + bool add_self_loop, + bool explicit_H, + bool use_bonds_weights, + bool offset_carbon, + int dtype_int, + int mask_nan_style_int, + double mask_nan_value) { + + GraphData graph = read_graph(smiles_string, explicit_H); + + const size_t edge_coo_count = 2*graph.num_bonds + (add_self_loop ? graph.num_atoms : 0); + std::unique_ptr edge_index(new int64_t[2*edge_coo_count]); + for (size_t i = 0; i < graph.num_bonds; ++i) { + // PyG has all directed edge begin indices followed by all end indices. + edge_index[2*i] = graph.bonds[i].beginAtomIdx; + edge_index[2*i+1] = graph.bonds[i].endAtomIdx; + edge_index[2*i + edge_coo_count] = graph.bonds[i].endAtomIdx; + edge_index[2*i+1 + edge_coo_count] = graph.bonds[i].beginAtomIdx; + } + if (add_self_loop) { + for (size_t i = 0; i < graph.num_atoms; ++i) { + edge_index[2*graph.num_bonds + i] = i; + edge_index[2*graph.num_bonds + i + edge_coo_count] = i; + } + } + int64_t edge_coo_dims[2] = { int64_t(2), int64_t(edge_coo_count) }; + at::Tensor edge_coo_tensor = torch_tensor_from_array(std::move(edge_index), edge_coo_dims, 2, c10::ScalarType::Long); + + std::vector tensors; + tensors.push_back(std::move(edge_coo_tensor)); + c10::ScalarType dtype = c10::ScalarType(dtype_int); + MaskNaNStyle mask_nan_style = MaskNaNStyle(mask_nan_style_int); + int64_t num_nans = 0; + int64_t nan_tensor_index = -1; + if (dtype == c10::ScalarType::Half) { + create_all_features( + graph, + atom_property_list_onehot, + atom_property_list_float, + create_conformer_feature, + bond_property_list, + positional_property_list, + duplicate_edges, + add_self_loop, + explicit_H, + use_bonds_weights, + offset_carbon, + dtype, + mask_nan_style, + FeatureValues::convertToFeatureType(mask_nan_value), + num_nans, + nan_tensor_index, + smiles_string, + tensors); + } + else if (dtype == c10::ScalarType::Float) { + create_all_features( + graph, + atom_property_list_onehot, + atom_property_list_float, + create_conformer_feature, + bond_property_list, + positional_property_list, + duplicate_edges, + add_self_loop, + explicit_H, + use_bonds_weights, + offset_carbon, + dtype, + mask_nan_style, + FeatureValues::convertToFeatureType(mask_nan_value), + num_nans, + nan_tensor_index, + smiles_string, + tensors); + } + else if (dtype == c10::ScalarType::Double) { + create_all_features( + graph, + atom_property_list_onehot, + atom_property_list_float, + create_conformer_feature, + bond_property_list, + positional_property_list, + duplicate_edges, + add_self_loop, + explicit_H, + use_bonds_weights, + offset_carbon, + dtype, + mask_nan_style, + FeatureValues::convertToFeatureType(mask_nan_value), + num_nans, + nan_tensor_index, + smiles_string, + tensors); + } + + return std::make_tuple(tensors, num_nans, nan_tensor_index); +} diff --git a/graphium/graphium_cpp/features.h b/graphium/graphium_cpp/features.h new file mode 100644 index 000000000..0164112db --- /dev/null +++ b/graphium/graphium_cpp/features.h @@ -0,0 +1,392 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This header file declares feature-related enums, functions, and structs, +//! some of which are defined in features.cpp and exported to Python. + +#pragma once + +#include +#include +#include +#include +#include + +// Torch tensor headers +#include +#include + +#include +#include + +// PyBind and Torch headers +#include +#include +#include + +//! Levels at which features or labels can be associated +//! String names are in `feature_level_to_enum` in features.cpp +enum class FeatureLevel { + NODE, //!< Values for each node (atom) + EDGE, //!< Values for each edge (bond) + NODEPAIR, //!< Values for each pair of nodes (pair of atoms), even if no edge (bond) + GRAPH //!< Values for whole molecule +}; + +//! Features for use by `get_atom_float_feature` in float_features.cpp +//! String names are in `atom_float_name_to_enum` in features.cpp +enum class AtomFloatFeature { + ATOMIC_NUMBER, + MASS, + VALENCE, + IMPLICIT_VALENCE, + HYBRIDIZATION, + CHIRALITY, + AROMATIC, + IN_RING, + MIN_RING, + MAX_RING, + NUM_RING, + DEGREE, + RADICAL_ELECTRON, + FORMAL_CHARGE, + VDW_RADIUS, + COVALENT_RADIUS, + ELECTRONEGATIVITY, + IONIZATION, + MELTING_POINT, + METAL, + GROUP, + PERIOD, + SINGLE_BOND, + AROMATIC_BOND, + DOUBLE_BOND, + TRIPLE_BOND, + IS_CARBON, + UNKNOWN +}; + +//! Features for use by `get_one_hot_atom_feature` in one_hot.cpp +//! String names are in `atom_onehot_name_to_enum` in features.cpp +enum class AtomOneHotFeature { + ATOMIC_NUM, //!< Selected atomic numbers specified in `atomicNumList` in one_hot.cpp + DEGREE, //!< Number of explicit neighboring atoms + VALENCE, //!< Total valence of the atom + IMPLICIT_VALENCE, //!< Implicit valence of the atom + HYBRIDIZATION, //!< Hybridizations specified in `hybridizationList` in one_hot.cpp + CHIRALITY, //!< "R", anything other value ("S") or no value, and an extra + //!< chirality-related value (independent of the other two, so can + //!< have a 2nd one value) + PHASE, //!< Specified by `ElementPhase` and `atomicNumToPhase` in one_hot.cpp + TYPE, //!< Specified by `ElementType` and `atomicNumToType` in one_hot.cpp + GROUP, //!< Specified by `atomicNumToGroupTable` in float_features.h + PERIOD, //!< Specified by `atomicNumToPeriodTable` in float_features.h + UNKNOWN //!< Sentinel value. Do not use. +}; + +//! Features for use by `get_one_hot_bond_feature` in one_hot.cpp (if ends in `ONE_HOT`), and +//! `get_bond_float_feature` in float_features.cpp +//! String names are in `bond_name_to_enum` in features.cpp +enum class BondFeature { + TYPE_FLOAT, //!< Bond type as a float, e.g. 2.0 for double, 1.5 for aromatic + TYPE_ONE_HOT, //!< Selected bond types specified in `bondTypeList` in one_hot.cpp + IN_RING, //!< 1.0 if the bond is in at least one ring, else 0.0 + CONJUGATED, //!< 1.0 if the bond is conjugated, else 0.0 + STEREO_ONE_HOT, //!< Selected bond stereo values specified in `bondStereoList` in + //!< one_hot.cpp + CONFORMER_BOND_LENGTH,//!< Length of the bond from a conformer (either first or computed) + ESTIMATED_BOND_LENGTH,//!< Length of the bond estimated with a fast heuristic + UNKNOWN //!< Sentinel value. Do not use. +}; + +//! Supported "positional" features +//! String names are in `positional_name_to_enum` in features.cpp +enum class PositionalFeature { + LAPLACIAN_EIGENVEC, //!< See `compute_laplacian_eigendecomp` in spectral.cpp + LAPLACIAN_EIGENVAL, //!< See `compute_laplacian_eigendecomp` in spectral.cpp + RW_RETURN_PROBS, //!< See `compute_rwse` in random_walk.cpp + RW_TRANSITION_PROBS,//!< See `compute_rwse` in random_walk.cpp + ELECTROSTATIC, //!< See `compute_electrostatic_interactions` in electrostatic.cpp + COMMUTE, //!< See `compute_commute_distances` in commute.cpp + GRAPHORMER //!< See `compute_graphormer_distances` in graphormer.cpp +}; + +//! Options for normalization of graph Laplacian matrix in positional features. +//! Not to be confused with the normalization of label data in `prepare_and_save_data`. +//! String names are in `normalization_to_enum` in features.cpp +enum class Normalization { + NONE, //!< Leaves the matrix unnormalized: `L = D - adj` + SYMMETRIC, //!< Corresponds with `L_s = (D^-0.5) L (D^-0.5)` + INVERSE //!< Corresponds with `L_i = (D^-1) L` +}; + +//! Options for handling NaN or infinite values, passed from Python to `featurize_smiles` in +//! features.cpp. Masking is done in `mask_nans` in features.h +enum class MaskNaNStyle { + NONE, //!< Ignore (keep) NaN values + REPORT, //!< (default behaviour) Count NaN values and report that with the index of the + //!< first tensor that contained NaNs + REPLACE //!< Replace NaN values with a specific value (defaults to zero) +}; + +//! Class for storing all supported options of all positional features, +//! even ones that are mutually exclusive with each other. +struct PositionalOptions { + PositionalFeature feature; + FeatureLevel level; + + //! Powers used by `PositionalFeature::RW_RETURN_PROBS` and `RW_TRANSITION_PROBS` + std::vector rw_powers; + int rw_space_dim = 0; + + uint32_t laplacian_num_pos = 8; + Normalization laplacian_normalization = Normalization::NONE; + bool laplacian_disconnected_comp = true; +}; + +//! Class to help supporting `int16_t` as if it's a 16-bit floating-point (FP16) type, +//! while still supporting `float` (FP32) and `double` (FP64). +template +struct FeatureValues {}; + +//! Explicit instantiation of `FeatureValues` for `int16_t` as if it's a 16-bit +//! floating-point (FP16) type. +template<> struct FeatureValues { + static constexpr int16_t zero = 0x0000; + static constexpr int16_t one = 0x3C00; + static constexpr int16_t nan_value = 0x7C01; + + template + static int16_t convertToFeatureType(T inputType) { + static_assert(std::is_floating_point_v); + return c10::detail::fp16_ieee_from_fp32_value(float(inputType)); + } + + static constexpr bool is_finite(int16_t v) { + // If the exponent bits are the maximum value, v is infinite or NaN + return (v & 0x7C00) != 0x7C00; + } + + using MathType = float; +}; +//! Explicit instantiation of `FeatureValues` for `float` (FP32) +template<> struct FeatureValues { + static constexpr float zero = 0.0f; + static constexpr float one = 1.0f; + static constexpr float nan_value = std::numeric_limits::quiet_NaN(); + + template + static float convertToFeatureType(T inputType) { + static_assert(std::is_floating_point_v); + return float(inputType); + } + + static bool is_finite(float v) { + return std::isfinite(v); + } + + using MathType = float; +}; +//! Explicit instantiation of `FeatureValues` for `double` (FP64) +template<> struct FeatureValues { + static constexpr double zero = 0.0; + static constexpr double one = 1.0; + static constexpr double nan_value = std::numeric_limits::quiet_NaN(); + + template + static double convertToFeatureType(T inputType) { + static_assert(std::is_floating_point_v); + return double(inputType); + } + + static constexpr bool is_finite(double v) { + return std::isfinite(v); + } + + using MathType = double; +}; + +//! Handling for NaN or infinite values in an array, `data`, of `n` values. +//! @see MaskNaNStyle +template +constexpr int64_t mask_nans(T* data, size_t n, MaskNaNStyle style, T value) { + if (style == MaskNaNStyle::NONE) { + return 0; + } + if (style == MaskNaNStyle::REPLACE) { + for (size_t i = 0; i < n; ++i) { + if (!FeatureValues::is_finite(data[i])) { + data[i] = value; + } + } + return 0; + } + + assert(mask_nan_style == MaskNaNStyle::REPORT); + int64_t num_nans = 0; + for (size_t i = 0; i < n; ++i) { + num_nans += (!FeatureValues::is_finite(data[i])); + } + return num_nans; +} + + +// This is just a function to provide to torch, so that we don't have to copy +// the tensor data to put it in a torch tensor, and torch can delete the data +// when it's no longer needed. +template +void deleter(void* p) { + delete[](T*)p; +} + +//! Helper function to construct a torch `Tensor` from a C++ array. +//! The `Tensor` takes ownership of the memory owned by `source`. +template +at::Tensor torch_tensor_from_array(std::unique_ptr&& source, const int64_t* dims, size_t num_dims, c10::ScalarType type) { + return at::from_blob( + source.release(), + at::IntArrayRef(dims, num_dims), + deleter, c10::TensorOptions(type)); +} + +//! Most of the data needed about an atom +struct CompactAtom { + uint8_t atomicNum; + uint8_t totalDegree; + int8_t formalCharge; + uint8_t chiralTag; + uint8_t totalNumHs; + uint8_t hybridization; + bool isAromatic; + float mass; +}; + +//! Most of the data needed about a bond +struct CompactBond { + uint8_t bondType; + bool isConjugated; + bool isInRing; + uint8_t stereo; + uint32_t beginAtomIdx; + uint32_t endAtomIdx; +}; + +//! Data representing a molecule before featurization +struct GraphData { + const size_t num_atoms; + std::unique_ptr atoms; + const size_t num_bonds; + std::unique_ptr bonds; + + std::unique_ptr mol; +}; + + +//! This is called from Python to list atom one-hot features in a format that will be faster +//! to interpret inside `featurize_smiles`, passed in the `atom_property_list_onehot` parameter. +//! Implemented in features.cpp, but declared here so that graphium_cpp.cpp can expose them to +//! Python via pybind. +at::Tensor atom_onehot_feature_names_to_tensor(const std::vector& features); + +//! This is called from Python to list atom float features in a format that will be faster +//! to interpret inside `featurize_smiles`, passed in the `atom_property_list_float` parameter. +//! Implemented in features.cpp, but declared here so that graphium_cpp.cpp can expose them to +//! Python via pybind. +at::Tensor atom_float_feature_names_to_tensor(const std::vector& features); + +//! This is called from Python to list bond features in a format that will be faster +//! to interpret inside `featurize_smiles`, passed in the `bond_property_list` parameter. +//! Implemented in features.cpp, but declared here so that graphium_cpp.cpp can expose them to +//! Python via pybind. +at::Tensor bond_feature_names_to_tensor(const std::vector& features); + +//! This is called from Python to list positional features and their options in a format that +//! will be faster to interpret inside `featurize_smiles`, passed in the `bond_property_list` +//! parameter. Implemented in features.cpp, but declared here so that graphium_cpp.cpp can +//! expose them to Python via pybind. +std::pair,at::Tensor> positional_feature_options_to_tensor(const pybind11::dict& dict); + +//! `featurize_smiles` is called from Python to get feature tensors for `smiles_string`. +//! +//! @param smiles_string SMILES string of the molecule to featurize +//! @param atom_property_list_onehot Torch `Tensor` returned by +//! `atom_onehot_feature_names_to_tensor` representing the +//! list of one-hot atom features to create. +//! @param atom_property_list_float Torch `Tensor` returned by +//! `atom_float_feature_names_to_tensor` representing the +//! list of float atom features to create. +//! @param create_conformer_feature If true, a feature `Tensor` for a conformer is created. +//! @param bond_property_list Torch `Tensor` returned by `bond_feature_names_to_tensor` +//! representing the list of bond features to create. +//! @param positional_property_list Torch `Tensor` returned by +//! `positional_feature_options_to_tensor` representing the list +//! of positional features to create and their options. +//! @param duplicate_edges If true (the default), bond features will have values stored for +//! both edge directions. +//! @param add_self_loop If true (default false), bond features will have values stored for +//! self-edges. +//! @param explicit_H If true (default false), implicit hydrogen atoms will be added explicitly +//! before featurizing. +//! @param use_bonds_weights If true (default false), some features may use the bond type as an +//! edge weight, e.g. 2.0 for double bonds or 1.5 for aromatic bonds. +//! @param offset_carbon If true (the default), some atom float features will subtract a +//! value representing carbon, so that carbon atoms would have value zero. +//! @param dtype_int Value representing the torch data type to use for the output `Tensor`s. +//! Allowed values are 5 (FP16), 6 (FP32), and 7 (FP64), corresponding with +//! `c10::ScalarType`. +//! @param mask_nan_style_int Value representing the behaviour for handling NaN and infinite +//! output values. Allowed values are 0 (ignore NaNs), 1 (return +//! the number of NaNs and the index of the first output `Tensor` +//! containing NaNs), and 2 (replace NaN values with `mask_nan_value`) +//! corresponding with the `MaskNaNStyle` enum. +//! @param mask_nan_value Value to replace NaN and infinite values with if `mask_nan_style_int` +//! is 2 (`MaskNaNStyle::REPLACE`) +//! @return A vector of torch `Tensor`s for the features, as well as two integers representing +//! the number of NaN values and the index of the first output `Tensor` containing NaNs +//! if `mask_nan_style_int` is 1 (`MaskNaNStyle::REPORT`). The first tensor is a 2 by +//! `num_edges` (taking into account `duplicate_edges` and `add_self_loop`) 64-bit +//! integer `Tensor` with the atom indices on either side of each edge. The second +//! tensor is a 1D `Tensor` with length `num_edges`, containing all ones, even if +//! `use_bonds_weights` is true. The third tensor is the atom features tensor, +//! `num_atoms` by the number of values required for all one-hot and float atom +//! features. The fourth tensor is the bond features tensor, `num_edges` by the number +//! of values required for all bond features. If `create_conformer_feature` is true, +//! the fifth tensor is a 1D tensor of length `3*num_atoms` for the conformer positions. +//! The rest of the tensors are the positional feature tensors, one for each positional +//! feature. +std::tuple, int64_t, int64_t> featurize_smiles( + const std::string& smiles_string, + const at::Tensor& atom_property_list_onehot, + const at::Tensor& atom_property_list_float, + bool create_conformer_feature, + const at::Tensor& bond_property_list, + const at::Tensor& positional_property_list, + bool duplicate_edges = true, + bool add_self_loop = false, + bool explicit_H = false, + bool use_bonds_weights = false, + bool offset_carbon = true, + int dtype_int = int(c10::ScalarType::Half), + int mask_nan_style_int = int(MaskNaNStyle::REPORT), + double mask_nan_value = 0.0); + +//! Creates an RWMol from a SMILES string. +//! +//! If `ordered` is true, and the string contains atom classes, called "bookmarks" in RDKit, +//! that form a complete (0-based) ordering of the atoms, the atoms will be reordered according +//! to this explicit order, and the bookmarks will be removed, so that canonical orders +//! can be correctly compared later. +//! +//! This is implemented in graphium_cpp.cpp, but is declared in this header so +//! that both labels.cpp and features.cpp can call it. +std::unique_ptr parse_mol( + const std::string& smiles_string, + bool explicit_H, + bool ordered = true); + +//! Determines a canonical ordering of the atoms in `mol` +//! +//! This is implemented in graphium_cpp.cpp, to keep it near `parse_mol` +void get_canonical_atom_order( + const RDKit::ROMol& mol, + std::vector& atom_order); diff --git a/graphium/graphium_cpp/float_features.cpp b/graphium/graphium_cpp/float_features.cpp new file mode 100644 index 000000000..8b2b27d92 --- /dev/null +++ b/graphium/graphium_cpp/float_features.cpp @@ -0,0 +1,537 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file defines functions for float-valued atom and bond features, +//! declared in float_features.h and called from features.cpp + +#include "float_features.h" + +#include "features.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +static constexpr double qNaN = std::numeric_limits::quiet_NaN(); + +// This table is from the Electronegativity column of graphium/features/periodic_table.csv +const double electronegativityTable[] = { + 2.20, qNaN, 0.98, 1.57, 2.04, 2.55, 3.04, 3.44, 3.98, + qNaN, 0.93, 1.31, 1.61, 1.90, 2.19, 2.58, 3.16, qNaN, 0.82, + 1.00, 1.36, 1.54, 1.63, 1.66, 1.55, 1.83, 1.88, 1.91, 1.90, + 1.65, 1.81, 2.01, 2.18, 2.55, 2.96, qNaN, 0.82, 0.95, 1.22, + 1.33, 1.60, 2.16, 1.90, 2.20, 2.28, 2.20, 1.93, 1.69, 1.78, + 1.96, 2.05, 2.10, 2.66, qNaN, 0.79, 0.89, 1.10, 1.12, 1.13, + 1.14, 1.13, 1.17, 1.20, 1.20, 1.20, 1.22, 1.23, 1.24, 1.25, + 1.10, 1.27, 1.30, 1.50, 2.36, 1.90, 2.20, 2.20, 2.28, 2.54, + 2.00, 2.04, 2.33, 2.02, 2.00, 2.20, qNaN, 0.70, 0.90, 1.10, + 1.30, 1.50, 1.38, 1.36, 1.28, 1.30, 1.30, 1.30, 1.30, 1.30, + 1.30, 1.30, 1.30, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, + qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, +}; + +// This table is from the FirstIonization column of graphium/features/periodic_table.csv +const double firstIonizationTable[] = { + 13.5984, 24.5874, 5.3917, 9.3227, 8.2980, 11.2603, 14.5341, 13.6181, 17.4228, + 21.5645, 5.1391, 7.6462, 5.9858, 8.1517, 10.4867, 10.3600, 12.9676, 15.7596, 4.3407, + 6.1132, 6.5615, 6.8281, 6.7462, 6.7665, 7.4340, 7.9024, 7.8810, 7.6398, 7.7264, + 9.3942, 5.9993, 7.8994, 9.7886, 9.7524, 11.8138, 13.9996, 4.1771, 5.6949, 6.2173, + 6.6339, 6.7589, 7.0924, 7.2800, 7.3605, 7.4589, 8.3369, 7.5762, 8.9938, 5.7864, + 7.3439, 8.6084, 9.0096, 10.4513, 12.1298, 3.8939, 5.2117, 5.5769, 5.5387, 5.4730, + 5.5250, 5.5820, 5.6437, 5.6704, 6.1501, 5.8638, 5.9389, 6.0215, 6.1077, 6.1843, + 6.2542, 5.4259, 6.8251, 7.5496, 7.8640, 7.8335, 8.4382, 8.9670, 8.9587, 9.2255, + 10.4375, 6.1082, 7.4167, 7.2856, 8.4170, 9.3000, 10.7485, 4.0727, 5.2784, 5.1700, + 6.3067, 5.8900, 6.1941, 6.2657, 6.0262, 5.9738, 5.9915, 6.1979, 6.2817, 6.4200, + 6.5000, 6.5800, 6.6500, qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , + qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , +}; + +// This table is from the MeltingPoint column of graphium/features/periodic_table.csv +const double meltingPointTable[] = { + 14.175, qNaN , 453.85, 1560.15, 2573.15, 3948.15, 63.29, 50.50, 53.63, + 24.703, 371.15, 923.15, 933.40, 1683.15, 317.25, 388.51, 172.31, 83.96, 336.50, + 1112.15, 1812.15, 1933.15, 2175.15, 2130.15, 1519.15, 1808.15, 1768.15, 1726.15, 1357.75, + 692.88, 302.91, 1211.45, 1090.15, 494.15, 266.05, 115.93, 312.79, 1042.15, 1799.15, + 2125.15, 2741.15, 2890.15, 2473.15, 2523.15, 2239.15, 1825.15, 1234.15, 594.33, 429.91, + 505.21, 904.05, 722.80, 386.65, 161.45, 301.70, 1002.15, 1193.15, 1071.15, 1204.15, + 1289.15, 1204.15, 1345.15, 1095.15, 1585.15, 1630.15, 1680.15, 1743.15, 1795.15, 1818.15, + 1097.15, 1936.15, 2500.15, 3269.15, 3680.15, 3453.15, 3300.15, 2716.15, 2045.15, 1337.73, + 234.43, 577.15, 600.75, 544.67, 527.15, 575.15, 202.15, 300.15, 973.15, 1323.15, + 2028.15, 1873.15, 1405.15, 913.15, 913.15, 1267.15, 1340.15, 1259.15, 1925.15, 1133.15, + qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , + qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , qNaN , +}; + +// This table is 2x the Metal column plus the Metalloid column of graphium/features/periodic_table.csv +const uint8_t metalTable[] = { + 0, 0, 2, 2, 1, 0, 0, 0, 0, + 0, 2, 2, 2, 1, 0, 0, 0, 0, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 1, 1, 0, 0, 0, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 1, 1, 0, 0, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 1, 0, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 0, 0, +}; + +// Fills in a particular atom float `feature` into `data`, for all atoms. +// See the declaration in float_features.h for more details. +template +void get_atom_float_feature(const GraphData& graph, T* data, AtomFloatFeature feature, size_t stride, bool offset_carbon) { + const uint32_t num_atoms = graph.num_atoms; + constexpr uint32_t carbon_atomic_num = 6; + using MT = typename FeatureValues::MathType; + switch (feature) { + case AtomFloatFeature::ATOMIC_NUMBER: { + const MT offset = offset_carbon ? carbon_atomic_num : 0; + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType((MT(graph.atoms[i].atomicNum) - offset) / MT(5)); + data += stride; + } + return; + } + case AtomFloatFeature::MASS: { + const RDKit::ROMol& mol = *graph.mol.get(); + constexpr MT carbon_mass = MT(12.011); + const MT offset = offset_carbon ? carbon_mass : 0; + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType((MT(mol.getAtomWithIdx(i)->getMass()) - offset) / MT(10)); + data += stride; + } + return; + } + case AtomFloatFeature::VALENCE: { + const RDKit::ROMol& mol = *graph.mol.get(); + const MT offset = offset_carbon ? 4 : 0; + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getTotalValence()) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::IMPLICIT_VALENCE: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getImplicitValence())); + data += stride; + } + return; + } + case AtomFloatFeature::HYBRIDIZATION: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getHybridization())); + data += stride; + } + return; + } + case AtomFloatFeature::CHIRALITY: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + const RDKit::Atom* atom = mol.getAtomWithIdx(i); + std::string prop; + bool has_prop = atom->getPropIfPresent(RDKit::common_properties::_CIPCode, prop); + *data = FeatureValues::convertToFeatureType(has_prop ? MT(prop.length() == 1 && prop[0] == 'R') : MT(2)); + data += stride; + } + return; + } + case AtomFloatFeature::AROMATIC: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getIsAromatic())); + data += stride; + } + return; + } + case AtomFloatFeature::IN_RING: { + const RDKit::ROMol& mol = *graph.mol.get(); + const RDKit::RingInfo* ring_info = mol.getRingInfo(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(ring_info->numAtomRings(i) != 0)); + data += stride; + } + return; + } + case AtomFloatFeature::MIN_RING: { + const RDKit::ROMol& mol = *graph.mol.get(); + const RDKit::RingInfo* ring_info = mol.getRingInfo(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(ring_info->minAtomRingSize(i))); + data += stride; + } + return; + } + case AtomFloatFeature::MAX_RING: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + data[i * stride] = FeatureValues::zero; + } + const RDKit::RingInfo* ring_info = mol.getRingInfo(); + const auto& rings = ring_info->atomRings(); + for (const auto& ring : rings) { + const T size = FeatureValues::convertToFeatureType(MT(ring.size())); + for (const auto atom_index : ring) { + if (size > data[atom_index * stride]) { + data[atom_index * stride] = size; + } + } + } + return; + } + case AtomFloatFeature::NUM_RING: { + const RDKit::ROMol& mol = *graph.mol.get(); + const RDKit::RingInfo* ring_info = mol.getRingInfo(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(ring_info->numAtomRings(i))); + data += stride; + } + return; + } + case AtomFloatFeature::DEGREE: { + const RDKit::ROMol& mol = *graph.mol.get(); + const MT offset = offset_carbon ? 2 : 0; + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getTotalDegree()) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::RADICAL_ELECTRON: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(mol.getAtomWithIdx(i)->getNumRadicalElectrons())); + data += stride; + } + return; + } + case AtomFloatFeature::FORMAL_CHARGE: { + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(graph.atoms[i].formalCharge)); + data += stride; + } + return; + } + case AtomFloatFeature::VDW_RADIUS: { + const RDKit::PeriodicTable* table = RDKit::PeriodicTable::getTable(); + const MT offset = offset_carbon ? MT(table->getRvdw(carbon_atomic_num)) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(table->getRvdw(graph.atoms[i].atomicNum)) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::COVALENT_RADIUS: { + const RDKit::PeriodicTable* table = RDKit::PeriodicTable::getTable(); + const MT offset = offset_carbon ? MT(table->getRcovalent(carbon_atomic_num)) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(table->getRcovalent(graph.atoms[i].atomicNum)) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::ELECTRONEGATIVITY: { + const MT offset = offset_carbon ? MT(electronegativityTable[carbon_atomic_num-1]) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i, data += stride) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + if (atomic_num <= 0 || atomic_num > 118 || electronegativityTable[atomic_num - 1] == 0) { + *data = FeatureValues::nan_value; + continue; + } + *data = FeatureValues::convertToFeatureType(MT(electronegativityTable[atomic_num - 1]) - offset); + } + return; + } + case AtomFloatFeature::IONIZATION: { + const T offset = offset_carbon ? T(firstIonizationTable[carbon_atomic_num-1]) : T(0); + for (uint32_t i = 0; i < num_atoms; ++i, data += stride) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + if (atomic_num <= 0 || atomic_num > 118 || firstIonizationTable[atomic_num - 1] == 0) { + *data = FeatureValues::nan_value; + continue; + } + *data = FeatureValues::convertToFeatureType((MT(firstIonizationTable[atomic_num - 1]) - offset) / MT(5)); + } + return; + } + case AtomFloatFeature::MELTING_POINT: { + const MT offset = offset_carbon ? MT(meltingPointTable[carbon_atomic_num-1]) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i, data += stride) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + if (atomic_num <= 0 || atomic_num > 118 || meltingPointTable[atomic_num - 1] == 0) { + *data = FeatureValues::nan_value; + continue; + } + *data = FeatureValues::convertToFeatureType((MT(meltingPointTable[atomic_num - 1]) - offset) / MT(200)); + } + return; + } + case AtomFloatFeature::METAL: { + for (uint32_t i = 0; i < num_atoms; ++i) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + *data = (atomic_num <= 0 || atomic_num > 118) ? FeatureValues::nan_value : FeatureValues::convertToFeatureType(MT(metalTable[atomic_num - 1])); + data += stride; + } + return; + } + case AtomFloatFeature::GROUP: { + const MT offset = offset_carbon ? MT(atomicNumToGroupTable[carbon_atomic_num - 1]) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + *data = (atomic_num <= 0 || atomic_num > 118) ? FeatureValues::nan_value : FeatureValues::convertToFeatureType(MT(atomicNumToGroupTable[atomic_num - 1]) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::PERIOD: { + const MT offset = offset_carbon ? MT(atomicNumToPeriodTable[carbon_atomic_num - 1]) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + const uint32_t atomic_num = graph.atoms[i].atomicNum; + *data = (atomic_num <= 0 || atomic_num > 118) ? FeatureValues::nan_value : FeatureValues::convertToFeatureType(MT(atomicNumToPeriodTable[atomic_num - 1]) - offset); + data += stride; + } + return; + } + case AtomFloatFeature::SINGLE_BOND: + case AtomFloatFeature::AROMATIC_BOND: + case AtomFloatFeature::DOUBLE_BOND: + case AtomFloatFeature::TRIPLE_BOND: + { + const RDKit::ROMol& mol = *graph.mol.get(); + const RDKit::Bond::BondType type = + (feature == AtomFloatFeature::SINGLE_BOND) ? RDKit::Bond::SINGLE : ( + (feature == AtomFloatFeature::AROMATIC_BOND) ? RDKit::Bond::AROMATIC : ( + (feature == AtomFloatFeature::DOUBLE_BOND) ? RDKit::Bond::DOUBLE : ( + RDKit::Bond::TRIPLE))); + for (uint32_t i = 0; i < num_atoms; ++i) { + auto [begin, end] = mol.getAtomBonds(mol.getAtomWithIdx(i)); + uint32_t count = 0; + for (; begin != end; ++begin) { + count += (mol[*begin]->getBondType() == type); + } + *data = FeatureValues::convertToFeatureType(MT(count)); + data += stride; + } + return; + } + case AtomFloatFeature::IS_CARBON: { + const MT offset = offset_carbon ? MT(1) : MT(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::convertToFeatureType(MT(graph.atoms[i].atomicNum == carbon_atomic_num) - offset); + data += stride; + } + return; + } + default: + break; + } + + // Missing implementation + assert(0); + for (uint32_t i = 0; i < num_atoms; ++i) { + *data = FeatureValues::nan_value; + data += stride; + } +} + +// Explicit instantiations, so that the function can be templated +// but still be used from other cpp files. +template void get_atom_float_feature(const GraphData& graph, int16_t* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); +template void get_atom_float_feature(const GraphData& graph, float* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); +template void get_atom_float_feature(const GraphData& graph, double* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); + +// This table is from the SingleBondRadius column of graphium/features/periodic_table.csv +const double single_bond_lengths[] = { + 0.32, 0.46, 1.33, 1.02, 0.85, 0.75, 0.71, 0.63, 0.64, + 0.67, 1.55, 1.39, 1.26, 1.16, 1.11, 1.03, 0.99, 0.96, 1.96, + 1.71, 1.48, 1.36, 1.34, 1.22, 1.19, 1.16, 1.11, 1.10, 1.12, + 1.18, 1.24, 1.21, 1.21, 1.16, 1.14, 1.17, 2.10, 1.85, 1.63, + 1.54, 1.47, 1.38, 1.28, 1.25, 1.25, 1.20, 1.28, 1.36, 1.42, + 1.40, 1.40, 1.36, 1.33, 1.31, 2.32, 1.96, 1.80, 1.63, 1.76, + 1.74, 1.73, 1.72, 1.68, 1.69, 1.68, 1.67, 1.66, 1.65, 1.64, + 1.70, 1.62, 1.52, 1.46, 1.37, 1.31, 1.29, 1.22, 1.23, 1.24, + 1.33, 1.44, 1.44, 1.51, 1.45, 1.47, 1.42, 2.23, 2.01, 1.86, + 1.75, 1.69, 1.70, 1.71, 1.72, 1.66, 1.66, 1.68, 1.68, 1.65, + 1.67, 1.73, 1.76, 1.61, 1.57, 1.49, 1.43, 1.41, 1.34, 1.29, + 1.28, 1.21, 1.22, 1.36, 1.43, 1.62, 1.75, 1.65, 1.57, +}; +// This table is from the DoubleBondRadius column of graphium/features/periodic_table.csv +const double double_bond_lengths[] = { + qNaN, qNaN, 1.24, 0.90, 0.78, 0.67, 0.60, 0.57, 0.59, + 0.96, 1.60, 1.32, 1.13, 1.07, 1.02, 0.94, 0.95, 1.07, 1.93, + 1.47, 1.16, 1.17, 1.12, 1.11, 1.05, 1.09, 1.03, 1.01, 1.15, + 1.20, 1.17, 1.11, 1.14, 1.07, 1.09, 1.21, 2.02, 1.57, 1.30, + 1.27, 1.25, 1.21, 1.20, 1.14, 1.10, 1.17, 1.39, 1.44, 1.36, + 1.30, 1.33, 1.28, 1.29, 1.35, 2.09, 1.61, 1.39, 1.37, 1.38, + 1.37, 1.35, 1.34, 1.34, 1.35, 1.35, 1.33, 1.33, 1.33, 1.31, + 1.29, 1.31, 1.28, 1.26, 1.20, 1.19, 1.16, 1.15, 1.12, 1.21, + 1.42, 1.42, 1.35, 1.41, 1.35, 1.38, 1.45, 2.18, 1.73, 1.53, + 1.43, 1.38, 1.34, 1.36, 1.35, 1.35, 1.36, 1.39, 1.40, 1.40, + qNaN, 1.39, qNaN, 1.41, 1.40, 1.36, 1.28, 1.28, 1.25, 1.25, + 1.16, 1.16, 1.37, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, +}; +// This table is from the TripleBondRadius column of graphium/features/periodic_table.csv +const double triple_bond_lengths[] = { + qNaN, qNaN, qNaN, 0.85, 0.73, 0.60, 0.54, 0.53, 0.53, + qNaN, qNaN, 1.27, 1.11, 1.02, 0.94, 0.95, 0.93, 0.96, qNaN, + 1.33, 1.14, 1.08, 1.06, 1.03, 1.03, 1.02, 0.96, 1.01, 1.20, + qNaN, 1.21, 1.14, 1.06, 1.07, 1.10, 1.08, qNaN, 1.39, 1.24, + 1.21, 1.16, 1.13, 1.10, 1.03, 1.06, 1.12, 1.37, qNaN, 1.46, + 1.32, 1.27, 1.21, 1.25, 1.22, qNaN, 1.49, 1.39, 1.31, 1.28, + qNaN, qNaN, qNaN, qNaN, 1.32, qNaN, qNaN, qNaN, qNaN, qNaN, + qNaN, 1.31, 1.22, 1.19, 1.15, 1.10, 1.09, 1.07, 1.10, 1.23, + qNaN, 1.50, 1.37, 1.35, 1.29, 1.38, 1.33, qNaN, 1.59, 1.40, + 1.36, 1.29, 1.18, 1.16, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, + qNaN, qNaN, qNaN, qNaN, 1.31, 1.26, 1.21, 1.19, 1.18, 1.13, + 1.12, 1.18, 1.30, qNaN, qNaN, qNaN, qNaN, qNaN, qNaN, +}; + +// Fills in a particular bond float `feature` into `data`, for all bonds. +// See the declaration in float_features.h for more details. +template +void get_bond_float_feature(const GraphData& graph, T* data, BondFeature feature, size_t stride) { + const uint32_t num_bonds = graph.num_bonds; + switch (feature) { + case BondFeature::TYPE_FLOAT: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (size_t i = 0; i < num_bonds; ++i, data += stride) { + auto type = graph.bonds[i].bondType; + double value = 0; + switch (type) { + case RDKit::Bond::BondType::SINGLE: value = 1.0; break; + case RDKit::Bond::BondType::DOUBLE: value = 2.0; break; + case RDKit::Bond::BondType::TRIPLE: value = 3.0; break; + case RDKit::Bond::BondType::AROMATIC: value = 1.5; break; + default: value = mol.getBondWithIdx(i)->getBondTypeAsDouble(); + } + *data = FeatureValues::convertToFeatureType(value); + } + return; + } + case BondFeature::IN_RING: { + const RDKit::ROMol& mol = *graph.mol.get(); + for (size_t i = 0; i < num_bonds; ++i, data += stride) { + bool is_in_ring = mol.getRingInfo()->numBondRings(i) != 0; + *data = is_in_ring ? FeatureValues::one : FeatureValues::zero; + } + return; + } + case BondFeature::CONJUGATED: { + for (size_t i = 0; i < num_bonds; ++i, data += stride) { + bool is_conjugated = graph.bonds[i].isConjugated; + *data = is_conjugated ? FeatureValues::one : FeatureValues::zero; + } + return; + } + case BondFeature::CONFORMER_BOND_LENGTH: { + RDKit::ROMol& mol = *graph.mol.get(); + if (mol.beginConformers() == mol.endConformers()) { + // Try to generate a conformer + RDKit::DGeomHelpers::EmbedParameters params; + params.enforceChirality = false; + params.ignoreSmoothingFailures = true; + params.useBasicKnowledge = true; + params.useExpTorsionAnglePrefs = true; + params.optimizerForceTol = 0.1; + int id = RDKit::DGeomHelpers::EmbedMolecule(mol, params); + if (id == -1) { + // Failed to generate a conformer + const uint32_t num_bonds = graph.num_bonds; + for (uint32_t i = 0; i < num_bonds; ++i, data += stride) { + *data = FeatureValues::nan_value; + } + return; + } + assert(mol.beginConformers() != mol.endConformers()); + } + const RDKit::Conformer& conformer = mol.getConformer(); + const auto& positions = conformer.getPositions(); + for (uint32_t i = 0; i < num_bonds; ++i, data += stride) { + const uint32_t begin_atom = graph.bonds[i].beginAtomIdx; + const uint32_t end_atom = graph.bonds[i].endAtomIdx; + const RDGeom::Point3D diff = (positions[end_atom] - positions[begin_atom]); + // Unfortunately, the length() function on Point3D is virtual, so compute it manually. + const double length = std::sqrt(diff.x * diff.x + diff.y * diff.y + diff.z * diff.z); + *data = FeatureValues::convertToFeatureType(length); + } + return; + } + case BondFeature::ESTIMATED_BOND_LENGTH: { + for (uint32_t i = 0; i < num_bonds; ++i, data += stride) { + const uint32_t begin_atom = graph.bonds[i].beginAtomIdx; + const uint32_t end_atom = graph.bonds[i].endAtomIdx; + const int atomic_num1 = graph.atoms[begin_atom].atomicNum; + const bool atom1_valid = (atomic_num1 >= 1 && atomic_num1 <= 118); + const int atomic_num2 = graph.atoms[end_atom].atomicNum; + const bool atom2_valid = (atomic_num2 >= 1 && atomic_num2 <= 118); + assert(atom1_valid && atom2_valid); + if (!atom1_valid || !atom2_valid) { + *data = FeatureValues::nan_value; + continue; + } + + const auto type = graph.bonds[i].bondType; + if (type == RDKit::Bond::BondType::SINGLE) { + // All atoms have a single bond length + *data = FeatureValues::convertToFeatureType( + single_bond_lengths[atomic_num1 - 1] + single_bond_lengths[atomic_num2 - 1]); + continue; + } + if (type == RDKit::Bond::BondType::DOUBLE) { + const double length1 = (double_bond_lengths[atomic_num1 - 1] >= 0) ? + double_bond_lengths[atomic_num1 - 1] : single_bond_lengths[atomic_num1 - 1]; + const double length2 = (double_bond_lengths[atomic_num2 - 1] >= 0) ? + double_bond_lengths[atomic_num2 - 1] : single_bond_lengths[atomic_num2 - 1]; + *data = FeatureValues::convertToFeatureType(length1 + length2); + continue; + } + if (type == RDKit::Bond::BondType::TRIPLE) { + const double length1 = (triple_bond_lengths[atomic_num1 - 1] >= 0) ? + triple_bond_lengths[atomic_num1 - 1] : single_bond_lengths[atomic_num1 - 1]; + const double length2 = (triple_bond_lengths[atomic_num2 - 1] >= 0) ? + triple_bond_lengths[atomic_num2 - 1] : single_bond_lengths[atomic_num2 - 1]; + *data = FeatureValues::convertToFeatureType(length1 + length2); + continue; + } + if (type != RDKit::Bond::BondType::AROMATIC) { + *data = FeatureValues::nan_value; + } + + // Aromatic case + double length1 = single_bond_lengths[atomic_num1 - 1]; + double length2 = single_bond_lengths[atomic_num2 - 1]; + if (double_bond_lengths[atomic_num1] >= 0) { + length1 = 0.5 * (length1 + double_bond_lengths[atomic_num1 - 1]); + } + if (double_bond_lengths[atomic_num2] >= 0) { + length2 = 0.5 * (length2 + double_bond_lengths[atomic_num2 - 1]); + } + *data = FeatureValues::convertToFeatureType(length1 + length2); + } + return; + } + default: + // Missing implementation + assert(0); + for (uint32_t i = 0; i < num_bonds; ++i, data += stride) { + *data = FeatureValues::nan_value; + } + return; + } +} + +// Explicit instantiations, so that the function can be templated +// but still be used from other cpp files. +template void get_bond_float_feature(const GraphData& graph, int16_t* data, BondFeature feature, size_t stride); +template void get_bond_float_feature(const GraphData& graph, float* data, BondFeature feature, size_t stride); +template void get_bond_float_feature(const GraphData& graph, double* data, BondFeature feature, size_t stride); diff --git a/graphium/graphium_cpp/float_features.h b/graphium/graphium_cpp/float_features.h new file mode 100644 index 000000000..9ec49d97a --- /dev/null +++ b/graphium/graphium_cpp/float_features.h @@ -0,0 +1,93 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This header file declares functions for float-valued atom and bond features, +//! defined in float_features.cpp and called from features.cpp + +#pragma once + +#include "features.h" + +#include + +#include + +//! Fills in a particular atom float `feature` into `data`, for all atoms. +//! Template type `T` can be `int16_t` (FP16), `float`, or `double`. +//! Implementation is in float_features.cpp +//! +//! @param graph Molecule containing the source data +//! @param data Destination array, pointing to the first atom's value for this +//! feature to be filled in. Each atom's data for this feature is just 1 value, +//! but because different features are interleaved, the values for +//! each atom are spaced `stride` values apart. +//! @param feature The atom feature to write into `data` +//! @param stride The number of values from the beginning of one atom's data to the beginning +//! of the next atom's data, which may include values for other features +//! @param offset_carbon If true (the default), a reference value for carbon is subtracted, +//! so that carbon atoms would usually have value zero, if applicable. +//! @see AtomFloatFeature +template +void get_atom_float_feature(const GraphData& graph, T* data, AtomFloatFeature feature, size_t stride, bool offset_carbon = true); + +// Instantiation declarations of `get_atom_float_feature` for `int16_t` (FP16), +// `float` (FP32), and `double` (FP64). The explicit instantiations are in float_features.cpp +extern template void get_atom_float_feature(const GraphData& graph, int16_t* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); +extern template void get_atom_float_feature(const GraphData& graph, float* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); +extern template void get_atom_float_feature(const GraphData& graph, double* data, AtomFloatFeature feature, size_t stride, bool offset_carbon); + +//! Fills in a particular bond float `feature` into `data`, for all bonds. +//! Template type `T` can be `int16_t` (FP16), `float`, or `double`. +//! Implementation is in float_features.cpp +//! +//! @param graph Molecule containing the source data +//! @param data Destination array, pointing to the first bond's value for this +//! feature to be filled in. Each bond's data for this feature is just 1 value, +//! but because different features are interleaved, the values for +//! each bond are spaced `stride` values apart. +//! @param feature The bond feature to write into `data` +//! @param stride The number of values from the beginning of one bond's data to the beginning +//! of the next bond's data, which may include values for other features +//! @see BondFeature +template +void get_bond_float_feature(const GraphData& graph, T* data, BondFeature feature, size_t stride); + +// Instantiation declarations of `get_bond_float_feature` for `int16_t` (FP16), +// `float` (FP32), and `double` (FP64). The explicit instantiations are in float_features.cpp +extern template void get_bond_float_feature(const GraphData& graph, int16_t* data, BondFeature feature, size_t stride); +extern template void get_bond_float_feature(const GraphData& graph, float* data, BondFeature feature, size_t stride); +extern template void get_bond_float_feature(const GraphData& graph, double* data, BondFeature feature, size_t stride); + +// This table is from the Group column of graphium/features/periodic_table.csv +constexpr uint8_t atomicNumToGroupTable[] = { + 1, 18, 1, 2, 13, 14, 15, 16, 17, + 18, 1, 2, 13, 14, 15, 16, 17, 18, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 1, 2, 3, 19, 19, + 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, + 19, 19, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 1, 2, 3, + 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, + 19, 19, 19, 19, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, +}; +constexpr size_t groupCount = 19; + +// This table is from the Period column of graphium/features/periodic_table.csv +constexpr uint8_t atomicNumToPeriodTable[] = { + 1, 1, 2, 2, 2, 2, 2, 2, 2, + 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7 +}; +constexpr size_t periodCount = 7; diff --git a/graphium/graphium_cpp/graphium_cpp.cpp b/graphium/graphium_cpp/graphium_cpp.cpp new file mode 100644 index 000000000..84909d083 --- /dev/null +++ b/graphium/graphium_cpp/graphium_cpp.cpp @@ -0,0 +1,139 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file specifies which functions are exported to Python, +//! as well as defining `parse_mol` and `get_canonical_atom_order`, +//! declared in features.h and called from features.cpp and labels.cpp + +#include "features.h" +#include "labels.h" + +// C++ standard library headers +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// RDKit headers +#include +#include +#include +#include +#include +#include +#include +#include + + +// PyBind and Torch headers for use by library to be imported by Python +#include +#include +#include +#include + +// Creates an RWMol from a SMILES string. +// See the declaration in features.h for more details. +std::unique_ptr parse_mol( + const std::string& smiles_string, + bool explicit_H, + bool ordered) { + + // Parse SMILES string with default options + RDKit::SmilesParserParams params; + std::unique_ptr mol{ RDKit::SmilesToMol(smiles_string, params) }; + if (!mol) { + return mol; + } + + if (ordered) { + // Do not order atoms to the canonical order. + // Order them based only on the atom map, and only + // if they indicate a valid order. + const unsigned int num_atoms = mol->getNumAtoms(); + std::vector atom_order(num_atoms); + for (unsigned int i = 0; i < num_atoms; ++i) { + RDKit::Atom* atom = mol->getAtomWithIdx(i); + if (!atom->hasProp(RDKit::common_properties::molAtomMapNumber)) { + ordered = false; + // Don't break, because the property needs to be cleared + // from any following atoms that might have it. + } + else { + atom_order[i] = (unsigned int)atom->getAtomMapNum(); + + // 0-based, and must be in range + if (atom_order[i] >= num_atoms) { + ordered = false; + } + + // Clear the property, so that any equivalent molecules will + // get the same canoncial order. + atom->clearProp(RDKit::common_properties::molAtomMapNumber); + } + } + + if (ordered) { + // Invert the order + // Use max value as a "not found yet" value + constexpr unsigned int not_found_value = std::numeric_limits::max(); + std::vector inverse_order(num_atoms, not_found_value); + for (unsigned int i = 0; i < num_atoms; ++i) { + unsigned int index = atom_order[i]; + // Can't have the same index twice + if (inverse_order[index] != not_found_value) { + ordered = false; + break; + } + inverse_order[index] = i; + } + + if (ordered) { + // Reorder the atoms to the explicit order + mol.reset(static_cast(RDKit::MolOps::renumberAtoms(*mol, inverse_order))); + } + } + } + if (explicit_H) { + RDKit::MolOps::addHs(*mol); + } + else { + // Default params for SmilesToMol already calls removeHs, + // and calling it again shouldn't have any net effect. + //RDKit::MolOps::removeHs(*mol); + } + return mol; +} + +// Determines a canonical ordering of the atoms in `mol` +// See the declaration in features.h for more details. +void get_canonical_atom_order(const RDKit::ROMol& mol, std::vector& atom_order) { + RDKit::Canon::rankMolAtoms(mol, atom_order); + assert(atom_order.size() == mol->getNumAtoms()); +} + +// This is necessary to export Python functions in a Python module named graphium_cpp. +PYBIND11_MODULE(graphium_cpp, m) { + m.doc() = "graphium C++ plugin"; // Python module docstring + + // Functions in labels.cpp + m.def("load_num_cols_and_dtypes", &load_num_cols_and_dtypes, "Loads from a cache file, a list of integers representing the number of columns in each task, and a list of integers representing the torch ScalarType of the task's data."); + m.def("load_metadata_tensors", &load_metadata_tensors, "Loads from cache files for a specific stage, a torch tensor containing all SMILES strings contatenated, another with the offsets of all SMILES strings, two for the nubmer of nodes and edges in each molecule, and optionally another representing the offsets of molecules in files."); + m.def("load_stats", &load_stats, "Loads from a cache file of a specific task, the stats for each column, for use in denormalization."); + m.def("concatenate_strings", &concatenate_strings, "Accepts a Numpy array of strings or Python list of strings and returns a PyTorch tensor of all of the characters and another tensor containing indices into the other tensor indicating where each string begins."); + m.def("prepare_and_save_data", &prepare_and_save_data, "Accepts a dict mapping dataset (task) names to dicts with \"smiles\", \"labels\", and \"label_offsets\" data, and returns the data that would be returned by load_metadata_tensors, load_stats, and load_num_cols_and_dtypes."); + m.def("load_labels_from_index", &load_labels_from_index, "Loads label data from disk, for a specific stage and molecule."); + m.def("extract_string", &extract_string, "Extracts a single string from a Tensor of contatenated strings."); + + // Functions in features.cpp + m.def("atom_onehot_feature_names_to_tensor", &atom_onehot_feature_names_to_tensor, "Accepts feature names and returns a tensor representing them as integers"); + m.def("atom_float_feature_names_to_tensor", &atom_float_feature_names_to_tensor, "Accepts feature names and returns a tensor representing them as integers"); + m.def("bond_feature_names_to_tensor", &bond_feature_names_to_tensor, "Accepts feature names and returns a tensor representing them as integers"); + m.def("positional_feature_options_to_tensor", &positional_feature_options_to_tensor, "Accepts feature names, levels, and options, and returns a tensor representing them as integers"); + m.def("featurize_smiles", &featurize_smiles, "Accepts a SMILES string and returns tensors representing the features"); +} diff --git a/graphium/graphium_cpp/graphormer.cpp b/graphium/graphium_cpp/graphormer.cpp new file mode 100644 index 000000000..a822c8e85 --- /dev/null +++ b/graphium/graphium_cpp/graphormer.cpp @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file defines and instantiates the `compute_graphormer_distances` function, +//! declared in graphormer.h and called from features.cpp + +#include "graphormer.h" + +#include +#include +#include +#include + +// Computes the shortest path distance, along edges, between all pairs of nodes, +// outputting to `all_pairs_distances`. +// See the declaration in graphormer.h for more details. +template +void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances) { + + // Compute all pairs shortest paths. + // Because this is a sparse graph treated as having unweighted edges, + // BFS on each node is faster than Dijkstra's or Floyd-Warshall's. + + if (queue.capacity() == 0) { + queue.reserve(n); + } + + all_pairs_distances.resize(size_t(n) * n); + std::fill(all_pairs_distances.begin(), all_pairs_distances.end(), T(-1)); + + for (uint32_t start_index = 0; start_index < n; ++start_index) { + queue.resize(0); + size_t queue_head = 0; + queue.push_back({ start_index,0 }); + T* const distances = all_pairs_distances.data() + start_index * n; + while (queue.size() != queue_head) { + auto [current_node, current_distance] = queue[queue_head]; + ++queue_head; + + if (distances[current_node] != T(-1)) { + continue; + } + + distances[current_node] = T(current_distance); + + ++current_distance; + + const uint32_t* neighbor_start = neighbors + neighbor_starts[current_node]; + const uint32_t* neighbor_end = neighbors + neighbor_starts[current_node+1]; + for (; neighbor_start != neighbor_end; ++neighbor_start) { + queue.push_back({ *neighbor_start,current_distance }); + } + } + } +} + +// Explicit instantiations of `compute_graphormer_distances` for `float` and `double` +template void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances); +template void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances); diff --git a/graphium/graphium_cpp/graphormer.h b/graphium/graphium_cpp/graphormer.h new file mode 100644 index 000000000..0ff1acc33 --- /dev/null +++ b/graphium/graphium_cpp/graphormer.h @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This header file declares the `compute_graphormer_distances` function, +//! defined in graphormer.cpp and called from features.cpp + +#pragma once + +#include +#include +#include + +//! Computes the shortest path distance, along edges, between all pairs of nodes, outputting to +//! `all_pairs_distances`. +//! Template type `T` can be `float` or `double`. Implementation is in graphormer.cpp +//! +//! @param n Number of nodes +//! @param neighbor_starts Array of `n+1` indices into `neighbors`, indicating where each node's +//! neighbors start, plus one at the end to indicate the full length of +//! `neighbors` +//! @param neighbors Concatenated array of all neighbors of all nodes, in order +//! @param queue vector used for temporary storage internally +//! @param all_pairs_distances This will be filled with the unweighted lengths of the shortest +//! path between each pair of nodes. +template +void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances); + +// Instantiation declarations of `compute_graphormer_distances` for `float` and `double` +// The explicit instantiations are in graphormer.cpp +extern template void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances); +extern template void compute_graphormer_distances( + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + std::vector>& queue, + std::vector& all_pairs_distances); diff --git a/graphium/graphium_cpp/labels.cpp b/graphium/graphium_cpp/labels.cpp new file mode 100644 index 000000000..974ba24d7 --- /dev/null +++ b/graphium/graphium_cpp/labels.cpp @@ -0,0 +1,2185 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file defines functions for preprocessing and looking up label data, +//! some of which are declared in labels.h for exporting to Python. + +#include "labels.h" + +#include "features.h" + +// C++ standard library headers +#include +#include +#include + +// RDKit headers +#include +#include +#include +#include + +// Numpy array headers +#include +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include + +#ifdef _WIN32 +// Windows file handling wrappers +#define WIN32_LEAN_AND_MEAN +#include + +using FileType = HANDLE; +const auto INVALID_FILE = INVALID_HANDLE_VALUE; + +// Opens a file for reading on Windows, working around an issue with non-ASCII +// file paths in fopen on Windows. +static FileType fopen_read_wrapper(const std::filesystem::path& file_path) { + return CreateFileW( + file_path.wstring().c_str(), + GENERIC_READ, + FILE_SHARE_READ, + nullptr, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + nullptr); +} + +// Opens a file for writing on Windows, working around an issue with non-ASCII +// file paths in fopen on Windows. +static FileType fopen_write_wrapper(const std::filesystem::path& file_path) { + return CreateFileW( + file_path.wstring().c_str(), + GENERIC_WRITE, + 0, + nullptr, + CREATE_ALWAYS, + FILE_ATTRIBUTE_NORMAL, + nullptr); +} + +// Reads from a file on Windows +static size_t fread_wrapper(void* buffer, size_t bytes, FileType file) { + size_t total_bytes_read = 0; + while (bytes > 0) { + // NOTE: ReadFile should support reads up to (2^32 - 1) bytes, + // but might as well limit it to 1GB (2^30 bytes) at a time, + // just in case there are issues at or above 2GB. + const DWORD max_read_size = 1024 * 1024 * 1024; + const DWORD bytes_to_read = (bytes > max_read_size) ? max_read_size : (DWORD)bytes; + DWORD bytes_read; + BOOL success = ReadFile(file, buffer, bytes_to_read, &bytes_read, nullptr); + total_bytes_read += (success ? bytes_read : 0); + if (!success || bytes_read != bytes_to_read) { + return total_bytes_read; + } + bytes -= bytes_read; + } + return total_bytes_read; +} + +// Writes to a file on Windows +static size_t fwrite_wrapper(const void* buffer, size_t bytes, FileType file) { + size_t total_bytes_written = 0; + while (bytes > 0) { + // NOTE: ReadFile should support reads up to (2^32 - 1) bytes, + // but might as well limit it to 1GB (2^30 bytes) at a time, + // just in case there are issues at or above 2GB. + const DWORD max_write_size = 1024 * 1024 * 1024; + const DWORD bytes_to_write = (bytes > max_write_size) ? max_write_size : (DWORD)bytes; + DWORD bytes_written; + BOOL success = WriteFile(file, buffer, bytes_to_write, &bytes_written, nullptr); + total_bytes_written += (success ? bytes_written : 0); + if (!success || bytes_written != bytes_to_write) { + return total_bytes_written; + } + bytes -= bytes_written; + } + return total_bytes_written; +} + +// Seeks to a 64-bit absolute position in a file on Windows. +static int fseek_wrapper(FileType file, int64_t file_pointer) { + LARGE_INTEGER file_pointer_union; + file_pointer_union.QuadPart = (LONGLONG)file_pointer; + BOOL success = SetFilePointerEx(file, file_pointer_union, nullptr, FILE_BEGIN); + return (success == 0); +} + +// Closes a file on Windows +static void fclose_wrapper(FileType file) { + CloseHandle(file); +} + +#else +// Linux file handling wrappers +#include + +using FileType = FILE*; +const auto INVALID_FILE = (FILE*)nullptr; + +// Opens a file for reading on non-Windows platforms, where fopen supports UTF-8 file paths. +static FileType fopen_read_wrapper(const std::filesystem::path& file_path) { + return fopen(file_path.string().c_str(), "rb"); +} + +// Opens a file for writing on non-Windows platforms, where fopen supports UTF-8 file paths. +static FileType fopen_write_wrapper(const std::filesystem::path& file_path) { + return fopen(file_path.string().c_str(), "wb"); +} + +// Reads from a file on non-Windows platforms +static size_t fread_wrapper(void* buffer, size_t bytes, FileType file) { + return fread(buffer, 1, bytes, file); +} + +// Writes to a file on non-Windows platforms +static size_t fwrite_wrapper(const void* buffer, size_t bytes, FileType file) { + return fwrite(buffer, 1, bytes, file); +} + +// Seeks to a 64-bit absolute position in a file on non-Windows platforms. +static int fseek_wrapper(FileType file, int64_t file_pointer) { + // NOTE: If these files could ever be larger than 2GB each, fseek won't + // work on platforms where "long" is a 32-bit type (e.g. 32-bit Linux) + static_assert(sizeof(long) == sizeof(int64_t)); + return fseek(file, (long)file_pointer, SEEK_SET); +} + +// Closes a file on non-Windows platforms +static void fclose_wrapper(FileType file) { + fclose(file); +} + +#endif // End of file handling wrappers + +struct InitNumpyArrayModule { + InitNumpyArrayModule() { + // This imports the numpy array module, and it must be + // called exactly once before numpy array functions are used. + if (_import_array() < 0) { + printf("ERROR: Failed to import numpy.core.multiarray from C++ in graphium_cpp module\n"); + } + } +}; +static void ensure_numpy_array_module_initialized() { + // Function scope static variables will be initialized upon the first call, + // and only once, in a threadsafe manner. + static InitNumpyArrayModule numpy_initializer; +} + +// Simple representation of the limited data needed for most molecules during label data +// merging, and for the cached `num_nodes` and `num_edges` tensors. +struct MolBriefData { + uint64_t unique_id[2]; + uint32_t num_nodes; + uint32_t num_edges; +}; + +// Computes `MolBriefData` from a molecule's SMILES string, optionally including +// a compacted InChI key for the `unique_id` values. InChI keys are very expensive to +// compute, and only used for identifying equivalent molecules, so if not merging equivalent +// molecules (e.g. for inference), it saves time to skip computing them. +static MolBriefData smiles_to_brief_data( + const std::string& smiles_string, + bool add_self_loop, + bool explicit_H, + bool compute_inchi_key) { + + // Don't add explicit_H here, in case it affects MolToInchiKey (though it really shouldn't) + std::unique_ptr mol{ parse_mol(smiles_string, false) }; + if (!mol) { + return MolBriefData{ {0,0}, 0, 0 }; + } + + uint64_t id0 = 0; + uint64_t id1 = 0; + if (compute_inchi_key) { + const std::string inchiKeyString = MolToInchiKey(*mol, "/FixedH /SUU /RecMet /KET /15T"); + size_t n = inchiKeyString.size(); + // Format: AAAAAAAAAAAAAA-BBBBBBBBFV-P + // According to https://www.inchi-trust.org/technical-faq/ + assert(n == 27 && inchiKeyString[14] == '-' && inchiKeyString[25] == '-'); + // Convert from capital letter characters to 64-bit integers: + // 13 characters for first integer, 12 characters for 2nd integer. + // Neither should overflow a 64-bit unsigned integer. + id0 = (n > 0) ? (inchiKeyString[0] - 'A') : 0; + for (size_t i = 1; i < 13 && i < n; ++i) { + id0 = 26*id0 + (inchiKeyString[i] - 'A'); + } + id1 = (13 < n) ? (inchiKeyString[13] - 'A') : 0; + for (size_t i = 15; i < 25 && i < n; ++i) { + id1 = 26*id1 + (inchiKeyString[i] - 'A'); + } + if (26 < n) { + id1 = 26*id1 + (inchiKeyString[26] - 'A'); + } + } + + // Now handle explicit_H + if (explicit_H) { + RDKit::MolOps::addHs(*mol); + } + else { + // Default params for SmilesToMol already calls removeHs, + // and calling it again shouldn't have any net effect. + //RDKit::MolOps::removeHs(*mol); + } + + return MolBriefData{ + {id0, id1}, + mol->getNumAtoms(), + 2*mol->getNumBonds() + (add_self_loop ? mol->getNumAtoms() : 0) + }; +} + +//! Normalization methods for use in `prepare_and_save_data`. +//! Not to be confused with the normalization of the graph Laplacian matrix in +//! positional features. +enum class NormalizationMethod { + NONE, //!< No normalization + NORMAL, //!< Subtract mean and divide by standard deviation + UNIT //!< Subtract minimum and divide by range +}; +//! Normalization options for use in `prepare_and_save_data`. +//! All data will be clamped to be between `min_clipping` and `max_clipping`. +struct NormalizationOptions { + NormalizationMethod method = NormalizationMethod::NONE; + double min_clipping = -std::numeric_limits::infinity(); + double max_clipping = std::numeric_limits::infinity(); +}; + +//! To avoid having one gigantic label data file or millions of tiny label data files, +//! store label data for 1024 molecules in each file. +constexpr size_t num_mols_per_file = 1024; + +//! Quickly creates a filename of the format 0000000.tmp, with at least 7 digits. +static void get_mol_label_filename( + char filename[25], + uint64_t file_num) { + + size_t filename_index = 0; + while (file_num != 0) { + filename[filename_index] = '0' + (file_num % 10); + ++filename_index; + file_num /= 10; + } + while (filename_index < 7) { + filename[filename_index] = '0'; + ++filename_index; + } + std::reverse(filename, filename + filename_index); + filename[filename_index] = '.'; + filename[filename_index+1] = 't'; + filename[filename_index+2] = 'm'; + filename[filename_index+3] = 'p'; + filename[filename_index+4] = 0; +} + +struct Types { + size_t size; + int numpy_type; + c10::ScalarType torch_type; +}; +constexpr size_t num_supported_types = 3; +constexpr Types supported_types[num_supported_types] = { + {2, NPY_FLOAT16, c10::ScalarType::Half}, + {4, NPY_FLOAT32, c10::ScalarType::Float}, + {8, NPY_FLOAT64, c10::ScalarType::Double} +}; +static bool is_supported_numpy_type(int type) { + return (type == supported_types[0].numpy_type) || + (type == supported_types[1].numpy_type) || + (type == supported_types[2].numpy_type); +}; +static size_t numpy_type_index(int type) { + if (type == supported_types[0].numpy_type) { + return 0; + } + if (type == supported_types[1].numpy_type) { + return 1; + } + if (type == supported_types[2].numpy_type) { + return 2; + } + return num_supported_types; +}; +static size_t torch_type_index(c10::ScalarType type) { + if (type == supported_types[0].torch_type) { + return 0; + } + if (type == supported_types[1].torch_type) { + return 1; + } + if (type == supported_types[2].torch_type) { + return 2; + } + return num_supported_types; +}; + +// Filenames for cached data other than label data +constexpr const char*const label_metadata_filename = "label_metadata.tmp"; +constexpr const char*const file_data_offsets_filename = "file_data_offsets.tmp"; +constexpr const char*const concat_smiles_filename = "concat_smiles.tmp"; +constexpr const char*const smiles_offsets_filename = "smiles_offsets.tmp"; +constexpr const char*const num_nodes_filename = "num_nodes.tmp"; +constexpr const char*const num_edges_filename = "num_edges.tmp"; + +// Called by `prepare_and_save_data` to write out a file representing the number of +// columns and data type for each task/label. +static bool save_num_cols_and_dtypes( + const std::filesystem::path& common_path, + const std::vector& label_num_cols, + const std::vector& label_data_types) { + + const uint64_t num_labels = label_num_cols.size(); + if (num_labels != label_data_types.size()) { + return false; + } + std::filesystem::path file_path(common_path / label_metadata_filename); + FileType file = fopen_write_wrapper(file_path); + if (file == INVALID_FILE) { + return false; + } + size_t num_bytes_written = fwrite_wrapper(&num_labels, sizeof(num_labels), file); + num_bytes_written += fwrite_wrapper(label_num_cols.data(), sizeof(label_num_cols[0])*num_labels, file); + num_bytes_written += fwrite_wrapper(label_data_types.data(), sizeof(label_data_types[0])*num_labels, file); + fclose_wrapper(file); + if (num_bytes_written != sizeof(num_labels) + (sizeof(label_num_cols[0]) + sizeof(label_data_types[0]))*num_labels) { + return false; + } + return true; +} + +// Reads the number of columns and data type for each task, from the common label +// metadata file. +// See the declaration in labels.h for more details. +std::tuple< + std::vector, + std::vector +> load_num_cols_and_dtypes( + const std::string& processed_graph_data_path, + const std::string& data_hash) { + + std::vector label_num_cols; + std::vector label_data_types; + std::filesystem::path file_path( + std::filesystem::path(processed_graph_data_path) / data_hash / label_metadata_filename + ); + FileType file = fopen_read_wrapper(file_path); + if (file == INVALID_FILE) { + return std::make_tuple(std::move(label_num_cols), std::move(label_data_types)); + } + uint64_t num_labels = 0; + size_t num_bytes_read = fread_wrapper(&num_labels, sizeof(num_labels), file); + // Trying to allocate 2^60 would fail, unless it overflows and then crashes + if (num_bytes_read != sizeof(num_labels) || num_labels == 0 || num_labels >= (uint64_t(1) << (64-4))) { + fclose_wrapper(file); + return std::make_tuple(std::move(label_num_cols), std::move(label_data_types)); + } + label_num_cols.resize(num_labels, 0); + num_bytes_read = fread_wrapper(label_num_cols.data(), sizeof(label_num_cols[0])*num_labels, file); + if (num_bytes_read != sizeof(label_num_cols[0])*num_labels) { + fclose_wrapper(file); + label_num_cols.resize(0); + return std::make_tuple(std::move(label_num_cols), std::move(label_data_types)); + } + label_data_types.resize(num_labels, -1); + num_bytes_read = fread_wrapper(label_data_types.data(), sizeof(label_data_types[0])*num_labels, file); + fclose_wrapper(file); + if (num_bytes_read != sizeof(label_data_types[0])*num_labels) { + label_num_cols.resize(0); + label_data_types.resize(0); + } + return std::make_tuple(std::move(label_num_cols), std::move(label_data_types)); +} + +// Helper function to save a file containing a 64-bit number of values, +// followed by all of those values, e.g. for saving 1D tensors. +// Does not save a representation of the data type, so callers must keep track. +// Pairs with `load_array_from_file`. +template +bool save_array_to_file( + const std::filesystem::path& directory, + const char*const filename, + const T* data, + const uint64_t n) { + + std::filesystem::path file_path(directory / filename); + FileType file = fopen_write_wrapper(file_path); + if (file == INVALID_FILE) { + return false; + } + size_t num_bytes_written = fwrite_wrapper(&n, sizeof(n), file); + num_bytes_written += fwrite_wrapper(data, sizeof(T)*n, file); + fclose_wrapper(file); + if (num_bytes_written != sizeof(n) + sizeof(T)*n) { + return false; + } + return true; +} + +// Helper function to load values from a file containing a 64-bit number of values, +// followed by all of those values, e.g. for loading 1D tensors. +// The caller must know the correct data type. +// Pairs with `save_array_to_file`. +template +[[nodiscard]] uint64_t load_array_from_file( + const std::filesystem::path& directory, + const char*const filename, + std::unique_ptr& data) { + + data.reset(nullptr); + + std::filesystem::path file_path(directory / filename); + FileType file = fopen_read_wrapper(file_path); + if (file == INVALID_FILE) { + return 0; + } + uint64_t n; + size_t num_bytes_read = fread_wrapper(&n, sizeof(n), file); + // Trying to allocate 2^60 would fail, unless it overflows and then crashes + if (num_bytes_read != sizeof(n) || n == 0 || n >= (uint64_t(1) << (64-4))) { + fclose_wrapper(file); + return 0; + } + data.reset(new T[n]); + num_bytes_read = fread_wrapper(data.get(), sizeof(T)*n, file); + fclose_wrapper(file); + if (num_bytes_read != sizeof(T)*n) { + data.reset(nullptr); + return 0; + } + return n; +} + +// Reads data from the stage-specific label metadata files. +// See the declaration in labels.h for more details. +std::vector load_metadata_tensors( + const std::string processed_graph_data_path, + const std::string stage, + const std::string data_hash) { + + std::filesystem::path base_path{processed_graph_data_path}; + std::filesystem::path directory = base_path / (stage + "_" + data_hash); + + std::unique_ptr concatenated_smiles; + uint64_t concatenated_smiles_size = + load_array_from_file(directory, concat_smiles_filename, concatenated_smiles); + + std::unique_ptr smiles_offsets; + uint64_t num_smiles_offsets = + load_array_from_file(directory, smiles_offsets_filename, smiles_offsets); + + std::unique_ptr num_nodes; + uint64_t num_num_nodes = + load_array_from_file(directory, num_nodes_filename, num_nodes); + + std::unique_ptr num_edges; + uint64_t num_num_edges = + load_array_from_file(directory, num_edges_filename, num_edges); + + std::unique_ptr mol_data_offsets; + uint64_t num_mol_data_offsets = + load_array_from_file(directory, file_data_offsets_filename, mol_data_offsets); + + if (num_num_nodes == 0 || num_num_edges != num_num_nodes || num_smiles_offsets != (num_num_nodes+1) || + concatenated_smiles_size == 0 || concatenated_smiles_size != uint64_t(smiles_offsets[num_num_edges]) || + (num_mol_data_offsets != num_num_nodes + (num_num_nodes + num_mols_per_file-1)/num_mols_per_file && num_mol_data_offsets != 0)) { + printf("ERROR: graphium_cpp.load_metadata_tensors failed to load valid metadata files\n"); + printf(" len(concat_smiles) is %zu\n", size_t(concatenated_smiles_size)); + printf(" len(smiles_offsets) is %zu\n", size_t(num_smiles_offsets)); + printf(" len(num_nodes) is %zu\n", size_t(num_num_nodes)); + printf(" len(num_edges) is %zu\n", size_t(num_num_edges)); + printf(" len(file_data_offsets) is %zu\n", size_t(num_mol_data_offsets)); + return std::vector(); + } + + // The above conditions should ensure that none of these arrays are empty, + // but assert in debug builds just in case. + assert(concatenated_smiles && smiles_offsets && num_nodes && num_edges); + + const int64_t concatenated_smiles_dims[1] = { int64_t(concatenated_smiles_size) }; + at::Tensor smiles_tensor = torch_tensor_from_array(std::move(concatenated_smiles), concatenated_smiles_dims, 1, c10::ScalarType::Char); + const int64_t smiles_offsets_dims[1] = { int64_t(num_num_nodes+1) }; + at::Tensor smiles_offsets_tensor = torch_tensor_from_array(std::move(smiles_offsets), smiles_offsets_dims, 1, c10::ScalarType::Long); + const int64_t num_nodes_dims[1] = { int64_t(num_num_nodes) }; + at::Tensor num_nodes_tensor = torch_tensor_from_array(std::move(num_nodes), num_nodes_dims, 1, c10::ScalarType::Int); + const int64_t num_edges_dims[1] = { int64_t(num_num_nodes) }; + at::Tensor num_edges_tensor = torch_tensor_from_array(std::move(num_edges), num_edges_dims, 1, c10::ScalarType::Int); + + std::vector stage_return_data; + stage_return_data.reserve((num_mol_data_offsets > 0) ? 5 : 4); + + stage_return_data.push_back(std::move(smiles_tensor)); + stage_return_data.push_back(std::move(smiles_offsets_tensor)); + stage_return_data.push_back(std::move(num_nodes_tensor)); + stage_return_data.push_back(std::move(num_edges_tensor)); + + if (num_mol_data_offsets > 0) { + const int64_t data_offsets_dims[1] = { int64_t(num_mol_data_offsets) }; + at::Tensor data_offsets_tensor = torch_tensor_from_array(std::move(mol_data_offsets), data_offsets_dims, 1, c10::ScalarType::Long); + + stage_return_data.push_back(std::move(data_offsets_tensor)); + } + + return stage_return_data; +} + +// Reads data from the task-specific stats file. +// See the declaration in labels.h for more details. +std::vector load_stats( + const std::string processed_graph_data_path, + const std::string data_hash, + const std::string task_name) { + + std::filesystem::path base_path{processed_graph_data_path}; + std::filesystem::path directory = base_path / data_hash; + const std::string filename(task_name + "_stats.tmp"); + + std::unique_ptr task_stats; + uint64_t num_stat_floats = + load_array_from_file(directory, filename.c_str(), task_stats); + + if (num_stat_floats == 0 || num_stat_floats % 4 != 0) { + return std::vector(); + } + + const uint64_t num_cols = num_stat_floats / 4; + std::vector return_stats(4); + for (size_t stat_index = 0; stat_index < 4; ++stat_index) { + std::unique_ptr single_stat(new double[num_cols]); + for (size_t i = 0; i < num_cols; ++i) { + single_stat[i] = task_stats[4*i + stat_index]; + } + const int64_t stat_dims[1] = { int64_t(num_cols) }; + at::Tensor stat_tensor = torch_tensor_from_array(std::move(single_stat), stat_dims, 1, c10::ScalarType::Double); + return_stats.push_back(std::move(stat_tensor)); + } + + return return_stats; +} + +// Accepts a Numpy array of strings or Python list of strings, and returns a PyTorch tensor +// of all of the characters and another tensor containing indices into the other tensor +// indicating where each string begins, plus one extra index indicating the end. +// See the declaration in labels.h. +std::pair concatenate_strings(pybind11::handle handle) { + using return_type = std::pair; + + ensure_numpy_array_module_initialized(); + + at::Tensor concatenated_strings; + at::Tensor offsets; + + PyObject* obj_ptr = handle.ptr(); + if (PyArray_Check(obj_ptr)) { + PyArrayObject* numpy_array = reinterpret_cast(obj_ptr); + int type_num = PyArray_TYPE(numpy_array); + int ndims = PyArray_NDIM(numpy_array); + if (type_num != NPY_OBJECT || ndims != 1) { + return return_type(std::move(concatenated_strings), std::move(offsets)); + } + intptr_t n = PyArray_DIM(numpy_array, 0); + if (n <= 0) { + return return_type(std::move(concatenated_strings), std::move(offsets)); + } + + size_t total_characters = 0; + for (intptr_t i = 0; i < n; ++i) { + pybind11::handle string_handle(*(PyObject**)PyArray_GETPTR1(numpy_array, i)); + if (!pybind11::isinstance(string_handle)) { + continue; + } + // TODO: Consider trying to avoid constructing std::string here + std::string string{pybind11::str{string_handle}}; + // +1 is for null terminator + total_characters += string.size() + 1; + } + std::unique_ptr concatenated_chars(new char[total_characters]); + std::unique_ptr offsets_buffer(new int64_t[n+1]); + int64_t offset = 0; + for (intptr_t i = 0; i < n; ++i) { + offsets_buffer[i] = offset; + pybind11::handle string_handle(*(PyObject**)PyArray_GETPTR1(numpy_array, i)); + if (!pybind11::isinstance(string_handle)) { + continue; + } + // TODO: Consider trying to avoid constructing std::string here + std::string string{pybind11::str{string_handle}}; + memcpy(concatenated_chars.get(), string.c_str(), string.size()); + offset += string.size(); + concatenated_chars[offset] = 0; + ++offset; + } + offsets_buffer[n] = offset; + + const int64_t concatenated_strings_dims[1] = { int64_t(total_characters) }; + concatenated_strings = torch_tensor_from_array(std::move(concatenated_chars), concatenated_strings_dims, 1, c10::ScalarType::Char); + const int64_t offsets_dims[1] = { int64_t(n+1) }; + offsets = torch_tensor_from_array(std::move(offsets_buffer), offsets_dims, 1, c10::ScalarType::Long); + } + if (pybind11::isinstance(handle)) { + pybind11::list list = handle.cast(); + size_t n = list.size(); + + size_t total_characters = 0; + for (size_t i = 0; i < n; ++i) { + pybind11::handle string_handle(list[i]); + if (!pybind11::isinstance(string_handle)) { + continue; + } + // TODO: Consider trying to avoid constructing std::string here + std::string string{pybind11::str{string_handle}}; + // +1 is for null terminator + total_characters += string.size() + 1; + } + std::unique_ptr concatenated_chars(new char[total_characters]); + std::unique_ptr offsets_buffer(new int64_t[n+1]); + int64_t offset = 0; + for (size_t i = 0; i < n; ++i) { + offsets_buffer[i] = offset; + pybind11::handle string_handle(list[i]); + if (!pybind11::isinstance(string_handle)) { + continue; + } + // TODO: Consider trying to avoid constructing std::string here + std::string string{pybind11::str{string_handle}}; + memcpy(concatenated_chars.get(), string.c_str(), string.size()); + offset += string.size(); + concatenated_chars[offset] = 0; + ++offset; + } + offsets_buffer[n] = offset; + + const int64_t concatenated_strings_dims[1] = { int64_t(total_characters) }; + concatenated_strings = torch_tensor_from_array(std::move(concatenated_chars), concatenated_strings_dims, 1, c10::ScalarType::Char); + const int64_t offsets_dims[1] = { int64_t(n+1) }; + offsets = torch_tensor_from_array(std::move(offsets_buffer), offsets_dims, 1, c10::ScalarType::Long); + } + return return_type(std::move(concatenated_strings), std::move(offsets)); +} + +// There are exactly 3 allowed stages. +constexpr size_t num_stages = 3; + +// The names of the 3 allowed stages. +// NOTE: Computing stats below depends on that "train" is stage 0. +const std::string stages[num_stages] = { + std::string("train"), + std::string("val"), + std::string("test") +}; + +// Called by `prepare_and_save_data` to get pointers to the relevant NumPy arrays for each task +static void get_task_data( + const pybind11::list& task_names, + pybind11::dict& task_dataset_args, + const pybind11::dict& task_label_normalization, + int64_t* return_label_num_cols, + int32_t* return_label_data_types, + size_t* task_col_starts, + size_t* task_bytes_per_float, + NormalizationOptions* task_normalization_options, + PyArrayObject** smiles_numpy_arrays, + PyArrayObject** labels_numpy_arrays, + PyArrayObject** label_offsets_numpy_arrays, + FeatureLevel* task_levels +) { + size_t total_num_cols = 0; + size_t task_index = 0; + for (const auto& task : task_names) { + const size_t current_task_index = task_index; + task_col_starts[current_task_index] = total_num_cols; + task_bytes_per_float[current_task_index] = 0; + smiles_numpy_arrays[current_task_index] = nullptr; + labels_numpy_arrays[current_task_index] = nullptr; + label_offsets_numpy_arrays[current_task_index] = nullptr; + ++task_index; + if (!pybind11::isinstance(task)) { + continue; + } + const std::string task_name{ pybind11::str(task) }; + pybind11::handle task_dataset_handle = pybind11::handle(PyDict_GetItemString(task_dataset_args.ptr(), task_name.c_str())); + if (!task_dataset_handle || !pybind11::isinstance(task_dataset_handle)) { + continue; + } + pybind11::dict dataset_dict = task_dataset_handle.cast(); + pybind11::handle smiles_handle = pybind11::handle(PyDict_GetItemString(dataset_dict.ptr(), "smiles")); + if (!smiles_handle) { + continue; + } + PyObject* smiles_obj_ptr = smiles_handle.ptr(); + if (!PyArray_Check(smiles_obj_ptr)) { + continue; + } + PyArrayObject* smiles_numpy_array = reinterpret_cast(smiles_obj_ptr); + int smiles_type_num = PyArray_TYPE(smiles_numpy_array); + int smiles_ndims = PyArray_NDIM(smiles_numpy_array); + if (smiles_type_num != NPY_OBJECT || smiles_ndims != 1) { + continue; + } + intptr_t num_smiles = PyArray_DIM(smiles_numpy_array, 0); + if (num_smiles <= 0) { + continue; + } + + // smiles array is okay + smiles_numpy_arrays[current_task_index] = smiles_numpy_array; + + // Check for labels. There might not be labels in inference case. + pybind11::handle labels_handle = pybind11::handle(PyDict_GetItemString(dataset_dict.ptr(), "labels")); + if (!labels_handle) { + continue; + } + pybind11::handle label_offsets_handle = pybind11::handle(PyDict_GetItemString(dataset_dict.ptr(), "label_offsets")); + PyObject* labels_obj_ptr = labels_handle.ptr(); + PyObject* label_offsets_obj_ptr = label_offsets_handle.ptr(); + const bool is_labels_numpy = PyArray_Check(labels_obj_ptr); + const bool is_labels_multi_row = label_offsets_obj_ptr && PyArray_Check(label_offsets_obj_ptr); + if (!is_labels_numpy) { + continue; + } + PyArrayObject* labels_numpy_array = reinterpret_cast(labels_obj_ptr); + PyArrayObject* label_offsets_numpy_array = is_labels_multi_row ? reinterpret_cast(label_offsets_obj_ptr) : nullptr; + int labels_type_num = PyArray_TYPE(labels_numpy_array); + int labels_ndims = PyArray_NDIM(labels_numpy_array); +#if GRAPHIUM_CPP_DEBUGGING + printf("\"%s\" labels numpy type %d, %d dims\n", task_name.c_str(), labels_type_num, labels_ndims); +#endif + if (!is_supported_numpy_type(labels_type_num) || labels_ndims != 2) { + continue; + } + if (is_labels_multi_row) { + int label_offsets_type_num = PyArray_TYPE(label_offsets_numpy_array); + int label_offsets_ndims = PyArray_NDIM(label_offsets_numpy_array); + // Only int64 is supported, for simplicity + if (label_offsets_type_num != NPY_INT64 || label_offsets_ndims != 1) { + continue; + } + } + intptr_t num_label_rows = PyArray_DIM(labels_numpy_array, 0); + intptr_t num_molecules = num_label_rows; + if (is_labels_multi_row) { + intptr_t num_offsets_rows = PyArray_DIM(label_offsets_numpy_array, 0); + if (num_offsets_rows == 0) { + continue; + } + // -1 is because last offset is the end offset + num_molecules = num_offsets_rows - 1; + + // Verify that the first offset is zero + if (*(const int64_t*)PyArray_GETPTR1(label_offsets_numpy_array, 0) != 0) { + continue; + } + // Verify that the last offset is the end offset + if (*(const int64_t*)PyArray_GETPTR1(label_offsets_numpy_array, num_molecules) != num_label_rows) { + continue; + } + } + intptr_t num_label_cols = PyArray_DIM(labels_numpy_array, 1); +#if GRAPHIUM_CPP_DEBUGGING + printf("\"%s\" labels[%zd][%zd] (%zd molecules)\n", task_name.c_str(), num_label_rows, num_label_cols, num_molecules); +#endif + if (num_smiles != num_molecules || num_label_cols <= 0) { + continue; + } + + const size_t supported_type_index = numpy_type_index(labels_type_num); + const size_t bytes_per_float = supported_types[supported_type_index].size; + labels_numpy_arrays[current_task_index] = labels_numpy_array; + label_offsets_numpy_arrays[current_task_index] = is_labels_multi_row ? label_offsets_numpy_array : nullptr; + return_label_num_cols[current_task_index] = num_label_cols; + return_label_data_types[current_task_index] = int(supported_types[supported_type_index].torch_type); + total_num_cols += size_t(num_label_cols); + task_bytes_per_float[current_task_index] = bytes_per_float; + + pybind11::handle task_normalization_handle = pybind11::handle(PyDict_GetItemString(task_label_normalization.ptr(), task_name.c_str())); + if (!task_normalization_handle || !pybind11::isinstance(task_normalization_handle)) { + continue; + } + pybind11::dict normalization_dict = task_normalization_handle.cast(); + pybind11::handle method_handle = pybind11::handle(PyDict_GetItemString(normalization_dict.ptr(), "method")); + pybind11::handle min_handle = pybind11::handle(PyDict_GetItemString(normalization_dict.ptr(), "min_clipping")); + pybind11::handle max_handle = pybind11::handle(PyDict_GetItemString(normalization_dict.ptr(), "max_clipping")); + if (method_handle && pybind11::isinstance(method_handle)) { + std::string method{pybind11::str(method_handle)}; + if (strcmp(method.c_str(), "normal") == 0) { + task_normalization_options[current_task_index].method = NormalizationMethod::NORMAL; + } + else if (strcmp(method.c_str(), "unit") == 0) { + task_normalization_options[current_task_index].method = NormalizationMethod::UNIT; + } + } + if (min_handle && pybind11::isinstance(min_handle)) { + task_normalization_options[current_task_index].min_clipping = double(int64_t(min_handle.cast())); + } + else if (min_handle && pybind11::isinstance(min_handle)) { + task_normalization_options[current_task_index].min_clipping = double(min_handle.cast()); + } + if (max_handle && pybind11::isinstance(max_handle)) { + task_normalization_options[current_task_index].max_clipping = double(int64_t(max_handle.cast())); + } + else if (max_handle && pybind11::isinstance(max_handle)) { + task_normalization_options[current_task_index].max_clipping = double(max_handle.cast()); + } + } + const size_t num_tasks = task_names.size(); + assert(task_index == num_tasks); + task_col_starts[num_tasks] = total_num_cols; + + // Determine the level of each task's data, for node reordering. + for (size_t task_index = 0; task_index < num_tasks; ++task_index) { + pybind11::handle task = task_names[task_index]; + if (!smiles_numpy_arrays[task_index]) { + continue; + } + const std::string task_name{ pybind11::str(task) }; + + constexpr const char* graph_prefix = "graph_"; + constexpr const char* node_prefix = "node_"; + constexpr const char* edge_prefix = "edge_"; + constexpr const char* nodepair_prefix = "nodepair_"; + constexpr size_t graph_prefix_length{ std::char_traits::length(graph_prefix) }; + constexpr size_t node_prefix_length{ std::char_traits::length(node_prefix) }; + constexpr size_t edge_prefix_length{ std::char_traits::length(edge_prefix) }; + constexpr size_t nodepair_prefix_length{ std::char_traits::length(nodepair_prefix) }; + + if (std::strncmp(task_name.c_str(), graph_prefix, graph_prefix_length) == 0) { + task_levels[task_index] = FeatureLevel::GRAPH; + } + else if (std::strncmp(task_name.c_str(), node_prefix, node_prefix_length) == 0) { + task_levels[task_index] = FeatureLevel::NODE; + } + else if (std::strncmp(task_name.c_str(), edge_prefix, edge_prefix_length) == 0) { + task_levels[task_index] = FeatureLevel::EDGE; + } + else if (std::strncmp(task_name.c_str(), nodepair_prefix, nodepair_prefix_length) == 0) { + task_levels[task_index] = FeatureLevel::NODEPAIR; + } + else { + // Invalid, but for now, just default to graph-level + task_levels[task_index] = FeatureLevel::GRAPH; + continue; + } + } +} + +// Called by `prepare_and_save_data` to get the indices and SMILES strings of all molecules +// being kept, (could be subsampled, or could be all), from each dataset (task) for each stage +// (train/val/test). The indices (strings) are all added to `task_mol_indices` +// (`smiles_strings`), with the beginnings of each of the `num_stages*num_tasks` ranges being +// recorded in `task_mol_start`, plus one extra for the end. +static void get_indices_and_strings( + const pybind11::list& task_names, + const pybind11::dict& task_train_indices, + const pybind11::dict& task_val_indices, + const pybind11::dict& task_test_indices, + size_t* task_mol_start, + std::vector& task_mol_indices, + PyArrayObject*const*const smiles_numpy_arrays, + std::vector& smiles_strings +) { + const size_t num_tasks = task_names.size(); + + const pybind11::dict* stage_task_indices[num_stages] = { + &task_train_indices, + &task_val_indices, + &task_test_indices + }; + + // Get the total number of molecules, by stage and task + size_t total_num_mols = 0; + for (size_t stage_index = 0; stage_index < num_stages; ++stage_index) { + const pybind11::dict& task_indices_dict = *stage_task_indices[stage_index]; + + for (size_t task_index = 0; task_index < num_tasks; ++task_index) { + pybind11::handle task = task_names[task_index]; + if (!smiles_numpy_arrays[task_index]) { + continue; + } + const std::string task_name{ pybind11::str(task) }; + pybind11::handle task_indices_handle = pybind11::handle(PyDict_GetItemString(task_indices_dict.ptr(), task_name.c_str())); + if (!task_indices_handle || !pybind11::isinstance(task_indices_handle)) { + printf("Error: Task %s indices list isn't valid.\n", task_name.c_str()); + continue; + } + const pybind11::list task_indices_list = task_indices_handle.cast(); + const size_t current_num_mols = task_indices_list.size(); + if (current_num_mols == 0) { + printf("Error: Task %s indices list is empty.\n", task_name.c_str()); + } + total_num_mols += current_num_mols; + } + } + + // Get the mol indices for all stages and tasks + task_mol_indices.reserve(total_num_mols); + // Unfortunately, reading strings from a numpy array isn't threadsafe, + // so we have to do that single-threaded first, too. + smiles_strings.reserve(total_num_mols); + for (size_t stage_index = 0; stage_index < num_stages; ++stage_index) { + const pybind11::dict& task_indices_dict = *stage_task_indices[stage_index]; + + for (size_t task_index = 0; task_index < num_tasks; ++task_index) { + // Update task_mol_start here, in case any indices aren't integers + // or any SMILES strings aren't strings below. + task_mol_start[stage_index*num_tasks + task_index] = task_mol_indices.size(); + + pybind11::handle task = task_names[task_index]; + if (!smiles_numpy_arrays[task_index]) { + continue; + } + const std::string task_name{ pybind11::str(task) }; + pybind11::handle task_indices_handle = pybind11::handle(PyDict_GetItemString(task_indices_dict.ptr(), task_name.c_str())); + if (!task_indices_handle || !pybind11::isinstance(task_indices_handle)) { + continue; + } + + const pybind11::list task_indices_list = task_indices_handle.cast(); + const size_t current_num_mols = task_indices_list.size(); + + PyArrayObject*const smiles_numpy_array = smiles_numpy_arrays[task_index]; + const size_t smiles_array_size = PyArray_DIM(smiles_numpy_array, 0); + + for (size_t indices_index = 0; indices_index < current_num_mols; ++indices_index) { + const auto list_item = task_indices_list[indices_index]; + if (!pybind11::isinstance(list_item)) { + continue; + } + + size_t task_mol_index = size_t(list_item.cast()); + if (task_mol_index >= smiles_array_size) { + continue; + } + + pybind11::handle single_smiles_handle(*(PyObject**)PyArray_GETPTR1(smiles_numpy_array, task_mol_index)); + if (!pybind11::isinstance(single_smiles_handle)) { + continue; + } + + task_mol_indices.push_back(task_mol_index); + smiles_strings.push_back(std::string(pybind11::str(single_smiles_handle))); + } + + } + } + total_num_mols = task_mol_indices.size(); + task_mol_start[num_stages*num_tasks] = total_num_mols; +} + +// Molecule data used for sorting in `prepare_and_save_data`, in order to quickly merge +// equivalent molecules. +struct MolKey { + uint64_t id0; + uint64_t id1; + uint32_t num_nodes; + uint32_t num_edges; + uint64_t task_index; + uint64_t task_mol_index; + uint64_t mol_index; + + // This is the comparison operator used for sorting + bool operator<(const MolKey& other) const { + if (id0 != other.id0) { + return (id0 < other.id0); + } + if (id1 != other.id1) { + return (id1 < other.id1); + } + if (num_nodes != other.num_nodes) { + return (num_nodes < other.num_nodes); + } + if (num_edges != other.num_edges) { + return (num_edges < other.num_edges); + } + if (task_index != other.task_index) { + return (task_index < other.task_index); + } + return (task_mol_index < other.task_mol_index); + } + + // This is used for identifying keys of molecules with invalid SMILES strings. + // They show up as having no nodes, no edges, and ID 0. + bool isInvalid() const { + return id0 == 0 && id1 == 0 && num_nodes == 0 && num_edges == 0; + } +}; + +// Called by `prepare_and_save_data` to fill in the `MolKey` structure for each molecule, +// in parallel if needed, because the InChI key computation is very slow. When not merging +// equivalent molecules, e.g. for inference, the InChI keys are not needed, but the SMILES +// parsing might as well still be parallelized. +static void compute_mol_keys( + MolKey*const keys, + const size_t total_num_mols, + const size_t num_tasks, + int max_threads, + const size_t*const task_mol_start, + const bool add_self_loop, + const bool explicit_H, + const bool merge_equivalent_mols, + const size_t*const task_mol_indices, + const std::vector& smiles_strings) { + + // Determine the number of threads to use for computing MolKey values + const size_t num_mols_per_block = 512; + const size_t num_blocks = (total_num_mols + num_mols_per_block-1) / num_mols_per_block; + const size_t num_processors = std::thread::hardware_concurrency(); + size_t num_threads = (num_processors == 1 || num_blocks <= 4) ? 1 : std::min(num_processors, num_blocks/2); + // max_threads of -1 means n-1 threads, to avoid starving other processes + if (max_threads < 0) { + max_threads += num_processors; + // Don't hit zero or remain negative, because that would skip applying the limit + if (max_threads < 1) { + max_threads = 1; + } + } + // max_threads of 0 means to not limit the number of threads + if (max_threads > 0 && num_threads > size_t(max_threads)) { + num_threads = size_t(max_threads); + } + + auto&& get_single_mol_key = [task_mol_start,add_self_loop,explicit_H,task_mol_indices,&smiles_strings,num_tasks,merge_equivalent_mols](size_t mol_index) -> MolKey { + // Find which task this mol is in. If there could be many tasks, + // this could be a binary search, but for small numbers of tasks, + // a linear search is fine. + size_t task_index = 0; + while (task_mol_start[task_index+1] <= mol_index) { + ++task_index; + } + const size_t task_mol_index = task_mol_indices[mol_index]; + + const std::string& smiles_str = smiles_strings[mol_index]; + MolBriefData mol_data = smiles_to_brief_data(smiles_str, add_self_loop, explicit_H, merge_equivalent_mols); + + if (!merge_equivalent_mols) { + // mol_index is, by definition, distinct for each input index, + // so no molecules will be identified as equivalent below. + mol_data.unique_id[0] = mol_index; + mol_data.unique_id[1] = 0; + } + + return MolKey{mol_data.unique_id[0], mol_data.unique_id[1], mol_data.num_nodes, mol_data.num_edges, task_index % num_tasks, task_mol_index, mol_index}; + }; + if (num_threads == 1) { + for (size_t mol_index = 0; mol_index < total_num_mols; ++mol_index) { + keys[mol_index] = get_single_mol_key(mol_index); + } + } + else { + std::atomic next_block_index(0); + auto&& thread_functor = [keys,&next_block_index,num_blocks,num_mols_per_block,total_num_mols,&get_single_mol_key]() { + while (true) { + const size_t block_index = next_block_index.fetch_add(1); + if (block_index >= num_blocks) { + return; + } + const size_t begin_index = block_index * num_mols_per_block; + const size_t end_index = std::min((block_index+1) * num_mols_per_block, total_num_mols); + for (size_t mol_index = begin_index; mol_index < end_index; ++mol_index) { + keys[mol_index] = get_single_mol_key(mol_index); + } + } + }; + std::vector threads; + for (size_t thread_index = 0; thread_index < num_threads; ++thread_index) { + threads.push_back(std::thread(thread_functor)); + } + for (size_t thread_index = 0; thread_index < num_threads; ++thread_index) { + threads[thread_index].join(); + } + } +} + +constexpr size_t stat_min_offset = 0; +constexpr size_t stat_max_offset = 1; +constexpr size_t stat_mean_offset = 2; +constexpr size_t stat_std_offset = 3; +constexpr size_t num_stats = 4; + +// Called by `prepare_and_save_data` to compute the minimum, maximum, mean, and standard +// deviation of each column in each dataset (task), using only the molecules in the +// training stage. This does nothing if there is no label data, e.g. for inference. +static auto compute_stats( + const std::filesystem::path& common_path, + const size_t total_num_cols, + const pybind11::list& task_names, + const size_t*const task_mol_start, + const size_t*const task_col_starts, + const size_t*const task_bytes_per_float, + const NormalizationOptions*const task_normalization_options, + PyArrayObject*const*const labels_numpy_arrays, + PyArrayObject*const*const label_offsets_numpy_arrays, + const MolKey*const keys, + std::unique_ptr& all_task_stats) { + + std::unordered_map> all_stats_return_data; + if (total_num_cols == 0) { + return all_stats_return_data; + } + + const size_t num_tasks = task_names.size(); + + // Compute stats on the train stage only (stage 0), like how the python code did it. + // Normalization will be applied to all stages later. + // TODO: Does it matter that stats calculations will include all copies of molecules + // that occur multiple times in the same dataset? + size_t stats_floats = num_stats*total_num_cols; + all_task_stats.reset((stats_floats > 0) ? new double[stats_floats] : nullptr); + + std::unique_ptr all_task_num_non_nan(new intptr_t[total_num_cols]); + for (size_t task_index = 0; task_index < num_tasks; ++task_index) { + const size_t task_num_mols = task_mol_start[task_index+1] - task_mol_start[task_index]; + const size_t task_first_col = task_col_starts[task_index]; + const size_t task_num_cols = task_col_starts[task_index+1] - task_first_col; + if (task_num_mols == 0 || task_num_cols == 0) { + continue; + } + // Initialize stats for accumulation + double*const task_stats = all_task_stats.get() + num_stats*task_first_col; + intptr_t*const task_num_non_nan = all_task_num_non_nan.get() + task_first_col; + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index) { + task_stats[num_stats*task_col_index + stat_min_offset] = std::numeric_limits::infinity(); + task_stats[num_stats*task_col_index + stat_max_offset] = -std::numeric_limits::infinity(); + task_stats[num_stats*task_col_index + stat_mean_offset] = 0.0; + task_stats[num_stats*task_col_index + stat_std_offset] = 0.0; + task_num_non_nan[task_col_index] = 0; + } + + const size_t bytes_per_float = task_bytes_per_float[task_index]; + + auto&& update_stats_single_row = [task_stats, task_num_non_nan](const char* col_data, const size_t task_num_cols, const size_t bytes_per_float, const intptr_t col_stride) { + double* stats = task_stats; + intptr_t* num_non_nan = task_num_non_nan; + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index, col_data += col_stride, stats += num_stats, ++num_non_nan) { + // TODO: Move the type check outside the loop if it's a bottleneck + double value; + if (bytes_per_float == sizeof(double)) { + value = *(const double*)(col_data); + } + else if (bytes_per_float == sizeof(float)) { + value = *(const float*)(col_data); + } + else { + assert(bytes_per_float == sizeof(uint16_t)); + value = c10::detail::fp16_ieee_to_fp32_value(*(const uint16_t*)(col_data)); + } + if (value != value) { + // NaN value, so skip it + continue; + } + stats[stat_min_offset] = std::min(stats[stat_min_offset], value); + stats[stat_max_offset] = std::max(stats[stat_max_offset], value); + stats[stat_mean_offset] += value; + // TODO: If summing the squares isn't accurate enough for computing the variance, + // consider other approaches. + stats[stat_std_offset] += value*value; + ++(*num_non_nan); + } + }; + + PyArrayObject*const labels_numpy_array = labels_numpy_arrays[task_index]; + if (labels_numpy_array != nullptr) { + const char* raw_data = (const char*)PyArray_DATA(labels_numpy_array); + const intptr_t* strides = PyArray_STRIDES(labels_numpy_array); + const intptr_t num_label_rows = PyArray_DIM(labels_numpy_array, 0); + PyArrayObject*const label_offsets_numpy_array = label_offsets_numpy_arrays[task_index]; + const char* offsets_raw_data = label_offsets_numpy_array ? (const char*)PyArray_DATA(label_offsets_numpy_array) : nullptr; + const intptr_t offsets_stride = label_offsets_numpy_array ? PyArray_STRIDES(label_offsets_numpy_array)[0] : 0; + // The -1 is because there's an extra entry at the end for the end offset. + const intptr_t num_mols = label_offsets_numpy_array ? PyArray_DIM(label_offsets_numpy_array, 0) - 1 : num_label_rows; + // The normalization is computed on the subsample being kept + for (size_t task_key_index = 0; task_key_index < task_num_mols; ++task_key_index) { + const size_t task_mol_index = keys[task_mol_start[task_index] + task_key_index].task_mol_index; + if (task_mol_index >= size_t(num_mols)) { + printf("Error: In task %zu, mol index %zu is past limit of %zu\n", size_t(task_index), task_mol_index, size_t(num_mols)); + continue; + } + if (offsets_raw_data == nullptr) { + const char* row_data = raw_data + strides[0]*task_mol_index; + update_stats_single_row(row_data, task_num_cols, bytes_per_float, strides[1]); + } + else { + size_t begin_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*task_mol_index); + size_t end_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*(task_mol_index+1)); + const char* row_data = raw_data + strides[0]*begin_offset; + for (size_t row = begin_offset; row < end_offset; ++row, row_data += strides[0]) { + update_stats_single_row(row_data, task_num_cols, bytes_per_float, strides[1]); + } + } + } + } + +#if GRAPHIUM_CPP_DEBUGGING + printf("Task %zu normalization method %zu\n", size_t(task_index), size_t(task_normalization_options[task_index].method)); + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index) { + printf("Task %zu col %zu, num non-nan = %zu, min = %e, max = %e\n", + size_t(task_index), task_col_index, + size_t(task_num_non_nan[task_col_index]), + task_stats[num_stats*task_col_index + stat_min_offset], + task_stats[num_stats*task_col_index + stat_max_offset]); + } +#endif + } + + for (size_t task_index = 0; task_index < num_tasks; ++task_index) { + const size_t task_first_col = task_col_starts[task_index]; + const size_t task_num_cols = task_col_starts[task_index+1] - task_first_col; + if (task_num_cols == 0) { + continue; + } + + // Finish accumulation + double*const task_stats = all_task_stats.get() + num_stats*task_first_col; + intptr_t*const task_num_non_nan = all_task_num_non_nan.get() + task_first_col; + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index) { + if (task_num_non_nan[task_col_index] == 0) { + task_stats[num_stats*task_col_index + stat_min_offset] = std::numeric_limits::quiet_NaN(); + task_stats[num_stats*task_col_index + stat_max_offset] = std::numeric_limits::quiet_NaN(); + task_stats[num_stats*task_col_index + stat_mean_offset] = std::numeric_limits::quiet_NaN(); + task_stats[num_stats*task_col_index + stat_std_offset] = std::numeric_limits::quiet_NaN(); + } + else { + if (task_normalization_options[task_index].min_clipping > task_stats[num_stats*task_col_index + stat_min_offset]) { + task_stats[num_stats*task_col_index + stat_min_offset] = task_normalization_options[task_index].min_clipping; + } + if (task_normalization_options[task_index].max_clipping < task_stats[num_stats*task_col_index + stat_max_offset]) { + task_stats[num_stats*task_col_index + stat_max_offset] = task_normalization_options[task_index].max_clipping; + } + const double n = double(task_num_non_nan[task_col_index]); + const double mean = task_stats[num_stats*task_col_index + stat_mean_offset] / n; + task_stats[num_stats*task_col_index + stat_mean_offset] = mean; + // sum((x[i] - m)^2)/(n-1) + // = sum(x[i]^2 -2mx[i] + m^2)/(n-1) + // = (sum(x[i]^2) - 2nm^2 + nm^2)/(n-1) + // = (sum(x[i]^2) - nm^2)/(n-1) + // except, for compatibility with numpy.nanstd, use n instead of n-1 + const double sum_sqaures = task_stats[num_stats*task_col_index + stat_std_offset]; + const double stdev = std::sqrt((sum_sqaures - n*mean*mean)/n); + task_stats[num_stats*task_col_index + stat_std_offset] = stdev; + } + } + + const std::string task_name{ pybind11::str(task_names[task_index]) }; +#if GRAPHIUM_CPP_DEBUGGING + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index) { + printf("%s %zu %lld %e %e %e %e\n", + task_name.c_str(), task_col_index, (long long)task_num_non_nan[task_col_index], + task_stats[num_stats*task_col_index + stat_min_offset], + task_stats[num_stats*task_col_index + stat_max_offset], + task_stats[num_stats*task_col_index + stat_mean_offset], + task_stats[num_stats*task_col_index + stat_std_offset]); + } +#endif + const std::string stats_filename = task_name + "_stats.tmp"; + save_array_to_file(common_path, stats_filename.c_str(), task_stats, num_stats*task_num_cols); + + // Make copies for returning in a format similar to the load_stats function. + std::vector task_stats_out; + for (size_t stat_index = 0; stat_index < num_stats; ++stat_index) { + const int64_t task_stats_dims[1] = { int64_t(task_num_cols) }; + std::unique_ptr task_stats_copy(new double[task_num_cols]); + for (size_t task_col_index = 0; task_col_index < task_num_cols; ++task_col_index) { + task_stats_copy[task_col_index] = task_stats[num_stats*task_col_index + stat_index]; + } + at::Tensor task_stats_tensor = torch_tensor_from_array(std::move(task_stats_copy), task_stats_dims, 1, c10::ScalarType::Double); + task_stats_out.push_back(std::move(task_stats_tensor)); + } + all_stats_return_data.insert(std::make_pair(std::move(task_name), std::move(task_stats_out))); + } + + return all_stats_return_data; +} + +// Called by `prepare_and_save_data` to save the SMILES string data, numbers of nodes, +// and numbers of edges, to cache files for each stage (train/val/test.) +static auto save_non_label_data( + const std::filesystem::path* stage_paths, + const size_t num_tasks, + const size_t*const task_mol_start, + const MolKey*const keys, + const std::vector& smiles_strings, + const size_t total_num_cols) { + + std::unordered_map> per_stage_return_data; + + for (size_t stage_index = 0; stage_index < num_stages; ++stage_index) { + size_t concatenated_smiles_size = 0; + uint64_t num_unique_mols = 0; + const size_t stage_begin_index = task_mol_start[stage_index*num_tasks]; + const size_t stage_end_index = task_mol_start[(stage_index+1)*num_tasks]; + for (size_t sorted_index = stage_begin_index; sorted_index < stage_end_index; ) { + if (keys[sorted_index].isInvalid()) { + ++sorted_index; + continue; + } + ++num_unique_mols; + + // Add the length of the smiles string to the total length, + // and include the terminating zero + const size_t smiles_length = smiles_strings[keys[sorted_index].mol_index].size(); + concatenated_smiles_size += (smiles_length+1); + + const uint64_t id0 = keys[sorted_index].id0; + const uint64_t id1 = keys[sorted_index].id1; + ++sorted_index; + while (sorted_index < stage_end_index && keys[sorted_index].id0 == id0 && keys[sorted_index].id1 == id1) { + ++sorted_index; + } + } + + std::unique_ptr concatenated_smiles(new char[concatenated_smiles_size]); + std::unique_ptr smiles_offsets(new int64_t[num_unique_mols+1]); + std::unique_ptr num_nodes(new int32_t[num_unique_mols]); + std::unique_ptr num_edges(new int32_t[num_unique_mols]); + size_t unique_index = 0; + int64_t smiles_offset = 0; + for (size_t sorted_index = stage_begin_index; sorted_index < stage_end_index; ) { + if (keys[sorted_index].isInvalid()) { + ++sorted_index; + continue; + } + smiles_offsets[unique_index] = smiles_offset; + + const uint64_t id0 = keys[sorted_index].id0; + const uint64_t id1 = keys[sorted_index].id1; + num_nodes[unique_index] = keys[sorted_index].num_nodes; + num_edges[unique_index] = keys[sorted_index].num_edges; + + // Copy the string + const std::string& smiles_string = smiles_strings[keys[sorted_index].mol_index]; + const size_t smiles_length = smiles_string.size(); + memcpy(concatenated_smiles.get() + smiles_offset, smiles_string.c_str(), smiles_length); + smiles_offset += smiles_length; + // Don't forget the terminating zero + concatenated_smiles[smiles_offset] = 0; + ++smiles_offset; + + ++unique_index; + ++sorted_index; + while (sorted_index < stage_end_index && keys[sorted_index].id0 == id0 && keys[sorted_index].id1 == id1) { + ++sorted_index; + } + } + smiles_offsets[unique_index] = smiles_offset; + + save_array_to_file(stage_paths[stage_index], concat_smiles_filename, concatenated_smiles.get(), concatenated_smiles_size); + save_array_to_file(stage_paths[stage_index], smiles_offsets_filename, smiles_offsets.get(), num_unique_mols+1); + save_array_to_file(stage_paths[stage_index], num_nodes_filename, num_nodes.get(), num_unique_mols); + save_array_to_file(stage_paths[stage_index], num_edges_filename, num_edges.get(), num_unique_mols); + + const int64_t concatenated_smiles_dims[1] = { int64_t(concatenated_smiles_size) }; + at::Tensor smiles_tensor = torch_tensor_from_array(std::move(concatenated_smiles), concatenated_smiles_dims, 1, c10::ScalarType::Char); + const int64_t smiles_offsets_dims[1] = { int64_t(num_unique_mols+1) }; + at::Tensor smiles_offsets_tensor = torch_tensor_from_array(std::move(smiles_offsets), smiles_offsets_dims, 1, c10::ScalarType::Long); + const int64_t num_nodes_dims[1] = { int64_t(num_unique_mols) }; + at::Tensor num_nodes_tensor = torch_tensor_from_array(std::move(num_nodes), num_nodes_dims, 1, c10::ScalarType::Int); + const int64_t num_edges_dims[1] = { int64_t(num_unique_mols) }; + at::Tensor num_edges_tensor = torch_tensor_from_array(std::move(num_edges), num_edges_dims, 1, c10::ScalarType::Int); + + std::vector stage_return_data; + // Reserve space for one extra, for the data offsets tensor later + stage_return_data.reserve((total_num_cols > 0) ? 5 : 4); + stage_return_data.push_back(std::move(smiles_tensor)); + stage_return_data.push_back(std::move(smiles_offsets_tensor)); + stage_return_data.push_back(std::move(num_nodes_tensor)); + stage_return_data.push_back(std::move(num_edges_tensor)); + per_stage_return_data.insert(std::make_pair(stages[stage_index], std::move(stage_return_data))); + } + + return per_stage_return_data; +} + +// Called by `prepare_and_save_data` to save the label data to cache files for each stage +// (train/val/test.) Each file will contain the label data for `num_mols_per_file` molecules. +// In the case of equivalent molecules, this also reorders node-level, edge-level, and +// node-pair-level label data to be consistent with the node or edge order in the first +// encountered equivalent molecule, as well as normalizing the data as specified. +static void save_label_data( + std::unordered_map>& per_stage_return_data, + const std::filesystem::path* stage_paths, + const size_t num_tasks, + const size_t*const task_mol_start, + const size_t*const task_col_starts, + const size_t total_num_cols, + const MolKey*const keys, + PyArrayObject*const*const labels_numpy_arrays, + PyArrayObject*const*const label_offsets_numpy_arrays, + const NormalizationOptions*const task_normalization_options, + const double*const all_task_stats, + const size_t*const task_bytes_per_float, + const FeatureLevel*const task_levels, + const std::vector& smiles_strings, + const bool explicit_H) { + + // mol_data_offsets will only need one entry for each unique molecule, + // plus one per file, but we can preallocate an upper bound. + std::vector mol_data_offsets; + size_t upper_bound_num_files = (task_mol_start[num_tasks] + num_mols_per_file-1) / num_mols_per_file; + mol_data_offsets.reserve(task_mol_start[num_tasks] + upper_bound_num_files); + + // temp_data is used for normalization + std::vector temp_data; + temp_data.resize(total_num_cols*sizeof(double)); + + std::vector data; + data.reserve(num_mols_per_file*(total_num_cols*sizeof(double) + (1+2*num_tasks)*sizeof(uint64_t))); + + // These are for reordering label data at node, edge, or nodepair level + // when the same molecule may appear in multiple tasks with different + // atom orders. + std::vector first_atom_order; + std::vector current_atom_order; + std::vector inverse_atom_order; + std::vector first_bond_atoms; + std::vector current_bond_atoms; + std::vector current_bond_order; + + // Now, deal with label data + for (size_t stage_index = 0; stage_index < num_stages; ++stage_index) { + mol_data_offsets.resize(0); + assert(data.size() == 0); + uint64_t num_unique_mols = 0; + const size_t stage_begin_index = task_mol_start[stage_index*num_tasks]; + const size_t stage_end_index = task_mol_start[(stage_index+1)*num_tasks]; + for (size_t sorted_index = stage_begin_index; sorted_index < stage_end_index; ) { + if (keys[sorted_index].isInvalid()) { + ++sorted_index; + continue; + } + size_t data_offset = data.size(); + mol_data_offsets.push_back(data_offset); + + const size_t first_sorted_index = sorted_index; + const uint64_t id0 = keys[sorted_index].id0; + const uint64_t id1 = keys[sorted_index].id1; + + uint64_t prev_task_index = keys[sorted_index].task_index; + uint64_t mol_num_tasks = 1; + ++sorted_index; + while (sorted_index < stage_end_index && keys[sorted_index].id0 == id0 && keys[sorted_index].id1 == id1) { + // The same molecule can occur multiple times in a single dataset, + // but we only want to keep one copy for each task. + if (keys[sorted_index].task_index != prev_task_index) { + ++mol_num_tasks; + prev_task_index = keys[sorted_index].task_index; + } + ++sorted_index; + } + assert(mol_num_tasks <= num_tasks); + assert(!merge_equivalent_mols || mol_num_tasks == 1); + + // TODO: Double data capacity as needed if resizing is slow + assert(data.size() == data_offset); + data.resize(data_offset + sizeof(uint64_t)*(1+2*mol_num_tasks)); + + // Copy in the number of tasks for this molecule, followed by a list of the task indices and their end offsets. + memcpy(data.data() + data_offset, &mol_num_tasks, sizeof(uint64_t)); + data_offset += sizeof(uint64_t); + uint64_t task_offset = 0; + // Start with an invalid prev_task_index to pick up the first task + prev_task_index = uint64_t(int64_t(-1)); + for (size_t i = first_sorted_index; i < sorted_index; ++i) { + const uint64_t task_index = keys[i].task_index; + // The same molecule can occur multiple times in a single dataset, + // but we only want to keep one copy for each task. + if (task_index == prev_task_index) { + continue; + } + prev_task_index = task_index; + size_t num_cols = task_col_starts[task_index+1] - task_col_starts[task_index]; + PyArrayObject*const label_offsets_numpy_array = label_offsets_numpy_arrays[task_index]; + if (label_offsets_numpy_array != nullptr) { + const size_t task_mol_index = keys[i].task_mol_index; + const char* offsets_raw_data = (const char*)PyArray_DATA(label_offsets_numpy_array); + const intptr_t offsets_stride = PyArray_STRIDES(label_offsets_numpy_array)[0]; + const int64_t begin_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*task_mol_index); + const int64_t end_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*(task_mol_index+1)); + const size_t current_rows = size_t(end_offset - begin_offset); + num_cols *= current_rows; + } + task_offset += task_bytes_per_float[task_index]*num_cols; + memcpy(data.data() + data_offset, &task_index, sizeof(uint64_t)); + data_offset += sizeof(uint64_t); + memcpy(data.data() + data_offset, &task_offset, sizeof(uint64_t)); + data_offset += sizeof(uint64_t); + } + + // TODO: Double data capacity as needed if resizing is slow + assert(data.size() == data_offset); + data.resize(data_offset + task_offset); + + auto&& store_single_row = [&data_offset, &data, &temp_data]( + const char* col_data, + const size_t task_num_cols, + const intptr_t col_stride, + const size_t in_bytes_per_float, + const size_t out_bytes_per_float, + const NormalizationMethod normalization_method, + const bool do_clipping, + const double* task_stats) { + + if (size_t(col_stride) == in_bytes_per_float) { + memcpy(temp_data.data(), col_data, in_bytes_per_float*task_num_cols); + } + else { + for (size_t col = 0; col < task_num_cols; ++col) { + memcpy(temp_data.data() + col*in_bytes_per_float, col_data, in_bytes_per_float); + col_data += col_stride; + } + } + for (size_t col = 0; col < task_num_cols; ++col) { + double value; + if (in_bytes_per_float == sizeof(double)) { + value = ((const double*)(temp_data.data()))[col]; + } + else if (in_bytes_per_float == sizeof(float)) { + value = ((const float*)(temp_data.data()))[col]; + } + else { + assert(in_bytes_per_float == sizeof(uint16_t)); + value = c10::detail::fp16_ieee_to_fp32_value(((const uint16_t*)(temp_data.data()))[col]); + } + if (do_clipping) { + value = std::max(value, task_stats[stat_min_offset]); + value = std::min(value, task_stats[stat_max_offset]); + } + if (normalization_method == NormalizationMethod::NORMAL) { + if (task_stats[stat_std_offset] != 0) { + value = (value - task_stats[stat_mean_offset])/task_stats[stat_std_offset]; + } + else { + value = 0; + } + } + else if (normalization_method == NormalizationMethod::UNIT) { + // TODO: Cache 1/(max-min) or 0 to avoid check + if (task_stats[stat_max_offset] - task_stats[stat_min_offset] != 0) { + value = (value - task_stats[stat_min_offset])/(task_stats[stat_max_offset] - task_stats[stat_min_offset]); + } + else { + value = 0; + } + } + + // NOTE: The code below writes to temp_data, which is still being read from above, + // so this relies on that we're not writing to a larger data type than we're reading, + // else we'll overwrite data. + assert(out_bytes_per_float <= in_bytes_per_float); + if (out_bytes_per_float == sizeof(double)) { + ((double*)(temp_data.data()))[col] = value; + } + else if (out_bytes_per_float == sizeof(float)) { + ((float*)(temp_data.data()))[col] = float(value); + } + else { + assert(out_bytes_per_float == sizeof(uint16_t)); + ((uint16_t*)(temp_data.data()))[col] = c10::detail::fp16_ieee_from_fp32_value(value); + } + task_stats += num_stats; + } + + memcpy(data.data() + data_offset, temp_data.data(), out_bytes_per_float*task_num_cols); + data_offset += out_bytes_per_float*task_num_cols; + }; + + // Copy in the task data, with optional normalization + // Start with an invalid prev_task_index to pick up the first task + prev_task_index = uint64_t(int64_t(-1)); + std::unique_ptr first_mol; + for (size_t i = first_sorted_index; i < sorted_index; ++i) { + const uint64_t task_index = keys[i].task_index; + // The same molecule can occur multiple times in a single dataset, + // but we only want to keep one copy for each task. + if (task_index == prev_task_index) { + continue; + } + prev_task_index = task_index; + + const uint64_t task_mol_index = keys[i].task_mol_index; + + const size_t task_first_col = task_col_starts[task_index]; + const size_t task_num_cols = task_col_starts[task_index+1] - task_first_col; + const NormalizationOptions& normalization = task_normalization_options[task_index]; + const bool do_clipping = + (normalization.min_clipping > -std::numeric_limits::infinity()) && + (normalization.max_clipping < std::numeric_limits::infinity()); + const double* task_stats = all_task_stats + num_stats*task_first_col; + + const size_t bytes_per_float = task_bytes_per_float[task_index]; + + // Before copying this task's label data, check whether the atom order + // is different from the representative SMILES string's atom order. + bool same_order_as_first = true; + if (i != first_sorted_index && task_levels[task_index] != FeatureLevel::GRAPH) { + const std::string& first_string = smiles_strings[keys[first_sorted_index].mol_index]; + const std::string& current_string = smiles_strings[keys[i].mol_index]; + if (first_string != current_string) { + // Different string, so get first and current atom orders + if (first_atom_order.size() == 0) { + first_mol = parse_mol(first_string, explicit_H); + get_canonical_atom_order(*first_mol, first_atom_order); + } + if (first_bond_atoms.size() == 0 && task_levels[task_index] == FeatureLevel::EDGE) { + const unsigned int num_bonds = first_mol->getNumBonds(); + for (unsigned int bond_index = 0; bond_index < num_bonds; ++bond_index) { + auto bond = first_mol->getBondWithIdx(bond_index); + unsigned int a = bond->getBeginAtomIdx(); + unsigned int b = bond->getEndAtomIdx(); + first_bond_atoms.push_back(a); + first_bond_atoms.push_back(b); + } + } + std::unique_ptr mol = parse_mol(current_string, explicit_H); + get_canonical_atom_order(*mol, current_atom_order); + assert(first_atom_order.size() == current_atom_order.size()); + + // first_atom_order maps from the first order to the canonical order. + // current_atom_order maps from the first order to the canonical order. + // We need the inverse current map, to go from the first order to the + // canonical order, and then from there to the current order. + inverse_atom_order.resize(first_atom_order.size()); + for (unsigned int current_index = 0; current_index < current_atom_order.size(); ++current_index) { + unsigned int canon_index = current_atom_order[current_index]; + assert(canon_index < inverse_atom_order.size()); + inverse_atom_order[canon_index] = current_index; + } + for (unsigned int first_index = 0; first_index < first_atom_order.size(); ++first_index) { + unsigned int canon_index = first_atom_order[first_index]; + assert(canon_index < inverse_atom_order.size()); + unsigned int current_index = inverse_atom_order[canon_index]; + assert(first_index < current_atom_order.size()); + current_atom_order[first_index] = current_index; + if (current_index != first_index) { + same_order_as_first = false; + } + } + + if (task_levels[task_index] == FeatureLevel::EDGE) { + const unsigned int num_bonds = mol->getNumBonds(); + for (unsigned int bond_index = 0; bond_index < num_bonds; ++bond_index) { + auto bond = mol->getBondWithIdx(bond_index); + unsigned int a = bond->getBeginAtomIdx(); + unsigned int b = bond->getEndAtomIdx(); + current_bond_atoms.push_back(a); + current_bond_atoms.push_back(b); + } + assert(first_bond_atoms.size() == current_bond_atoms.size()); + + // Create a mapping from the first bond order to the current bond order + same_order_as_first = true; + for (size_t i = 0; i < first_bond_atoms.size(); i += 2) { + const unsigned int first_a = current_atom_order[first_bond_atoms[i]]; + const unsigned int first_b = current_atom_order[first_bond_atoms[i + 1]]; + + // TODO: If this search ever becomes a bottleneck, do it properly by sorting both arrays and using a binary search. + bool found = false; + for (size_t j = 0; j < current_bond_atoms.size(); j += 2) { + const unsigned int a = current_bond_atoms[j]; + const unsigned int b = current_bond_atoms[j + 1]; + // Check both orders + if ((first_a == a && first_b == b) || (first_b == a && first_a == b)) { + if (current_bond_order.size() != (j / 2)) { + same_order_as_first = false; + } + current_bond_order.push_back(j / 2); + found = true; + break; + } + } + assert(found); + if (!found) { + // The bond should be found, but in case it isn't, fall back to current order + // to avoid crashing. This could happen if there's an InChI key collision. + same_order_as_first = true; + break; + } + } + } + } + } + + PyArrayObject*const labels_numpy_array = labels_numpy_arrays[task_index]; + if (labels_numpy_array != nullptr) { + const char* raw_data = (const char*)PyArray_DATA(labels_numpy_array); + const intptr_t* strides = PyArray_STRIDES(labels_numpy_array); + PyArrayObject*const label_offsets_numpy_array = label_offsets_numpy_arrays[task_index]; + const char* offsets_raw_data = label_offsets_numpy_array ? (const char*)PyArray_DATA(label_offsets_numpy_array) : nullptr; + const intptr_t offsets_stride = label_offsets_numpy_array ? PyArray_STRIDES(label_offsets_numpy_array)[0] : 0; + if (offsets_raw_data == nullptr) { + const char* row_data = raw_data + strides[0]*task_mol_index; + store_single_row(row_data, task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, do_clipping, task_stats); + } + else { + size_t begin_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*task_mol_index); + size_t end_offset = *reinterpret_cast(offsets_raw_data + offsets_stride*(task_mol_index+1)); + const char* row_data = raw_data + strides[0]*begin_offset; + if (same_order_as_first) { + for (size_t row = begin_offset; row < end_offset; ++row, row_data += strides[0]) { + store_single_row(row_data, task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, do_clipping, task_stats); + } + } + else if (task_levels[task_index] == FeatureLevel::NODE) { + assert(end_offset - begin_offset == current_atom_order.size()); + for (unsigned int current_index : current_atom_order) { + store_single_row(row_data + current_index*strides[0], task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, do_clipping, task_stats); + } + } + else if (task_levels[task_index] == FeatureLevel::NODEPAIR) { + const size_t n = current_atom_order.size(); + assert(end_offset - begin_offset == n*n); + for (unsigned int current_index0 : current_atom_order) { + for (unsigned int current_index1 : current_atom_order) { + store_single_row(row_data + (current_index0*n + current_index1)*strides[0], task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, do_clipping, task_stats); + } + } + } + else { + assert(task_levels[task_index] == FeatureLevel::EDGE); + for (unsigned int current_index : current_bond_order) { + store_single_row(row_data + current_index*strides[0], task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, do_clipping, task_stats); + } + } + } + } + } + first_atom_order.resize(0); + current_atom_order.resize(0); + inverse_atom_order.resize(0); + first_mol.reset(); + first_bond_atoms.resize(0); + current_bond_atoms.resize(0); + current_bond_order.resize(0); + + ++num_unique_mols; + if (num_unique_mols % num_mols_per_file == 0 || sorted_index == stage_end_index) { + // Write out the data to a file + + // First, construct the filename + char filename[20+4+1]; + size_t file_num = ((num_unique_mols-1) / num_mols_per_file); + get_mol_label_filename(filename, file_num); + + std::filesystem::path file_path(stage_paths[stage_index] / filename); + FileType file = fopen_write_wrapper(file_path); + if (file == INVALID_FILE) { + return; + } +#if GRAPHIUM_CPP_DEBUGGING + printf("Writing file %s\n", file_path.string().c_str()); +#endif + size_t num_bytes_written = fwrite_wrapper(data.data(), data_offset, file); + fclose_wrapper(file); + if (num_bytes_written != data_offset) { + return; + } + data.resize(0); + + // One extra data offset to mark the end of each file. + // data_offset is automatically reset to 0 on the next iteration + // due to data.size() being 0 now. + mol_data_offsets.push_back(data_offset); + } + } + + // Write out the molecule data offsets to a separate file, + // so that only one file read is needed per molecule when data loading + // if the offsets are all loaded once and kept in memory. + // Note the one extra entry per file. +#if GRAPHIUM_CPP_DEBUGGING + printf("Stage %s has %zu unique mols from %zu original\n", stages[stage_index].c_str(), size_t(num_unique_mols), size_t(stage_end_index - stage_begin_index)); +#endif + assert(mol_data_offsets.size() == num_unique_mols + (num_unique_mols + num_mols_per_file-1)/num_mols_per_file); + std::filesystem::path file_path(stage_paths[stage_index] / "mol_offsets.tmp"); + FileType file = fopen_write_wrapper(file_path); + if (file == INVALID_FILE) { + return; + } + size_t num_bytes_written = fwrite_wrapper(&num_unique_mols, sizeof(num_unique_mols), file); + if (num_bytes_written != sizeof(num_unique_mols)) { + fclose_wrapper(file); + return; + } + size_t num_offsets = mol_data_offsets.size(); + size_t data_offsets_size = num_offsets*sizeof(mol_data_offsets[0]); + num_bytes_written = fwrite_wrapper(mol_data_offsets.data(), data_offsets_size, file); + fclose_wrapper(file); + if (num_bytes_written != data_offsets_size) { + return; + } + + static_assert(sizeof(int64_t) == sizeof(mol_data_offsets[0])); + save_array_to_file(stage_paths[stage_index], file_data_offsets_filename, mol_data_offsets.data(), num_offsets); + std::unique_ptr temp_data_offsets(new int64_t[num_offsets]); + memcpy(temp_data_offsets.get(), mol_data_offsets.data(), data_offsets_size); + const int64_t data_offsets_dims[1] = { int64_t(num_offsets) }; + at::Tensor data_offsets_tensor = torch_tensor_from_array(std::move(temp_data_offsets), data_offsets_dims, 1, c10::ScalarType::Long); + + per_stage_return_data[stages[stage_index]].push_back(std::move(data_offsets_tensor)); + mol_data_offsets.resize(0); + } +} + +// Merges label data for equivalent molecules from separate datasets, +// computes statistics, and caches the label data to files for efficient loading later. +// +// Returns: +// stage -> [ +// unique mol smiles strings all concatenated, +// unique mol smiles string offsets (including one extra for the end), +// unique mol num_nodes, +// unique mol num_edges, +// mol_file_data_offsets +// ] +// task -> 4 stats tensors each +// task index -> label num columns +// task index -> label torch data type enum +// +// See the declaration in labels.h for more details. +std::tuple< + std::unordered_map>, + std::unordered_map>, + std::vector, + std::vector +> prepare_and_save_data( + const pybind11::list& task_names, + pybind11::dict& task_dataset_args, + const pybind11::dict& task_label_normalization, + const std::string processed_graph_data_path, + const std::string data_hash, + const pybind11::dict& task_train_indices, + const pybind11::dict& task_val_indices, + const pybind11::dict& task_test_indices, + bool add_self_loop, + bool explicit_H, + int max_threads, + bool merge_equivalent_mols) { + + ensure_numpy_array_module_initialized(); + + const size_t num_tasks = task_names.size(); + std::vector return_label_num_cols(num_tasks, 0); + std::vector return_label_data_types(num_tasks, -1); + std::unique_ptr task_col_starts(new size_t[num_tasks+1]); + std::unique_ptr task_bytes_per_float(new size_t[num_tasks]); + std::unique_ptr task_normalization_options(new NormalizationOptions[num_tasks]); + std::unique_ptr smiles_numpy_arrays(new PyArrayObject*[num_tasks]); + std::unique_ptr labels_numpy_arrays(new PyArrayObject*[num_tasks]); + std::unique_ptr label_offsets_numpy_arrays(new PyArrayObject*[num_tasks]); + std::unique_ptr task_levels(new FeatureLevel[num_tasks]); + + // Figure out the task bounds first, so that everything can be parallelized perfectly. + get_task_data( + task_names, + task_dataset_args, + task_label_normalization, + return_label_num_cols.data(), + return_label_data_types.data(), + task_col_starts.get(), + task_bytes_per_float.get(), + task_normalization_options.get(), + smiles_numpy_arrays.get(), + labels_numpy_arrays.get(), + label_offsets_numpy_arrays.get(), + task_levels.get()); + + const size_t total_num_cols = task_col_starts[num_tasks]; + + std::filesystem::path base_path{processed_graph_data_path}; + std::filesystem::create_directories(base_path); + std::filesystem::path common_path(base_path / data_hash); + std::filesystem::create_directories(common_path); + + if (total_num_cols > 0) { + save_num_cols_and_dtypes(common_path, return_label_num_cols, return_label_data_types); + } + + std::unique_ptr task_mol_start(new size_t[num_stages*num_tasks + 1]); + std::vector task_mol_indices; + std::vector smiles_strings; + get_indices_and_strings( + task_names, + task_train_indices, + task_val_indices, + task_test_indices, + task_mol_start.get(), + task_mol_indices, + smiles_numpy_arrays.get(), + smiles_strings); + const size_t total_num_mols = task_mol_indices.size(); + + // Compute all InChI keys for all molecules, in parallel if applicable. + std::unique_ptr keys(new MolKey[total_num_mols]); + compute_mol_keys( + keys.get(), + total_num_mols, + num_tasks, + max_threads, + task_mol_start.get(), + add_self_loop, + explicit_H, + merge_equivalent_mols, + task_mol_indices.data(), + smiles_strings); + + std::unique_ptr all_task_stats; + auto all_stats_return_data = compute_stats( + common_path, + total_num_cols, + task_names, + task_mol_start.get(), + task_col_starts.get(), + task_bytes_per_float.get(), + task_normalization_options.get(), + labels_numpy_arrays.get(), + label_offsets_numpy_arrays.get(), + keys.get(), + all_task_stats); + + if (merge_equivalent_mols) { + // Sort train, val, and test separately, since they need to be stored separately. + // Don't sort until after accumulating stats, because the code above currently assumes that the tasks + // aren't interleaved. + std::sort(keys.get(), keys.get() + task_mol_start[num_tasks]); + std::sort(keys.get() + task_mol_start[num_tasks], keys.get() + task_mol_start[2*num_tasks]); + std::sort(keys.get() + task_mol_start[2*num_tasks], keys.get() + total_num_mols); + } + + std::filesystem::path stage_paths[num_stages] = { + base_path / (stages[0] + "_" + data_hash), + base_path / (stages[1] + "_" + data_hash), + base_path / (stages[2] + "_" + data_hash) + }; + std::filesystem::create_directories(stage_paths[0]); + std::filesystem::create_directories(stage_paths[1]); + std::filesystem::create_directories(stage_paths[2]); + + // Deal with non-label data first (smiles, num_nodes, num_edges) + auto per_stage_return_data = save_non_label_data( + stage_paths, + num_tasks, + task_mol_start.get(), + keys.get(), + smiles_strings, + total_num_cols); + + if (total_num_cols == 0) { + // No label data, so all done + return std::make_tuple( + std::move(per_stage_return_data), + std::move(all_stats_return_data), + std::move(return_label_num_cols), + std::move(return_label_data_types)); + } + + save_label_data( + per_stage_return_data, + stage_paths, + num_tasks, + task_mol_start.get(), + task_col_starts.get(), + total_num_cols, + keys.get(), + labels_numpy_arrays.get(), + label_offsets_numpy_arrays.get(), + task_normalization_options.get(), + all_task_stats.get(), + task_bytes_per_float.get(), + task_levels.get(), + smiles_strings, + explicit_H); + + return std::make_tuple( + std::move(per_stage_return_data), + std::move(all_stats_return_data), + std::move(return_label_num_cols), + std::move(return_label_data_types)); +} + +// Loads label data associated with the molecule with index `mol_index` from the corresponding +// file in the directory `stage_directory`. +// See the declaration in labels.h for more details. +void load_labels_from_index( + const std::string stage_directory, + const int64_t mol_index, + const at::Tensor& mol_file_data_offsets, + const pybind11::list& label_names, + const pybind11::list& label_num_cols, + const pybind11::list& label_data_types, + pybind11::dict& labels +) { + const std::filesystem::path stage_path{stage_directory}; + if (mol_index < 0) { + printf("Error: In load_labels_from_index, mol_index = %lld\n", (long long)mol_index); + return; + } + const uint64_t file_num = uint64_t(mol_index) / num_mols_per_file; + const size_t index_into_offsets = file_num*(num_mols_per_file+1) + (uint64_t(mol_index) % num_mols_per_file); + + const size_t num_data_offsets = (mol_file_data_offsets.scalar_type() == c10::ScalarType::Long && mol_file_data_offsets.ndimension() == 1) ? mol_file_data_offsets.size(0) : 0; + if (index_into_offsets+1 >= num_data_offsets) { + printf("Error: In load_labels_from_index, mol_index = %zu, index_into_offsets = %zu, num_data_offsets = %zu\n", + size_t(mol_index), size_t(index_into_offsets), size_t(num_data_offsets)); + return; + } + // NOTE: If TensorBase::data_ptr is ever removed, change it to TensorBase::const_data_ptr. + // Some torch version being used doesn't have const_data_ptr yet. + const int64_t* const data_offsets = mol_file_data_offsets.data_ptr(); + const int64_t file_begin_offset = data_offsets[index_into_offsets]; + const int64_t file_end_offset = data_offsets[index_into_offsets+1]; + if (file_end_offset < 0 || file_end_offset-file_begin_offset < 8) { + printf("Error: In load_labels_from_index, mol_index = %zu, file_begin_offset = %lld, file_end_offset = %lld\n", + size_t(mol_index), (long long)(index_into_offsets), (long long)(num_data_offsets)); + return; + } + const size_t file_read_size = size_t(file_end_offset - file_begin_offset); + + std::unique_ptr data(new char[file_read_size]); + + { + char filename[25]; + get_mol_label_filename(filename, file_num); + + const std::filesystem::path file_path{stage_path / filename}; + FileType file = fopen_read_wrapper(file_path); + if (file == INVALID_FILE) { + printf("Error: In load_labels_from_index, failed to open \"%s\" for molecule %zu\n", + file_path.string().c_str(), size_t(mol_index)); + return; + } + int seek_failed = fseek_wrapper(file, file_begin_offset); + if (seek_failed) { + printf("Error: In load_labels_from_index, failed to seek to offset %zu in \"%s\" for molecule %zu\n", + size_t(file_begin_offset), file_path.string().c_str(), size_t(mol_index)); + fclose_wrapper(file); + return; + } + size_t num_bytes_read = fread_wrapper(data.get(), file_read_size, file); + fclose_wrapper(file); + if (num_bytes_read != file_read_size) { + printf("Error: In load_labels_from_index, read only %zu/%zu bytes from \"%s\" for molecule %zu\n", + size_t(num_bytes_read), size_t(file_read_size), file_path.string().c_str(), size_t(mol_index)); + return; + } + } + + uint64_t mol_num_tasks = 0; + memcpy(&mol_num_tasks, data.get(), sizeof(uint64_t)); + size_t data_offset = sizeof(uint64_t); + if (mol_num_tasks == 0 || mol_num_tasks > label_names.size() || file_read_size < (1+2*mol_num_tasks)*sizeof(uint64_t)) { + printf("Error: In load_labels_from_index, mol_index = %zu, mol_num_tasks = %zu, file_read_size = %zu\n", + size_t(mol_index), size_t(mol_num_tasks), size_t(file_read_size)); + return; + } + const size_t base_offset = (1+2*mol_num_tasks)*sizeof(uint64_t); + const char* base_task_data = data.get() + base_offset; + uint64_t task_offset = 0; + for (size_t data_task_index = 0; data_task_index < mol_num_tasks; ++data_task_index) { + uint64_t task_index = 0; + memcpy(&task_index, data.get() + data_offset, sizeof(uint64_t)); + data_offset += sizeof(uint64_t); + if (task_index >= label_names.size() || task_index >= label_data_types.size() || task_index >= label_num_cols.size()) { + printf("Error: In load_labels_from_index, mol_index = %zu, task_index = %zu\n", + size_t(mol_index), size_t(task_index)); + return; + } + + uint64_t task_end_offset = 0; + memcpy(&task_end_offset, data.get() + data_offset, sizeof(uint64_t)); + data_offset += sizeof(uint64_t); + if (task_end_offset < task_offset || task_end_offset > file_read_size-base_offset) { + printf("Error: In load_labels_from_index, mol_index = %zu, task_offset = %zu, task_end_offset = %zu, file_read_size = %zu, base_offset = %zu\n", + size_t(mol_index), size_t(task_offset), size_t(task_end_offset), size_t(file_read_size), size_t(base_offset)); + return; + } + + const size_t task_num_bytes = task_end_offset - task_offset; + if (!pybind11::isinstance(label_data_types[task_index]) || + !pybind11::isinstance(label_num_cols[task_index])) { + printf("Error: In load_labels_from_index, mol_index = %zu, task_index = %zu, label_data_type = \"%s\", label_num_cols = \"%s\"\n", + size_t(mol_index), size_t(task_index), + std::string(pybind11::str(label_data_types[task_index])).c_str(), + std::string(pybind11::str(label_num_cols[task_index])).c_str()); + return; + } + const c10::ScalarType torch_type = c10::ScalarType(size_t(label_data_types[task_index].cast())); + const size_t num_cols = size_t(label_num_cols[task_index].cast()); + if (num_cols == 0) { + printf("Error: In load_labels_from_index, mol_index = %zu, task_index = %zu, label_data_type = %zu, label_num_cols = %zu\n", + size_t(mol_index), size_t(task_index), + size_t(torch_type), num_cols); + return; + } + const size_t supported_type_index = torch_type_index(torch_type); + if (supported_type_index >= num_supported_types) { + printf("Error: In load_labels_from_index, mol_index = %zu, task_index = %zu, label_data_type = %zu, label_num_cols = %zu\n", + size_t(mol_index), size_t(task_index), + size_t(torch_type), num_cols); + } + const size_t bytes_per_float = supported_types[supported_type_index].size; + const size_t num_floats = task_num_bytes / bytes_per_float; + const size_t num_rows = num_floats / num_cols; + + if (num_floats != num_rows*num_cols) { + printf("Error: In load_labels_from_index, mol_index = %zu, task data bytes = %zu (not a multiple of %zu*%zu)\n", + size_t(mol_index), size_t(task_num_bytes), bytes_per_float, num_cols); + return; + } + + const std::string label_name{pybind11::str(label_names[task_index])}; + const bool is_graph_level = (std::strncmp(label_name.c_str(), "graph", 5) == 0); + if (is_graph_level && num_rows != 1) { + printf("Error: In load_labels_from_index, mol_index = %zu, num_rows = %zu for task \"%s\"\n", + size_t(mol_index), num_rows, label_name.c_str()); + return; + } + size_t num_label_dims = is_graph_level ? 1 : 2; + const int64_t label_dims[2] = { (is_graph_level ? int64_t(num_floats) : int64_t(num_rows)), int64_t(num_cols) }; + at::Tensor label_tensor; + + if (bytes_per_float == 2) { + std::unique_ptr label_data(new uint16_t[num_floats]); + memcpy(label_data.get(), base_task_data + task_offset, task_num_bytes); + label_tensor = torch_tensor_from_array(std::move(label_data), label_dims, num_label_dims, torch_type); + } + else if (bytes_per_float == 4) { + std::unique_ptr label_data(new float[num_floats]); + memcpy(label_data.get(), base_task_data + task_offset, task_num_bytes); + label_tensor = torch_tensor_from_array(std::move(label_data), label_dims, num_label_dims, torch_type); + } + else if (bytes_per_float == 8) { + std::unique_ptr label_data(new double[num_floats]); + memcpy(label_data.get(), base_task_data + task_offset, task_num_bytes); + label_tensor = torch_tensor_from_array(std::move(label_data), label_dims, num_label_dims, torch_type); + } + + PyDict_SetItem(labels.ptr(), label_names[task_index].ptr(), THPVariable_Wrap(std::move(label_tensor))); + + task_offset = task_end_offset; + } +} + +// Extracts a single string from `concat_strings`, a Tensor of contatenated strings, +// using offsets at the specified `index` in `string_offsets`. +// See the declaration in labels.h. +std::string extract_string( + const at::Tensor& concat_strings, + const at::Tensor& string_offsets, + const int64_t index) { + + const size_t data_size = (concat_strings.scalar_type() == c10::ScalarType::Char && concat_strings.ndimension() == 1) ? concat_strings.size(0) : 0; + const size_t num_data_offsets = (string_offsets.scalar_type() == c10::ScalarType::Long && string_offsets.ndimension() == 1) ? string_offsets.size(0) : 0; + if (index < 0 || size_t(index) >= num_data_offsets) { + return std::string(); + } + const char* const data = reinterpret_cast(concat_strings.data_ptr()); + const int64_t* const data_offsets = string_offsets.data_ptr(); + int64_t offset = data_offsets[index]; + int64_t end_offset = data_offsets[index+1]; + int64_t size = (end_offset - offset) - 1; + if (offset < 0 || size < 0 || end_offset > int64_t(data_size)) { + return std::string(); + } + return std::string(data + offset, size_t(size)); +} diff --git a/graphium/graphium_cpp/labels.h b/graphium/graphium_cpp/labels.h new file mode 100644 index 000000000..d1434a459 --- /dev/null +++ b/graphium/graphium_cpp/labels.h @@ -0,0 +1,207 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file declares functions for preprocessing and looking up label data, +//! for exporting to Python, defined in labels.cpp + +#pragma once + +#include +#include +#include +#include + +// Torch tensor headers +#include +#include +#include + +// PyBind and Torch headers +#include +#include +#include + +//! Reads the number of columns and data type for each task, from the common label metadata +//! file that was already saved by `prepare_and_save_data`, possibly on a previous run, in the +//! directory `processed_graph_data_path/data_hash`. Returns empty lists on failure. +//! +//! This is implemented in labels.cpp, and declared here so that graphium_cpp.cpp +//! can expose it to Python via pybind. +std::tuple< + std::vector, + std::vector +> load_num_cols_and_dtypes( + const std::string& processed_graph_data_path, + const std::string& data_hash); + +//! Reads data from the stage-specific label metadata files that were already saved by +//! `prepare_and_save_data`, possibly on a previous run, in the directory +//! `processed_graph_data_path/stage_data_hash`. Returns an empty list on failure. +//! +//! On success, the returned tensors are: +//! 0) All SMILES strings concatenated, +//! 1) The beginning offsets of each SMILES string in the first tensor, and +//! one extra at the end equal to the length of the first tensor +//! 2) The number of nodes (atoms) in each molecule +//! 3) The number of edges (bonds) in each molecule +//! 4) (Optional if only inference) The offset of each molecule's label data within the +//! label data files, plus an extra for the end of each file +//! The first two tensors are used by `extract_string`. The optional last tensor is +//! used by `load_labels_from_index`. +//! +//! This is implemented in labels.cpp, and declared here so that graphium_cpp.cpp +//! can expose it to Python via pybind. +std::vector load_metadata_tensors( + const std::string processed_graph_data_path, + const std::string stage, + const std::string data_hash); + +//! Reads data from the task-specific stats file that was already saved by +//! `prepare_and_save_data`, possibly on a previous run, in the directory +//! `processed_graph_data_path/data_hash`. Returns an empty list on failure. +//! +//! Each tensor's length is the number of columns for this task, and there are 4 +//! tensors total: minimum, maximum, mean, standard deviation. +//! +//! This is implemented in labels.cpp, and declared here so that graphium_cpp.cpp +//! can expose it to Python via pybind. +std::vector load_stats( + const std::string processed_graph_data_path, + const std::string data_hash, + const std::string task_name); + +//! Accepts a Numpy array of strings or Python list of strings, and returns a PyTorch tensor +//! of all of the characters and another tensor containing indices into the other tensor +//! indicating where each string begins, plus one extra index indicating the end. +//! +//! This is implemented in labels.cpp, and declared here so that graphium_cpp.cpp +//! can expose it to Python via pybind. +std::pair concatenate_strings(pybind11::handle handle); + +//! Merges label data for equivalent molecules from separate datasets, +//! computes statistics, and caches the label data to files for efficient loading later. +//! +//! This is implemented in labels.cpp, and declared here so that graphium_cpp.cpp +//! can expose it to Python via pybind. +//! +//! @param task_names Python list of the names of the datasets to process. These are used for +//! looking up into the other parameters starting with `task_`, and the +//! beginning of each name must be `graph_`, `node_`, `edge_`, or `nodepair_` +//! to determine the level of the label data. +//! @param task_dataset_args Python dict mapping task names to Python dicts for each dataset. +//! Each task's dict must contain a mapping from `"smiles"` to a 1D +//! Numpy array of objects, each of which is a Python string with a +//! molecule's SMILES text. If doing inference, each task's dict must +//! also map from `"labels"` to a 2D Numpy array of float16, float32, +//! or float64 type. For node, edge, or node-pair level label data, +//! the dict must also map from `"label_offsets"` to a 1D Numpy array +//! of type int64, indicating the row in the `"labels"` array where +//! each molecule's data begins, plus an extra for the end. If +//! `"label_offsets"` is not present, the `"labels"` array has one row +//! per molecule, and if it is present, the `"labels"` array has one +//! row per atom, bond, or pair of atoms, according to the label level. +//! @param task_label_normalization Python dict mapping task names to Python dicts for each +//! dataset's normalization options. Each task's dict must +//! contain a mapping from `"method"` to either `"none"`, +//! `"normal"`, or `"unit"`, and can optionally contain a +//! mapping from `"min_clipping"` and/or `"max_clipping"` to a +//! Python float or int to explicitly clip the range. +//! @param processed_graph_data_path String containing the base directory to create +//! subdirectories for cached files in. It can exist already, +//! or will be created if it does not already exist. +//! @param data_hash String representing a hash of the label data options. It will be used in +//! the names of all subdirectories created under `processed_graph_data_path`. +//! @param task_train_indices Python dict mapping task names to Python lists of ints, indicating +//! indices into `task_dataset_args[task_name]["smiles"]` and other +//! per-molecule-per-task arrays. Only these molecules will be used +//! for the "train" stage. +//! @param task_val_indices Python dict mapping task names to Python lists of ints, indicating +//! indices into `task_dataset_args[task_name]["smiles"]` and other +//! per-molecule-per-task arrays. Only these molecules will be used +//! for the "val" stage. +//! @param task_test_indices Python dict mapping task names to Python lists of ints, indicating +//! indices into `task_dataset_args[task_name]["smiles"]` and other +//! per-molecule-per-task arrays. Only these molecules will be used +//! for the "test" stage. +//! @param add_self_loop If true (default is false), `num_atoms` is added to the number of +//! directed edges (twice the number of bonds). This is for consistency +//! with `featurize_smiles` later. +//! @param explicit_H If true (default is false), any implicit hydrogens will be made explicit, +//! possibly increasing the number of atoms. +//! @param max_threads If greater than zero, at most this many threads will be created for +//! processing in parallel. If zero (the default), at most one thread per +//! logical CPU core will be created. If less than zero, the limit is +//! reduced by adding this negative amount to the number of logical CPU +//! cores. +//! @param merge_equivalent_mols If true (the default), label data for the same molecule in +//! different datasets are collected together, even if the atoms +//! or bonds are in a different order. Duplicates of the same +//! molecule within a single dataset will be ignored. This is very +//! slow, and changes the number and order of the molecules, so it +//! can be set to false for inference, where there is no label data +//! or only one dataset. +//! @return Four objects: +//! - A dict mapping the stage names ("train", "val", "test") to a list of five 1D +//! PyTorch tensors: +//! 0) SMILES strings all concatenated, one per unique molecule +//! 1) Offsets into the previous tensor where the strings begin, one per unique +//! molecule, plus one extra for the end +//! 2) Number of nodes (atoms) in each unique molecule +//! 3) Number of edges (2*bonds) in each unique molecule +//! 4) (Only if there is label data) `mol_file_data_offsets` to be passed to calls +//! to `load_labels_from_index` +//! - A dict mapping task names to a list of four 1D PyTorch tensors for column +//! normalization: minimum, maximum, mean, standard deviation +//! - A list of the number of columns in each task, in the same order as `task_names` +//! - A list of integers representing the Torch data type of each task, in the same +//! order as `task_names` +std::tuple< + std::unordered_map>, + std::unordered_map>, + std::vector, + std::vector +> prepare_and_save_data( + const pybind11::list& task_names, + pybind11::dict& task_dataset_args, + const pybind11::dict& task_label_normalization, + const std::string processed_graph_data_path, + const std::string data_hash, + const pybind11::dict& task_train_indices, + const pybind11::dict& task_val_indices, + const pybind11::dict& task_test_indices, + bool add_self_loop = false, + bool explicit_H = false, + int max_threads = 0, + bool merge_equivalent_mols = true); + +//! Loads label data associated with the molecule with index `mol_index` from the corresponding +//! file in the directory `stage_directory`, and adds the data to `labels` dictionary using +//! the strings from `label_names` to map to tensors. The label data must be previously saved +//! by `prepare_and_save_data`. `mol_file_data_offsets` is used to determine how to find the +//! data in the file, `label_data_types` is used for the type and size of each float, and +//! `label_num_cols` is used to determine the layout of each output tensor, especially ones with +//! multiple rows, such as node-level, edge-level, or node-pair-level label data. +//! +//! This is implemented in labels.cpp, and declared here so that graphium_cpp.cpp +//! can expose it to Python via pybind. +void load_labels_from_index( + const std::string stage_directory, + const int64_t mol_index, + const at::Tensor& mol_file_data_offsets, + const pybind11::list& label_names, + const pybind11::list& label_num_cols, + const pybind11::list& label_data_types, + pybind11::dict& labels); + +//! Extracts a single string from `concat_strings`, a Tensor of contatenated strings, +//! using offsets at the specified `index` in `string_offsets`. +//! +//! The tensors can be returned by `load_metadata_tensors`, `concatenate_strings`, or +//! `prepare_and_save_data`. +//! +//! This is implemented in labels.cpp, and declared here so that graphium_cpp.cpp +//! can expose it to Python via pybind. +std::string extract_string( + const at::Tensor& concat_strings, + const at::Tensor& string_offsets, + const int64_t index); diff --git a/graphium/graphium_cpp/one_hot.cpp b/graphium/graphium_cpp/one_hot.cpp new file mode 100644 index 000000000..16ad235e6 --- /dev/null +++ b/graphium/graphium_cpp/one_hot.cpp @@ -0,0 +1,375 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file defines functions for one-hot atom and bond features, +//! declared in one_hot.h and called from features.cpp + +#include "one_hot.h" +#include "features.h" +#include "float_features.h" + +#include +#include + +#include +#include +#include +#include + +// Helper class to automatically generates a reverse lookup table at compile time, +// with `MAX_OUT` used as a sentinel to indicate that a value wasn't present +// in the original list. +template +class OneHotLookup { + size_t indices[NUM_IN]; +public: + constexpr OneHotLookup(const size_t list[MAX_OUT]) : indices() { + std::fill(indices, indices + NUM_IN, MAX_OUT); + for (size_t i = 0; i < MAX_OUT; ++i) { + indices[list[i]] = i; + } + } + constexpr size_t operator[](size_t i) const { + return (i < NUM_IN) ? indices[i] : MAX_OUT; + } +}; + +// This list of elements matches ATOM_LIST in older file graphium/features/nmp.py +constexpr size_t atomicNumList[] = { + 6 -1, // C + 7 -1, // N + 8 -1, // O + 16-1,// S + 9 -1, // F + 14-1,// Si + 15-1,// P + 17-1,// Cl + 35-1,// Br + 12-1,// Mg + 11-1,// Na + 20-1,// Ca + 26-1,// Fe + 33-1,// As + 13-1,// Al + 53-1,// I + 5 -1,// B + 23-1,// V + 19-1,// K + 81-1,// Tl + 70-1,// Yb + 51-1,// Sb + 50-1,// Sn + 47-1,// Ag + 46-1,// Pd + 27-1,// Co + 34-1,// Se + 22-1,// Ti + 30-1,// Zn + 1 -1,// H + 3 -1,// Li + 32-1,// Ge + 29-1,// Cu + 79-1,// Au + 28-1,// Ni + 48-1,// Cd + 49-1,// In + 25-1,// Mn + 40-1,// Zr + 24-1,// Cr + 78-1,// Pt + 80-1,// Hg + 82-1,// Pb +}; +constexpr size_t atomicNumCount = std::extent::value; +constexpr OneHotLookup<118, atomicNumCount> atomicNumLookup(atomicNumList); + +constexpr size_t degreeCount = 5; +constexpr size_t valenceCount = 7; + +// Reverse alphabetical order, excluding "OTHER", +// matching HYBRIDIZATION_LIST in older file graphium/features/nmp.py +constexpr size_t hybridizationList[] = { + RDKit::Atom::HybridizationType::UNSPECIFIED, + RDKit::Atom::HybridizationType::SP3D2, + RDKit::Atom::HybridizationType::SP3D, + RDKit::Atom::HybridizationType::SP3, + RDKit::Atom::HybridizationType::SP2D, + RDKit::Atom::HybridizationType::SP2, + RDKit::Atom::HybridizationType::SP, + RDKit::Atom::HybridizationType::S, +}; +constexpr size_t hybridizationCount = std::extent::value; +constexpr OneHotLookup<8, hybridizationCount> hybridizationLookup(hybridizationList); + +static const std::string chiralityRString("R"); + +enum ElementPhase { + GAS, + ARTIFICIAL, + LIQ, + SOLID +}; +// This table is from the Phase column of graphium/features/periodic_table.csv +constexpr ElementPhase atomicNumToPhase[] = { + GAS, GAS, + SOLID, SOLID, SOLID, SOLID, GAS, GAS, GAS, GAS, + SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, GAS, GAS, + SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, LIQ, GAS, + SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, ARTIFICIAL, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, GAS, + SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, ARTIFICIAL, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, LIQ, SOLID, SOLID, SOLID, SOLID, SOLID, GAS, + SOLID, SOLID, SOLID, SOLID, SOLID, SOLID, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, ARTIFICIAL, +}; +constexpr size_t phaseCount = 4; + +enum ElementType { + NOBLE_GAS, + ALKALI_METAL, + METAL, HALOGEN, + LANTHANIDE, + ALKALINE_EARTH_METAL, + TRANSITION_METAL, + ACTINIDE, + METALLOID, + NONE, + TRANSACTINIDE, + NONMETAL, + + NUM_ELEMENT_TYPES +}; +// This table is from the Type column of graphium/features/periodic_table.csv +constexpr ElementType atomicNumToType[] = { + NONMETAL, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, METALLOID, NONMETAL, NONMETAL, NONMETAL, HALOGEN, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, METAL, METALLOID, NONMETAL, NONMETAL, HALOGEN, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, METAL, METALLOID, METALLOID, NONMETAL, HALOGEN, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, METAL, METAL, METALLOID, METALLOID, HALOGEN, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, LANTHANIDE, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, TRANSITION_METAL, METAL, METAL, METAL, METALLOID, NOBLE_GAS, + ALKALI_METAL, ALKALINE_EARTH_METAL, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, ACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, TRANSACTINIDE, NONE, TRANSACTINIDE, NONE, TRANSACTINIDE, NONE, NOBLE_GAS +}; +constexpr size_t typeCount = ElementType::NUM_ELEMENT_TYPES; + +// This matches BOND_TYPES in older file graphium/features/nmp.py +constexpr size_t bondTypeList[] = { + RDKit::Bond::BondType::SINGLE, + RDKit::Bond::BondType::DOUBLE, + RDKit::Bond::BondType::TRIPLE, + RDKit::Bond::BondType::AROMATIC, +}; +constexpr size_t bondTypeCount = std::extent::value; +constexpr OneHotLookup<22, bondTypeCount> bondTypeLookup(bondTypeList); + +// This matches BOND_STEREO in older file graphium/features/nmp.py +constexpr size_t bondStereoList[] = { + RDKit::Bond::BondStereo::STEREONONE, + RDKit::Bond::BondStereo::STEREOANY, + RDKit::Bond::BondStereo::STEREOZ, + RDKit::Bond::BondStereo::STEREOE, + RDKit::Bond::BondStereo::STEREOCIS, + RDKit::Bond::BondStereo::STEREOTRANS, +}; +constexpr size_t bondStereoCount = std::extent::value; +constexpr OneHotLookup<6, bondStereoCount> bondStereoLookup(bondStereoList); + +// Returns the number of values per atom, required by `feature` in `get_one_hot_atom_feature`'s +// `data` argument. +size_t get_one_hot_atom_feature_size(AtomOneHotFeature feature) { + switch (feature) { + case AtomOneHotFeature::ATOMIC_NUM: return atomicNumCount + 1; + case AtomOneHotFeature::DEGREE: return degreeCount + 1; + case AtomOneHotFeature::VALENCE: return valenceCount + 1; + case AtomOneHotFeature::IMPLICIT_VALENCE: return valenceCount + 1; + case AtomOneHotFeature::HYBRIDIZATION: return hybridizationCount + 1; + // "R", anything else ("S" or no value), bool for if other property present + case AtomOneHotFeature::CHIRALITY: return 3; + case AtomOneHotFeature::PHASE: return phaseCount + 1; + case AtomOneHotFeature::TYPE: return typeCount + 1; + case AtomOneHotFeature::GROUP: return groupCount + 1; + case AtomOneHotFeature::PERIOD: return periodCount + 1; + default: + // Missing implementation + assert(0); + return 0; + } +} + +// Fills in a particular atom `feature`'s one-hot encoding into `data`, for all atoms. +// See the declaration in one_hot.h for more details. +template +size_t get_one_hot_atom_feature(const GraphData& graph, T* data, AtomOneHotFeature feature, size_t stride) { + const size_t num_atoms = graph.num_atoms; + const RDKit::ROMol& mol = *graph.mol.get(); + const size_t feature_size = get_one_hot_atom_feature_size(feature); + const size_t total_feature_size = feature_size * num_atoms; + if (total_feature_size == 0) { + return feature_size; + } + { + T* current_data = data; + for (size_t i = 0; i < num_atoms; ++i) { + memset(current_data, 0, sizeof(data[0]) * feature_size); + current_data += stride; + } + } + switch (feature) { + case AtomOneHotFeature::ATOMIC_NUM: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + size_t atomicNum = graph.atoms[atomIndex].atomicNum; + data[atomicNumLookup[atomicNum-1]] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::DEGREE: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + auto degree = mol.getAtomWithIdx(atomIndex)->getDegree(); + size_t dataIndex = (degree < degreeCount) ? degree : degreeCount; + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::VALENCE: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + auto valence = mol.getAtomWithIdx(atomIndex)->getTotalValence(); + size_t dataIndex = (size_t(valence) < valenceCount) ? size_t(valence) : valenceCount; + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::IMPLICIT_VALENCE: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + auto valence = mol.getAtomWithIdx(atomIndex)->getImplicitValence(); + size_t dataIndex = (size_t(valence) < valenceCount) ? size_t(valence) : valenceCount; + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::HYBRIDIZATION: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + auto hybridization = mol.getAtomWithIdx(atomIndex)->getHybridization(); + data[hybridizationLookup[hybridization]] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::CHIRALITY: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + std::string chirality; + const RDKit::Atom* atom = mol.getAtomWithIdx(atomIndex); + bool isPresent = atom->getPropIfPresent(RDKit::common_properties::_CIPCode, chirality); + data[(isPresent && chirality == chiralityRString) ? 0 : 1] = FeatureValues::one; + if (atom->hasProp(RDKit::common_properties::_ChiralityPossible)) { + data[2] = FeatureValues::one; + } + } + return feature_size; + case AtomOneHotFeature::PHASE: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + size_t atomicNum = graph.atoms[atomIndex].atomicNum; + size_t dataIndex = phaseCount; + if (atomicNum - 1 < std::extent::value) { + ElementPhase phase = atomicNumToPhase[atomicNum - 1]; + // Group numbers are 1-based, but the array indices aren't. + dataIndex = phase - 1; + } + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::TYPE: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + size_t atomicNum = graph.atoms[atomIndex].atomicNum; + size_t dataIndex = typeCount; + if (atomicNum - 1 < std::extent::value) { + ElementType type = atomicNumToType[atomicNum - 1]; + // Group numbers are 1-based, but the array indices aren't. + dataIndex = type - 1; + } + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::GROUP: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + size_t atomicNum = graph.atoms[atomIndex].atomicNum; + size_t dataIndex = groupCount; + if (atomicNum - 1 < std::extent::value) { + uint8_t group = atomicNumToGroupTable[atomicNum - 1]; + // Group numbers are 1-based, but the array indices aren't. + dataIndex = group - 1; + } + data[dataIndex] = FeatureValues::one; + } + return feature_size; + case AtomOneHotFeature::PERIOD: + for (size_t atomIndex = 0; atomIndex < num_atoms; ++atomIndex, data += stride) { + size_t atomicNum = graph.atoms[atomIndex].atomicNum; + size_t dataIndex = periodCount; + if (atomicNum - 1 < std::extent::value) { + uint8_t period = atomicNumToPeriodTable[atomicNum - 1]; + // Period numbers are 1-based, but the array indices aren't. + dataIndex = period - 1; + } + data[dataIndex] = FeatureValues::one; + } + return feature_size; + default: + // Missing implementation + assert(0); + return feature_size; + } +} + +// Explicit instantiations, so that the function can be templated +// but still be used from other cpp files. +template size_t get_one_hot_atom_feature(const GraphData& graph, int16_t* data, AtomOneHotFeature feature, size_t stride); +template size_t get_one_hot_atom_feature(const GraphData& graph, float* data, AtomOneHotFeature feature, size_t stride); +template size_t get_one_hot_atom_feature(const GraphData& graph, double* data, AtomOneHotFeature feature, size_t stride); + + +// Returns the number of values per bond, required by `feature` in `get_one_hot_bond_feature`'s +// `data` argument. +size_t get_one_hot_bond_feature_size(BondFeature feature) { + switch (feature) { + case BondFeature::TYPE_ONE_HOT: return bondTypeCount + 1; + case BondFeature::STEREO_ONE_HOT: return bondStereoCount + 1; + default: + break; + } + // Missing implementation + assert(0); + return 0; +} + +// Fills in a particular bond `feature`'s one-hot encoding into `data`, for all bonds. +// See the declaration in one_hot.h for more details. +template +size_t get_one_hot_bond_feature(const GraphData& graph, T* data, BondFeature feature, size_t stride) { + const size_t num_bonds = graph.num_bonds; + const size_t feature_size = get_one_hot_bond_feature_size(feature); + const size_t total_feature_size = feature_size * num_bonds; + if (total_feature_size == 0) { + return 0; + } + { + T* current_data = data; + for (size_t i = 0; i < num_bonds; ++i) { + memset(current_data, 0, sizeof(data[0]) * feature_size); + current_data += stride; + } + } + switch (feature) { + case BondFeature::TYPE_ONE_HOT: + for (size_t i = 0; i < num_bonds; ++i, data += stride) { + auto type = graph.bonds[i].bondType; + data[bondTypeLookup[type]] = FeatureValues::one; + } + return feature_size; + case BondFeature::STEREO_ONE_HOT: + for (size_t i = 0; i < num_bonds; ++i, data += stride) { + auto stereo = graph.bonds[i].stereo; + data[bondStereoLookup[stereo]] = FeatureValues::one; + } + return feature_size; + default: + // Missing implementation + assert(0); + return feature_size; + } +} + +// Explicit instantiations, so that the function can be templated +// but still be used from other cpp files. +template size_t get_one_hot_bond_feature(const GraphData& graph, int16_t* data, BondFeature feature, size_t stride); +template size_t get_one_hot_bond_feature(const GraphData& graph, float* data, BondFeature feature, size_t stride); +template size_t get_one_hot_bond_feature(const GraphData& graph, double* data, BondFeature feature, size_t stride); diff --git a/graphium/graphium_cpp/one_hot.h b/graphium/graphium_cpp/one_hot.h new file mode 100644 index 000000000..adc02333f --- /dev/null +++ b/graphium/graphium_cpp/one_hot.h @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This header file declares functions for one-hot atom and bond features, +//! defined in one_hot.cpp and called from features.cpp + +#pragma once + +#include "features.h" + +#include + +#include + + +//! Returns the number of values per atom, required by `feature` in `get_one_hot_atom_feature`'s +//! `data` argument. Implementation is in one_hot.cpp +size_t get_one_hot_atom_feature_size(AtomOneHotFeature feature); + +//! Fills in a particular atom `feature`'s one-hot encoding into `data`, for all atoms. +//! Template type `T` can be `int16_t` (FP16), `float`, or `double`. +//! Implementation is in one_hot.cpp +//! +//! @param graph Molecule containing the source data to one-hot encode +//! @param data Destination array, pointing to the first atom's one-hot values for this +//! feature to be filled in. Each atom's data for this feature is +//! `get_one_hot_atom_feature_size(feature)` values long, but because different +//! features are interleaved, the beginnings of the data for each atom are spaced +//! `stride` values apart, which will be greater if there are other features. +//! @param feature The atom feature to one-hot encode (i.e. all zeros except a single one +//! whose index represents the feature value) into `data` +//! @param stride The number of values from the beginning of one atom's data to the beginning +//! of the next atom's data, which may include values for other features +//! @return The number of values per atom, i.e. `get_one_hot_atom_feature_size(feature)` +//! @see AtomOneHotFeature +//! @see get_one_hot_atom_feature_size +template +size_t get_one_hot_atom_feature(const GraphData& graph, T* data, AtomOneHotFeature feature, size_t stride); + +// Instantiation declarations of `get_one_hot_atom_feature` for `int16_t` (FP16), +// `float` (FP32), and `double` (FP64). The explicit instantiations are in one_hot.cpp +extern template size_t get_one_hot_atom_feature(const GraphData& graph, int16_t* data, AtomOneHotFeature feature, size_t stride); +extern template size_t get_one_hot_atom_feature(const GraphData& graph, float* data, AtomOneHotFeature feature, size_t stride); +extern template size_t get_one_hot_atom_feature(const GraphData& graph, double* data, AtomOneHotFeature feature, size_t stride); + +//! Returns the number of values required by `feature` in `get_one_hot_bond_feature`'s +//! `data` argument. Implementation is in one_hot.cpp +size_t get_one_hot_bond_feature_size(BondFeature feature); + +//! Fills in a particular bond `feature`'s one-hot encoding into `data`, for all bonds. +//! Template type `T` can be `int16_t` (FP16), `float`, or `double`. +//! Implementation is in one_hot.cpp +//! +//! @param graph Molecule containing the source data to one-hot encode +//! @param data Destination array, pointing to the first bond's one-hot values for this +//! feature to be filled in. Each bond's data for this feature is +//! `get_one_hot_bond_feature_size(feature)` values long, but because different +//! features are interleaved, the beginnings of the data for each bond are spaced +//! `stride` values apart, which will be greater if there are other features. +//! @param feature The bond feature to one-hot encode (i.e. all zeros except a single one +//! whose index represents the feature value) into `data` +//! @param stride The number of values from the beginning of one bond's data to the beginning +//! of the next bond's data, which may include values for other features +//! @return The number of values per bond, i.e. `get_one_hot_bond_feature_size(feature)` +//! @see BondFeature +//! @see get_one_hot_bond_feature_size +template +size_t get_one_hot_bond_feature(const GraphData& graph, T* data, BondFeature feature, size_t stride); + +// Instantiation declarations of `get_one_hot_bond_feature` for `int16_t` (FP16), +// `float` (FP32), and `double` (FP64). The explicit instantiations are in one_hot.cpp +extern template size_t get_one_hot_bond_feature(const GraphData& graph, int16_t* data, BondFeature feature, size_t stride); +extern template size_t get_one_hot_bond_feature(const GraphData& graph, float* data, BondFeature feature, size_t stride); +extern template size_t get_one_hot_bond_feature(const GraphData& graph, double* data, BondFeature feature, size_t stride); + diff --git a/graphium/graphium_cpp/pybind11 b/graphium/graphium_cpp/pybind11 new file mode 160000 index 000000000..ccefee4c3 --- /dev/null +++ b/graphium/graphium_cpp/pybind11 @@ -0,0 +1 @@ +Subproject commit ccefee4c3187c2892fcf4590b1bbc850134b84bb diff --git a/graphium/graphium_cpp/random_walk.cpp b/graphium/graphium_cpp/random_walk.cpp new file mode 100644 index 000000000..6da82ed48 --- /dev/null +++ b/graphium/graphium_cpp/random_walk.cpp @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file defines and instantiates the `compute_rwse` function, +//! declared in random_walk.h and called from features.cpp + +#include "random_walk.h" + +#include +#include +#include +#include +#include + +//! Multiplies the dense `n` by `n` matrix `in_matrix` by the sparse `n` by `n` matrix in CSC +//! format (transpose of CSR format) represented by `neighbor_starts`, `neighbors`, and +//! `col_major_weights`, writing the results into `out_matrix`. +template +void multiply_dense_by_sparse(uint32_t n, T* out_matrix, const T* in_matrix, const uint32_t* neighbor_starts, const uint32_t* neighbors, const T* col_major_weights) { + for (uint32_t row = 0; row < n; ++row) { + T* out_row_start = out_matrix + row * n; + const T* in_row_start = in_matrix + row * n; + for (uint32_t col = 0; col < n; ++col) { + T sum = T(0); + // The adjacency is symmetric, so rows and cols are swappable there, + // but the weights might not be, so for fast access, we want column major weights. + const uint32_t* neighbors_start = neighbors + neighbor_starts[col]; + const uint32_t* neighbors_end = neighbors + neighbor_starts[col+1]; + const T* weights_start = col_major_weights + neighbor_starts[col]; + for (; neighbors_start != neighbors_end; ++neighbors_start, ++weights_start) { + sum += *weights_start * in_row_start[*neighbors_start]; + } + out_row_start[col] = sum; + } + } +} + +// Computes random walk data about the graph, either probabilities or transfer amounts +// after certain numbers of steps, outputting the values to `output`. +// See the declaration in random_walk.h for more details. +template +void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim) { + + // Cast one n to size_t to avoid integer overflow if n >= 65536 + if (option == RandomWalkDataOption::PROBABILITIES) { + output.resize(num_powers * size_t(n)); + } + else { + output.resize(num_powers * size_t(n) * n); + } + + if (num_powers == 0) { + return; + } + if (n == 1) { + // Special case: All ones for single node, matching original code + for (uint32_t i = 0; i < output.size(); ++i) { + output[i] = T(1); + } + return; + } + + // Initialize this to represent column major D^-1 * adj + std::vector col_major_weights; + col_major_weights.resize(neighbor_starts[n]); + for (uint32_t col = 0, i = 0; col < n; ++col) { + const uint32_t* neighbor_start = neighbors + neighbor_starts[col]; + const uint32_t* neighbor_end = neighbors + neighbor_starts[col+1]; + for (; neighbor_start != neighbor_end; ++neighbor_start, ++i) { + const uint32_t neighbor = *neighbor_start; + uint32_t neighbor_degree = neighbor_starts[neighbor + 1] - neighbor_starts[neighbor]; + T degree_inv = (neighbor_degree == 0) ? T(0) : T(1) / T(neighbor_degree); + col_major_weights[i] = degree_inv; + } + } + + // Space for 2 matrices, to alternate between them + std::vector matrix; + matrix.resize(2 * size_t(n) * n, T(0)); + T* matrix0 = matrix.data(); + T* matrix1 = matrix.data() + size_t(n) * n; + uint64_t current_power = 0; + // Initialize current matrix to identity matrix + for (size_t i = 0, diag_index = 0; i < n; ++i, diag_index += (n+1)) { + matrix0[diag_index] = T(1); + } + + for (uint32_t power_index = 0; power_index < num_powers; ++power_index) { + const uint64_t target_power = powers[power_index]; + assert(target_power >= current_power); + while (target_power > current_power) { + std::swap(matrix0, matrix1); + multiply_dense_by_sparse(n, matrix0, matrix1, neighbor_starts, neighbors, col_major_weights.data()); + ++current_power; + } + + // Copy results to output + if (option == RandomWalkDataOption::PROBABILITIES) { + const T scale_factor = (space_dim == 0) ? T(1) : T(std::pow(T(target_power), T(0.5) * T(space_dim))); + // Just copy the diagonal values + for (size_t i = 0, diag_index = 0; i < n; ++i, diag_index += (n + 1)) { + output[i * num_powers + power_index] = scale_factor * matrix0[diag_index]; + } + } + else { + // Copy transition probabilities, making sure the dimensions are correct, because matrix0 isn't symmetric. + // Least significant dimension is num_powers + // Middle dimension is the columns across a single row of matrix0 + // Most significant dimension is the rows of the matrix0 + const size_t row_stride = num_powers * size_t(n); + for (size_t row = 0, i = 0; row < n; ++row) { + for (size_t col = 0; col < n; ++col, ++i) { + output[row * row_stride + col * num_powers + power_index] = matrix0[i]; + } + } + } + } +} + +// Explicit instantiations of `compute_rwse` for `float` and `double` +template void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim); +template void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim); diff --git a/graphium/graphium_cpp/random_walk.h b/graphium/graphium_cpp/random_walk.h new file mode 100644 index 000000000..8f604a507 --- /dev/null +++ b/graphium/graphium_cpp/random_walk.h @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This header file declares the `compute_rwse` function, +//! defined in random_walk.cpp and called from features.cpp + +#pragma once + +#include +#include + +//! Options for the `option` parameter of `compute_rwse` function +enum class RandomWalkDataOption { + PROBABILITIES, + MATRIX +}; + +//! Computes random walk data about the graph, either probabilities or transfer amounts +//! after certain numbers of steps, outputting the values to `output`. +//! Template type `T` can be `float` or `double`. Implementation is in random_walk.cpp +//! +//! The adjacency (neighbor_starts and neighbors) must be symmetric. +//! +//! @param num_powers The length of `powers` +//! @param powers Array of `num_powers` integers, with each one indicating how many steps at +//! which to output the probabilities or transfer amounts. +//! This *must* be in increasing order. +//! @param n Number of nodes +//! @param neighbor_starts Array of `n+1` indices into `neighbors`, indicating where each node's +//! neighbors start, plus one at the end to indicate the full length of +//! `neighbors` +//! @param neighbors Concatenated array of all neighbors of all nodes, in order +//! @param option Whether to output `n` probabilities for each power or a matrix of `n^2` +//! transfer amounts for each power +//! @param output Array of values to be filled with `n * num_powers` probabilities or +//! `n^2 * num_powers` transfer amounts, as if this is a 2D array +//! `[n][num_powers]` or 3D array `[rows][cols][num_powers]`, respectively +//! @param space_dim Optional parameter to scale probabilities by `power^(0.5*space_dim)`. +//! Default of zero corresponds with not scaling the probabilities. +template +void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim = 0); + +// Instantiation declarations of `compute_rwse` for `float` and `double` +// The explicit instantiations are in random_walk.cpp +extern template void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim); +extern template void compute_rwse( + const uint32_t num_powers, + const uint64_t* powers, + const uint32_t n, + const uint32_t* neighbor_starts, + const uint32_t* neighbors, + RandomWalkDataOption option, + std::vector& output, + int space_dim); diff --git a/graphium/graphium_cpp/setup.py b/graphium/graphium_cpp/setup.py new file mode 100755 index 000000000..c1fb1e3fb --- /dev/null +++ b/graphium/graphium_cpp/setup.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Setup script that builds graphium_cpp. +At time of writing, this has only been tested with GCC 10.5.0. +To build, git clone pybind11 into this directory, then run: +rm -r build/* +export PYTHONPATH=$PYTHONPATH:./pybind11 +pip install . +""" + +from distutils.core import setup +from pybind11.setup_helpers import Pybind11Extension, build_ext +import torch, rdkit, os +import numpy + +torch_dir = torch.__path__[0] +rdkit_lib_index = rdkit.__path__[0].split("/").index("lib") +rdkit_prefix = "/".join(rdkit.__path__[0].split("/")[:rdkit_lib_index]) + +ext_modules = [ + Pybind11Extension( + "graphium_cpp", + sources=[ + "graphium_cpp.cpp", + "features.cpp", + "labels.cpp", + "commute.cpp", + "electrostatic.cpp", + "float_features.cpp", + "graphormer.cpp", + "one_hot.cpp", + "random_walk.cpp", + "spectral.cpp", + ], + language="c++", + cxx_std=20, + include_dirs=[ + os.path.join(torch_dir, "include"), + os.path.join(torch_dir, "include/torch/csrc/api/include"), + os.path.join(rdkit_prefix, "include/rdkit"), + os.path.join(rdkit_prefix, "include/boost"), + numpy.get_include(), + ], + libraries=[ + "RDKitAlignment", + "RDKitDataStructs", + "RDKitDistGeometry", + "RDKitDistGeomHelpers", + "RDKitEigenSolvers", + "RDKitForceField", + "RDKitForceFieldHelpers", + "RDKitGenericGroups", + "RDKitGraphMol", + "RDKitInchi", + "RDKitRDInchiLib", + "RDKitRDBoost", + "RDKitRDGeneral", + "RDKitRDGeometryLib", + "RDKitRingDecomposerLib", + "RDKitSmilesParse", + "RDKitSubstructMatch", + "torch_cpu", + "torch_python", + ], + library_dirs=[os.path.join(rdkit_prefix, "lib"), os.path.join(torch_dir, "lib")], + extra_compile_args=[ + "-O3", + "-Wall", + "-Wmissing-field-initializers", + "-Wmaybe-uninitialized", + "-Wuninitialized", + ], + ) +] + +setup( + name="graphium_cpp", + version="0.1", + author="N. Dickson", + author_email="ndickson@nvidia.com", + license="Apache 2.0", + description="C++ extension for graphium", + ext_modules=ext_modules, + cmdclass={"build_ext": build_ext}, +) diff --git a/graphium/graphium_cpp/spectral.cpp b/graphium/graphium_cpp/spectral.cpp new file mode 100644 index 000000000..60cc7ee07 --- /dev/null +++ b/graphium/graphium_cpp/spectral.cpp @@ -0,0 +1,346 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This file defines and instantiates the `compute_laplacian_eigendecomp` +//! and `find_components` functions, declared in spectral.h +//! and called from features.cpp + +#include "spectral.h" + +#include +#include +#include +#include + +#include "features.h" +#include + +// Finds all connected components of the graph, assigning nodes to components, +// outputting to `components` and returning the number of components. +// See the declaration in spectral.h for more details. +size_t find_components( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + std::vector& components) { + + int32_t num_components = (n <= 1) ? 1 : 0; + std::vector queue; + if (n > 1) { + // First, find which nodes are in which component. + components.resize(n, -1); + queue.reserve(n); + for (uint32_t starti = 0; starti < n; ++starti) { + if (components[starti] >= 0) { + continue; + } + const int32_t component = num_components; + ++num_components; + queue.push_back(starti); + components[starti] = component; + while (queue.size() != 0) { + uint32_t current = queue[queue.size()-1]; + queue.resize(queue.size()-1); + const uint32_t* neighbor_begin = neighbors + row_starts[current]; + const uint32_t* neighbor_end = neighbors + row_starts[current+1]; + for ( ; neighbor_begin != neighbor_end; ++neighbor_begin) { + uint32_t neighbor = *neighbor_begin; + if (neighbor > starti && components[neighbor] < 0) { + components[neighbor] = component; + queue.push_back(neighbor); + } + } + } + } + } + return size_t(num_components); +} + +// Computes the eigendecomposition of the graph Laplacian matrix for a single +// connected component of the graph. Called from compute_laplacian_eigendecomp. +template +void compute_laplacian_eigendecomp_single(const uint32_t n, LaplacianData& data, Normalization normalization) { + T* matrix = data.matrix_temp.data(); + std::unique_ptr matrix_alloc(new T[n * n]); + std::copy(matrix, matrix + n * n, matrix_alloc.get()); + + int64_t dims[2] = { n, n }; + at::Tensor torch_matrix = torch_tensor_from_array(std::move(matrix_alloc), dims, 2, c10::ScalarType::Double); + + // Using linalg_eigh should ensure we get all real eigenvalues and eigenvectors. + // Arbitrarily choose lower-triangular portion (L) + auto tuple = at::linalg_eigh(torch_matrix, c10::string_view("L",1)); + at::Tensor eigenvalue_tensor = std::move(std::get<0>(tuple)); + at::Tensor eigenvector_tensor = std::move(std::get<1>(tuple)); + assert(eigenvalue_tensor.ndimension() == 1); + assert(eigenvector_tensor.ndimension() == 2); + assert(eigenvalue_tensor.size(0) == n); + assert(eigenvector_tensor.size(0) == n); + assert(eigenvector_tensor.size(1) == n); + + // Copy eigenvectors first, because normalization values are in eigenvalues_temp + data.vectors.clear(); + data.vectors.resize(size_t(n) * n, 0); + T* vectors = data.vectors.data(); + if (eigenvector_tensor.scalar_type() == c10::ScalarType::Double) { + const double* const eigenvector_data = eigenvector_tensor.data_ptr(); + for (size_t i = 0; i < size_t(n) * n; ++i) { + vectors[i] = T(eigenvector_data[i]); + } + + if (normalization == Normalization::INVERSE) { + // Convert symmetric case eigenvectors to asymmetric case eigenvectors + + // Scale each row by the factor in eigenvalues_temp + for (size_t row = 0, i = 0; row < n; ++row) { + const T factor = data.eigenvalues_temp[row]; + for (size_t col = 0; col < n; ++col, ++i) { + vectors[i] *= factor; + } + + // Clear to zero for the summing below + data.eigenvalues_temp[row] = 0; + } + + // Find each column length + for (size_t row = 0, i = 0; row < n; ++row) { + for (size_t col = 0; col < n; ++col, ++i) { + const T v = vectors[i]; + data.eigenvalues_temp[col] += v*v; + } + } + for (size_t col = 0; col < n; ++col) { + data.eigenvalues_temp[col] = T(1)/std::sqrt(data.eigenvalues_temp[col]); + } + + // Normalize each column + for (size_t row = 0, i = 0; row < n; ++row) { + for (size_t col = 0; col < n; ++col, ++i) { + vectors[i] *= data.eigenvalues_temp[col]; + } + } + } + } + else { + assert(0); + } + + // Copy eigenvalues + data.eigenvalues_temp.resize(n); + if (eigenvalue_tensor.scalar_type() == c10::ScalarType::Double) { + const double* const eigenvalue_data = eigenvalue_tensor.data_ptr(); + for (size_t i = 0; i < n; ++i) { + // No adjustment needed to eigenvalues between symmetric and asymmetric + data.eigenvalues_temp[i] = T(eigenvalue_data[i]); + } + } + else { + assert(0); + } + + // Find the sorted order of the eigenvalues + data.order_temp.resize(n); + std::iota(data.order_temp.begin(), data.order_temp.end(), 0); + std::stable_sort(data.order_temp.begin(), data.order_temp.end(), + [&data](uint32_t i, uint32_t j) -> bool { + return data.eigenvalues_temp[i] < data.eigenvalues_temp[j]; + } + ); + + // Copy the eigenvalues into the sorted order + data.eigenvalues.resize(n); + for (size_t i = 0; i < n; ++i) { + data.eigenvalues[i] = data.eigenvalues_temp[data.order_temp[i]]; + } + + // Copy the eigenvectors into the sorted order + std::swap(data.matrix_temp, data.vectors); + for (size_t row = 0, i = 0; row < n; ++row) { + const size_t source_row = data.order_temp[row]; + const size_t source_row_start = source_row * n; + for (size_t col = 0; col < n; ++col, ++i) { + data.vectors[i] = data.matrix_temp[source_row_start + col]; + } + } +} + +// Computes the eigendecomposition of the graph Laplacian matrix, outputting to +// data.eigenvalues and data.vectors. +// See the declaration in spectral.h for more details. +template +void compute_laplacian_eigendecomp( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + Normalization normalization, + LaplacianData& data, + size_t num_components, + const int32_t* components, + const T* weights) { + + // Compute the weight row sums, if applicable, for the diagonal of the laplacian + if (weights != nullptr) { + data.eigenvalues_temp.clear(); + data.eigenvalues_temp.resize(n, 0); + for (uint32_t i = 0; i < n; ++i) { + const T* weights_begin = weights + row_starts[i]; + const T* weights_end = weights + row_starts[i + 1]; + T sum = T(0); + for (; weights_begin != weights_end; ++weights_begin) { + sum += *weights_begin; + } + data.eigenvalues_temp[i] = sum; + } + } + data.normalization = normalization; + + // Prepare the laplacian matrix of the graph + data.matrix_temp.clear(); + data.matrix_temp.resize(size_t(n) * n, 0); + T* matrix = data.matrix_temp.data(); + if (normalization == Normalization::NONE) { + for (uint32_t i = 0, outi = 0; i < n; ++i, outi += n) { + const uint32_t* neighbor_begin = neighbors + row_starts[i]; + const uint32_t* neighbor_end = neighbors + row_starts[i + 1]; + if (weights == nullptr) { + const uint32_t degree = row_starts[i + 1] - row_starts[i]; + matrix[outi + i] = T(degree); + for (; neighbor_begin < neighbor_end; ++neighbor_begin) { + uint32_t neighbor = *neighbor_begin; + matrix[outi + neighbor] = T(-1); + } + } + else { + matrix[outi + i] = data.eigenvalues_temp[i]; + const T* weights_begin = weights + row_starts[i]; + for (; neighbor_begin < neighbor_end; ++neighbor_begin, ++weights_begin) { + uint32_t neighbor = *neighbor_begin; + matrix[outi + neighbor] = -(*weights_begin); + } + } + } + } + else { + // The diagonalization of the asymmetric normalization can be computed from the + // diagonalization of the symmetric normalization, which is faster, so always use symmetric. + + // Find the normalization factor for each node (row or col) + // These values in eigenvalues_temp are also used inside compute_laplacian_eigendecomp_single + for (uint32_t node = 0; node < n; ++node) { + const uint32_t row_degree = row_starts[node + 1] - row_starts[node]; + const T denominator = (weights == nullptr) ? T(row_degree) : data.eigenvalues_temp[node]; + data.eigenvalues_temp[node] = T(1) / std::sqrt(denominator); + } + + for (uint32_t i = 0, outi = 0; i < n; ++i, outi += n) { + const uint32_t* neighbor_begin = neighbors + row_starts[i]; + const uint32_t* neighbor_end = neighbors + row_starts[i + 1]; + if (neighbor_begin == neighbor_end) { + continue; + } + + // Diagonal is always exactly 1 when normalized (after skipping zero-degree nodes) + matrix[outi + i] = T(1); + + const T row_factor = data.eigenvalues_temp[i]; + for (; neighbor_begin < neighbor_end; ++neighbor_begin) { + uint32_t neighbor = *neighbor_begin; + const T col_factor = data.eigenvalues_temp[neighbor]; + matrix[outi + neighbor] = -row_factor * col_factor; + } + } + } + + if (num_components == 1 || components == nullptr) { + compute_laplacian_eigendecomp_single(n, data, normalization); + return; + } + + // There are multiple components. + // To match the original code, handle them separately and + // pack them into the output. + + // data.eigenvalues is length n for the single component case, + // but to be able to handle this, it needs to be larger, so go with n by n + data.eigenvalues.clear(); + data.eigenvalues.resize(size_t(n) * n, 0); + data.vectors.clear(); + data.vectors.resize(size_t(n) * n, 0); + + LaplacianData sub_data; + std::vector queue; + for (int32_t component = 0; component < num_components; ++component) { + // Reuse queue for the indices + queue.resize(0); + for (uint32_t i = 0; i < n; ++i) { + if (components[i] == component) { + queue.push_back(i); + } + } + + // Extract the sub-matrix + const uint32_t sub_n = queue.size(); + sub_data.matrix_temp.resize(size_t(sub_n) * sub_n); + T* sub_matrix = sub_data.matrix_temp.data(); + for (uint32_t row_index = 0; row_index < sub_n; ++row_index) { + const uint32_t row = queue[row_index]; + const T*const source_row = matrix + row*size_t(n); + for (uint32_t col_index = 0; col_index < sub_n; ++col_index) { + const uint32_t col = queue[col_index]; + *sub_matrix = source_row[col]; + ++sub_matrix; + } + } + + // Find its eigenvalues and eigenvectors + compute_laplacian_eigendecomp_single(sub_n, sub_data, normalization); + + // Copy the eigenvalues to the output. The excess is already zeroed out. + // Unlike the eigenvectors, below, might as well switch to using columns + // for the eigenvalues, because the caller can handle this case more + // easily with the single component case this way. + for (uint32_t row_index = 0; row_index < sub_n; ++row_index) { + const uint32_t row = queue[row_index]; + T*const dest_row = data.eigenvalues.data() + row*size_t(n); + for (uint32_t col_index = 0; col_index < sub_n; ++col_index) { + // Destination data within the row is left justified, + // NOT distributed based on the component. + dest_row[col_index] = sub_data.eigenvalues[col_index]; + } + } + + // Copy the (row) eigenvectors to the output. The excess is already zeroed out. + // The caller changes them to column eigenvectors. + for (uint32_t row_index = 0; row_index < sub_n; ++row_index) { + // Destination data is top-aligned, NOT distributed + // based on the component. + T*const dest_row = data.vectors.data() + row_index*size_t(n); + const T*const source_row = sub_data.vectors.data() + row_index*size_t(sub_n); + for (uint32_t col_index = 0; col_index < sub_n; ++col_index) { + // Columns ARE distributed based on the component. + const uint32_t col = queue[col_index]; + dest_row[col] = source_row[col_index]; + } + } + } +} + +// Explicit instantiations of `compute_laplacian_eigendecomp` for `float` and `double` +template void compute_laplacian_eigendecomp( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + Normalization normalization, + LaplacianData& data, + size_t num_components, + const int32_t* components, + const float* weights); +template void compute_laplacian_eigendecomp( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + Normalization normalization, + LaplacianData& data, + size_t num_components, + const int32_t* components, + const double* weights); diff --git a/graphium/graphium_cpp/spectral.h b/graphium/graphium_cpp/spectral.h new file mode 100644 index 000000000..03628fd6f --- /dev/null +++ b/graphium/graphium_cpp/spectral.h @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! @file This header file declares the `compute_laplacian_eigendecomp` +//! and `find_components` functions, defined in spectral.cpp +//! and called from features.cpp + +#pragma once + +#include "features.h" + +#include +#include + +//! Structure for caching eigendecomposition of the graph Laplacian matrix +template +struct LaplacianData { + //! Normalization of the previous eigendecomposition, if computed + Normalization normalization; + + //! Output/cached eigenvectors of the decomposition, if computed + std::vector vectors; + //! Output/cached eigenvalues of the decomposition, if computed + std::vector eigenvalues; + + //! Temporary arrays used during decomposition + std::vector matrix_temp; + std::vector eigenvalues_temp; + std::vector order_temp; +}; + +//! Finds all connected components of the graph, assigning nodes to components. +//! +//! @param n Number of nodes +//! @param row_starts Array of `n+1` indices into `neighbors`, indicating where each node's +//! neighbors start, plus one at the end to indicate the full length of +//! `neighbors` +//! @param neighbors Concatenated array of all neighbors of all nodes, in order +//! @param components Output array to assign each node an integer indicating which +//! component it's in, in the range `[0, num_components)`. Unused if +//! `n < 2`. +//! @return The number of separate connected components found +size_t find_components( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + std::vector& components); + +//! Computes the eigendecomposition of the graph Laplacian matrix. +//! This outputs the eigenvalues in `data.eigenvalues` and the eigenvectors in `data.vectors`. +//! The Laplacian matrix is the positive semi-definite matrix `L = D - adj`, where `D` is the +//! diagonal matrix of node degrees, and `adj` is the adjacency matrix. +//! If `normalization` is not `Normalization::NONE`, `L_s = (D^-0.5) L (D^-0.5)` is +//! diagonalized, instead of `L`, and if it's `Normalization::INVERSE`, this is used to compute +//! the decomposition of `L_i = (D^-1) L = (D^-0.5) L_s (D^0.5)`. +//! Template type `T` can be `float` or `double`. Implementation is in spectral.cpp +//! +//! @param n Number of nodes +//! @param row_starts Array of `n+1` indices into `neighbors`, indicating where each node's +//! neighbors start, plus one at the end to indicate the full length of +//! `neighbors` +//! @param neighbors Concatenated array of all neighbors of all nodes, in order +//! @param normalization Whether and how to normalize the Laplacian matrix before diagonalizing. +//! This is recorded into `data.normalization` for caching use. +//! @param data Output and temporary arrays for the eigendecomposition of the graph Laplacian +//! matrix +//! @param num_components The number of connected components, if separately diagonalizing +//! each component, else 1 to diagonalize the entire graph as a whole +//! @param components Optional array of length `n`, where each node's integer indicates which +//! component it is in, in the range `[0, num_components)` +//! @param weights Optional array of edge weights, in the order corresponding with neighbors. +//! If null, the edge weights are all 1. +template +void compute_laplacian_eigendecomp( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + Normalization normalization, + LaplacianData& data, + size_t num_components, + const int32_t* components, + const T* weights = nullptr); + +// Instantiation declarations of `compute_laplacian_eigendecomp` for `float` and `double` +// The explicit instantiations are in spectral.cpp +extern template void compute_laplacian_eigendecomp( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + Normalization normalization, + LaplacianData& data, + size_t num_components, + const int32_t* components, + const float* weights); +extern template void compute_laplacian_eigendecomp( + const uint32_t n, + const uint32_t* row_starts, + const uint32_t* neighbors, + Normalization normalization, + LaplacianData& data, + size_t num_components, + const int32_t* components, + const double* weights); diff --git a/graphium/ipu/README.md b/graphium/ipu/README.md deleted file mode 100644 index 3f592d7a8..000000000 --- a/graphium/ipu/README.md +++ /dev/null @@ -1,15 +0,0 @@ -
- -

The Graph Of LIfe Library.

-
- - -## What is in this folder? - -code for IPU acceleration support - -- `ipu_dataloader.py`: code for handling dataloader on IPU -- `ipu_losses.py`: code for computing losses on IPU -- `ipu_simple_lightning.py`: code for pytorch lightning support on IPU -- `ipu_utils.py`: utils functions for IPU -- `ipu_wrapper.py`: wrapper code for IPU support \ No newline at end of file diff --git a/graphium/ipu/ipu_dataloader.py b/graphium/ipu/ipu_dataloader.py deleted file mode 100644 index 5aa7828f4..000000000 --- a/graphium/ipu/ipu_dataloader.py +++ /dev/null @@ -1,434 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Callable, Iterable, Optional, List, Tuple, Dict, Any, Union -from copy import deepcopy -from dataclasses import dataclass -import numpy as np -from loguru import logger -from torch import Tensor - -import torch -from torch_geometric.data import Data, Batch, Dataset -from torch_geometric.transforms import BaseTransform - -from graphium.data.utils import get_keys -from graphium.ipu.ipu_utils import import_poptorch -from graphium.utils.packing import ( - fast_packing, - hybrid_packing, - get_pack_sizes, - node_to_pack_indices_mask, - estimate_max_pack_node_size, -) - - -@dataclass -class IPUDataloaderOptions: - r""" - This data class stores the arguments necessary to instantiate a model for the Predictor. - - Parameters: - model_class: - pytorch module used to create a model - - model_kwargs: - Key-word arguments used to initialize the model from `model_class`. - """ - - batch_size: int - max_num_nodes: Optional[int] = None - max_num_nodes_per_graph: Optional[int] = None - max_num_edges: Optional[int] = None - max_num_edges_per_graph: Optional[int] = None - mode: "poptorch.DataLoaderMode" = "Sync" - - def set_kwargs(self): - # Get the maximum number of nodes - if self.max_num_nodes is not None: - assert ( - self.max_num_nodes_per_graph is None - ), "Cannot use `max_num_nodes` and `max_num_nodes_per_graph` simultaneously" - elif self.max_num_nodes_per_graph is not None: - assert ( - self.max_num_nodes is None - ), "Cannot use `max_num_nodes` and `max_num_nodes_per_graph` simultaneously" - self.max_num_nodes = self.max_num_nodes_per_graph * self.batch_size - else: - raise ValueError("Must provide either `max_num_nodes` or `max_num_nodes_per_graph`") - - # Get the maximum number of edges - if self.max_num_edges is not None: - assert ( - self.max_num_edges_per_graph is None - ), "Cannot use `max_num_edges` and `max_num_edges_per_graph` simultaneously" - elif self.max_num_edges_per_graph is not None: - assert ( - self.max_num_edges is None - ), "Cannot use `max_num_edges` and `max_num_edges_per_graph` simultaneously" - self.max_num_edges = self.max_num_edges_per_graph * self.batch_size - else: - raise ValueError("Must provide either `max_num_nodes` or `max_num_nodes_per_graph`") - - # poptorch mode - poptorch = import_poptorch() - if isinstance(self.mode, str): - if self.mode.lower() == "sync": - self.mode = poptorch.DataLoaderMode.Sync - elif self.mode.lower() == "async": - self.mode = poptorch.DataLoaderMode.Async - elif self.mode.lower() == "asyncrebatched": - self.mode = poptorch.DataLoaderMode.AsyncRebatched - else: - raise ValueError(f"`{self.mode}` not a valid parameter.") - - -class CombinedBatchingCollator: - """ - Collator object that manages the combined batch size defined as: - - combined_batch_size = batch_size * device_iterations - * replication_factor * gradient_accumulation - - This is intended to be used in combination with the poptorch.DataLoader - """ - - def __init__( - self, - batch_size: int, - max_num_nodes: int, - max_num_edges: int, - dataset_max_nodes_per_graph: int, - dataset_max_edges_per_graph: int, - collate_fn: Optional[Callable] = None, - ): - """ - Parameters: - batch_size: mini batch size used by the model - max_num_nodes: Maximum number of nodes in the batched padded graph - max_num_edges: Maximum number of edges in the batched padded graph - dataset_max_nodes_per_graph: Maximum number of nodes per graph in the full dataset - dataset_max_edges_per_graph: Maximum number of edges per graph in the full dataset - collate_fn: Function used to collate (or batch) the single data or graphs together - """ - super().__init__() - self.batch_size = batch_size - self.collate_fn = collate_fn - self.max_num_nodes = max_num_nodes - self.max_num_edges = max_num_edges - self.dataset_max_nodes_per_graph = dataset_max_nodes_per_graph - self.dataset_max_edges_per_graph = dataset_max_edges_per_graph - - def __call__( - self, batch: List[Dict[str, Union[Data, Dict[str, Tensor]]]] - ) -> Dict[str, Union[Batch, Dict[str, Tensor], Any]]: - """ - Stack tensors, batch the pyg graphs, and pad each tensor to be same size. - - Parameters: - batch: The batch of data, including pyg-graphs `Data` and labels `Dict[str, Tensor]` to be padded - - Returns: - out_batch: A dictionary where the graphs are batched and the labels or other Tensors are stacked - """ - - # Sort the batch such that large graphs are paired with small graphs - num_nodes = [b["features"].num_nodes for b in batch] - packed_indices = hybrid_packing(num_nodes, batch_size=self.batch_size) - packs = [[batch[idx] for idx in pack] for pack in packed_indices] - - # Loop all mini-batches within the global batch - all_batches = [] - for pack in packs: - if self.collate_fn != None: - local_batch = self.collate_fn(pack) - - transform = Pad( - max_num_nodes=self.max_num_nodes, - max_num_edges=self.max_num_edges, - dataset_max_nodes_per_graph=self.dataset_max_nodes_per_graph, - dataset_max_edges_per_graph=self.dataset_max_edges_per_graph, - ) - - local_batch["features"] = transform(local_batch["features"]) - local_batch["labels"] = transform(local_batch["labels"]) - all_batches.append(local_batch) - - out_batch = {} - - # Stack tensors in the first dimension to allow IPUs to differentiate between local and global graph - all_keys = get_keys(all_batches[0]["labels"]) - out_batch["labels"] = { - key: torch.stack([this_batch["labels"][key] for this_batch in all_batches], 0) for key in all_keys - } - out_graphs = [this_batch["features"] for this_batch in all_batches] - stacked_features = deepcopy(out_graphs[0]) - for key, val in out_graphs[0].items(): - if isinstance(val, torch.Tensor): - stacked_features[key] = torch.stack([this_graph[key] for this_graph in out_graphs], dim=0) - - out_batch["features"] = stacked_features - for key in all_batches[0].keys(): - if key not in ("features", "labels"): - out_batch[key] = [this_batch[key] for this_batch in all_batches] - - # - for data_key, data_val in out_batch.items(): - if isinstance(data_val, Batch): - for sub_key, sub_val in data_val.items(): - if isinstance(sub_val, Tensor) and sub_val.dtype == torch.int64: - out_batch[data_key][sub_key] = sub_val.to(torch.int32) - - return out_batch - - -def create_ipu_dataloader( - dataset: Dataset, - ipu_dataloader_options: IPUDataloaderOptions, - ipu_options: Optional["poptorch.Options"] = None, - batch_size: Optional[int] = 1, - collate_fn=None, - num_workers: Optional[int] = 0, - **kwargs, -) -> "poptorch.DataLoader": - """ - Creates a poptorch.DataLoader for graph datasets - Applies the mini-batching method of concatenating multiple graphs into a - single graph with multiple disconnected subgraphs. See: - https://pytorch-geometric.readthedocs.io/en/2.0.2/notes/batching.html - - Parameters: - - dataset: The torch_geometric.data.Dataset instance from which to - load the graph examples for the IPU. - ipu_dataloader_options: The options to initialize the Dataloader for IPU - ipu_options: The poptorch.Options used by the - poptorch.DataLoader. Will use the default options if not provided. - batch_size: How many graph examples to load in each batch - (default: 1). - collate_fn: The function used to collate batches - **kwargs (optional): Additional arguments of :class:`poptorch.DataLoader`. - - Returns: - The dataloader - """ - poptorch = import_poptorch() - - if ipu_options is None: - # Create IPU default options - ipu_options = poptorch.Options() - - # Define the collater function - collater = CombinedBatchingCollator( - batch_size, - collate_fn=collate_fn, - max_num_nodes=ipu_dataloader_options.max_num_nodes, - max_num_edges=ipu_dataloader_options.max_num_edges, - dataset_max_nodes_per_graph=dataset.max_num_nodes_per_graph, - dataset_max_edges_per_graph=dataset.max_num_edges_per_graph, - ) - - # Get the global batch size - num_nodes = np.asarray(dataset.num_nodes_list) - accum = ipu_options.Training.gradient_accumulation - repli = ipu_options._values["replication_factor"] - device_iter = ipu_options._values["device_iterations"] - combined_batch_size = batch_size * accum * repli * device_iter - num_batches = len(dataset) // combined_batch_size - num_workers = min(num_batches, num_workers) - buffer_size = num_batches // num_workers if num_workers > 0 else None - buffer_size = 3 if buffer_size is None else buffer_size - async_options = { - "sharing_strategy": poptorch.SharingStrategy.ForkServer, - "early_preload": True, - "buffer_size": buffer_size, - "load_indefinitely": True, - "miss_sleep_time_in_ms": 0, - } - - # Estimate the packing size needed - max_pack_size, max_pack_size_per_graph = 0, 0 - for _ in range(4): - this_max_pack_size, this_max_pack_size_per_graph = estimate_max_pack_node_size( - num_nodes=num_nodes, - batch_size=batch_size, - combined_batch_size=combined_batch_size, - ) - max_pack_size = max(max_pack_size, this_max_pack_size) - max_pack_size_per_graph = max(max_pack_size_per_graph, this_max_pack_size_per_graph) - - max_num_nodes = collater.max_num_nodes - # Log the estimated pack size, with warnings if too big or too small - logger.info( - f"Estimating pack max_pack_size={max_pack_size} or max_pack_size_per_graph={max_pack_size_per_graph}" - ) - logger.info(f"Provided `max_num_nodes={max_num_nodes}`") - if max_pack_size > max_num_nodes - 10: - logger.warning( - f"The value of `max_num_nodes={max_num_nodes}` seems to be insufficient compared to `max_pack_size={max_pack_size}` and will likely crash" - ) - elif max_pack_size < max_num_nodes - 20: - logger.warning( - f"The value of `max_num_nodes={max_num_nodes}` seems to be large compared to `max_pack_size={max_pack_size}` and will likely waste memory" - ) - - return poptorch.DataLoader( - options=deepcopy(ipu_options), - dataset=dataset, - batch_size=batch_size, - num_workers=num_workers, - collate_fn=collater, - async_options=async_options, - **kwargs, - ) - - -class Pad(BaseTransform): - """ - Data transform that applies padding to enforce consistent tensor shapes. - """ - - def __init__( - self, - max_num_nodes: int, - dataset_max_nodes_per_graph, - dataset_max_edges_per_graph, - max_num_edges: Optional[int] = None, - node_value: float = 0, - edge_value: float = 0, - ): - """ - Parameters: - max_num_nodes: The maximum number of nodes for the total padded graph - dataset_max_nodes_per_graph: the maximum number of nodes per graph in the dataset - dataset_max_edges_per_graph: the maximum number of edges per graph in the dataset - max_num_edges: The maximum number of edges for the total padded graph - node_value: Value to add to the node padding - edge_value: Value to add to the edge padding - """ - super().__init__() - self.max_num_nodes = max_num_nodes - self.dataset_max_nodes_per_graph = dataset_max_nodes_per_graph - self.dataset_max_edges_per_graph = dataset_max_edges_per_graph - - if max_num_edges: - self.max_num_edges = max_num_edges - else: - # Assume fully connected graph - self.max_num_edges = max_num_nodes * (max_num_nodes - 1) - - self.node_value = node_value - self.edge_value = edge_value - - def validate(self, data): - """ - Validates that the input graph does not exceed the constraints that: - - * the number of nodes must be <= max_num_nodes - * the number of edges must be <= max_num_edges - - Returns: - Tuple containing the number nodes and the number of edges - """ - num_nodes = data.num_nodes - num_edges = data.num_edges - - assert num_nodes <= self.max_num_nodes, ( - f"Too many nodes. Graph has {num_nodes} nodes " f"and max_num_nodes is {self.max_num_nodes}." - ) - - assert num_edges <= self.max_num_edges, ( - f"Too many edges. Graph has {num_edges} edges defined " - f"and max_num_edges is {self.max_num_edges}." - ) - - return num_nodes, num_edges - - def __call__(self, batch: Batch) -> Batch: - return self._call(batch) - - def forward(self, batch: Batch) -> Batch: - return self._call(batch) - - def _call(self, batch: Batch) -> Batch: - """ - Pad the batch with a fake graphs that has the desired - number of nodes and edges. - """ - num_nodes, num_edges = self.validate(batch) - num_pad_nodes = self.max_num_nodes - num_nodes - num_pad_edges = self.max_num_edges - num_edges - # Create a copy to update with padded features - new_batch = deepcopy(batch) - - real_graphs = new_batch.to_data_list() - - for g in real_graphs: - g.graph_is_true = torch.tensor([1], dtype=bool) - g.node_is_true = torch.full([g.num_nodes], True, dtype=bool) - g.edge_is_true = torch.full([g.num_edges], True, dtype=bool) - - # create fake graph with the needed # of nodes and edges - fake = Data() - fake.num_nodes = num_pad_nodes - fake.num_edges = num_pad_edges - fake.graph_is_true = torch.tensor([False], dtype=bool) - fake.node_is_true = torch.full([num_pad_nodes], False, dtype=bool) - fake.edge_is_true = torch.full([num_pad_edges], False, dtype=bool) - - for key, value in real_graphs[0]: - if not torch.is_tensor(value): - continue - - if key == "graph_is_true" or key == "node_is_true" or key == "edge_is_true": - continue - - dim = real_graphs[0].__cat_dim__(key, value) - pad_shape = list(value.shape) - - if batch.is_node_attr(key): - pad_shape[dim] = num_pad_nodes - pad_value = self.node_value - elif batch.is_edge_attr(key): - pad_shape[dim] = num_pad_edges - if key == "edge_index": - # Padding edges are self-loops on the first padding node - pad_value = 0 - else: - pad_value = self.edge_value - # identify graph attributes, pad nan label for the fake graph - elif key.startswith("graph_"): - num_pad_graphs = 1 # we pad with one big fake graph - pad_shape[dim] = num_pad_graphs - pad_value = float("nan") - else: - continue - - pad_value = value.new_full(pad_shape, pad_value) - fake[key] = torch.cat([pad_value], dim=dim) - real_graphs.append(fake) - new_batch = Batch.from_data_list(real_graphs) - - if "num_nodes" in new_batch: - new_batch.num_nodes = self.max_num_nodes - - return new_batch - - def __repr__(self) -> str: - s = f"{self.__class__.__name__}(" - s += f"max_num_nodes={self.max_num_nodes}, " - s += f"max_num_edges={self.max_num_edges}, " - s += f"node_value={self.node_value}, " - s += f"edge_value={self.edge_value})" - return s diff --git a/graphium/ipu/ipu_losses.py b/graphium/ipu/ipu_losses.py deleted file mode 100644 index 6bc434ae4..000000000 --- a/graphium/ipu/ipu_losses.py +++ /dev/null @@ -1,196 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -import torch -from torch import Tensor -from torch.nn import BCELoss, BCEWithLogitsLoss, MSELoss, L1Loss -from torch._C import _infer_size -from loguru import logger -from graphium.trainer.losses import HybridCELoss - - -class BCEWithLogitsLossIPU(BCEWithLogitsLoss): - """ - A modified version of the `torch.nn.BCEWithLogitsLoss` that can ignore NaNs - by giving them a weight of `0`. This allows it to work with compilation - and IPUs since it doesn't modify the tensor's shape. - """ - - def forward(self, input: Tensor, target: Tensor) -> Tensor: - prev_weight = None - - target = target.clone().to(input.dtype) - weight = self.weight - - # Get the original weight matrix. If None, set all weights = 1 - if weight is not None: - prev_weight = self.weight.clone() - new_size = _infer_size(target.size(), weight.size()) - weight = weight.expand(new_size).clone() - else: - weight = torch.ones(target.shape, dtype=input.dtype, device=input.device) - - # Replace the nan-targets by 0 or 1. Take the value closest to the input. - # Give a weight of 0 where there are nan-targets - nan_targets = target.isnan() - nan_targets_0 = (input < 0.5) & nan_targets - nan_targets_1 = (input >= 0.5) & nan_targets - target[nan_targets_0] = 0.0 - target[nan_targets_1] = 1.0 - weight[nan_targets] = 0.0 - - # Compute the loss, and rescale by the number of nan elements - self.weight = weight - loss = super().forward(input, target) - - num_real_targets = (~nan_targets).sum() - factor1 = torch.where(num_real_targets > 0, 1, 0) - factor2 = torch.where(num_real_targets > 0, 0, 1) - loss = factor1 * loss * nan_targets.numel() / (num_real_targets + factor2) - - # Reset the self.weight to its original value - self.weight = prev_weight - - return loss - - -class BCELossIPU(BCELoss): - """ - A modified version of the `torch.nn.BCELoss` that can ignore NaNs - by giving them a weight of `0`. This allows it to work with compilation - and IPUs since it doesn't modify the tensor's shape. - """ - - def forward(self, input: Tensor, target: Tensor) -> Tensor: - prev_weight = None - - target = target.clone().to(input.dtype) - weight = self.weight - - # Get the original weight matrix. If None, set all weights = 1 - if weight is not None: - prev_weight = self.weight.clone() - new_size = _infer_size(target.size(), weight.size()) - weight = weight.expand(new_size).clone() - else: - weight = torch.ones(target.shape, dtype=input.dtype, device=input.device) - - # Replace the nan-targets by 0 or 1. Take the value closest to the input. - # Give a weight of 0 where there are nan-targets - nan_targets = target.isnan() - nan_targets_0 = (input < 0.5) & nan_targets - nan_targets_1 = (input >= 0.5) & nan_targets - target[nan_targets_0] = 0.0 - target[nan_targets_1] = 1.0 - weight[nan_targets] = 0.0 - - # Compute the loss, and rescale by the number of nan elements - self.weight = weight - loss = super().forward(input, target) - - num_real_targets = (~nan_targets).sum() - factor1 = torch.where(num_real_targets > 0, 1, 0) - factor2 = torch.where(num_real_targets > 0, 0, 1) - loss = factor1 * loss * nan_targets.numel() / (num_real_targets + factor2) - - # Reset the self.weight to its original value - self.weight = prev_weight - - return loss - - -class MSELossIPU(MSELoss): - """ - A modified version of the `torch.nn.MSELoss` that can ignore NaNs - by giving them the same value for both `input` and `target`. - This allows it to work with compilation - and IPUs since it doesn't modify the tensor's shape. - """ - - def forward(self, input: Tensor, target: Tensor) -> Tensor: - target = target.clone().to(input.dtype) - input = input.clone() - - # Replace the nan-targets in the input/target tensors by 0 - nan_targets = target.isnan() - input[nan_targets] = 0.0 - target[nan_targets] = 0.0 - - # Compute the loss, and rescale by the number of nan elements - loss = super().forward(input, target) - - num_real_targets = (~nan_targets).sum() - factor1 = torch.where(num_real_targets > 0, 1, 0) - factor2 = torch.where(num_real_targets > 0, 0, 1) - loss = factor1 * loss * nan_targets.numel() / (num_real_targets + factor2) - - return loss - - -class L1LossIPU(L1Loss): - """ - A modified version of the `torch.nn.L1Loss` that can ignore NaNs - by giving them the same value for both `input` and `target`. - This allows it to work with compilation - and IPUs since it doesn't modify the tensor's shape. - """ - - def forward(self, input: Tensor, target: Tensor) -> Tensor: - target = target.clone().to(input.dtype) - input = input.clone() - - # Replace the nan-targets in the input/target tensors by 0 - nan_targets = target.isnan() - input[nan_targets] = 0.0 - target[nan_targets] = 0.0 - - # Compute the loss, and rescale by the number of nan elements - loss = super().forward(input, target) - num_real_targets = (~nan_targets).sum() - factor1 = torch.where(num_real_targets > 0, 1, 0) - factor2 = torch.where(num_real_targets > 0, 0, 1) - loss = factor1 * loss * nan_targets.numel() / (num_real_targets + factor2) - - return loss - - -class HybridCELossIPU(HybridCELoss): - def __init__( - self, - n_brackets, - alpha: float = 0.5, - ) -> None: - """ - Parameters: - n_brackets: the number of brackets that will be used to group the regression targets. - Expected to have the same size as the number of classes in the transformed regression task. - """ - super().__init__(n_brackets=n_brackets, alpha=alpha) - - def forward(self, input: Tensor, target: Tensor) -> Tensor: - """ - Parameters: - input: (batch_size x n_classes) tensor of logits predicted for each bracket. - target: (batch_size) or (batch_size, 1) tensor of target brackets in {0, 1, ..., self.n_brackets}. - """ - - target = target.clone().to(input.dtype) - input = input.clone() - - # Replace the nan-targets in the input/target tensors by 0 - nan_targets = target.isnan() - - # Compute the loss, and rescale by the number of nan elements - loss = super().forward(input, target, nan_targets) - return loss diff --git a/graphium/ipu/ipu_metrics.py b/graphium/ipu/ipu_metrics.py deleted file mode 100644 index 9029d3e00..000000000 --- a/graphium/ipu/ipu_metrics.py +++ /dev/null @@ -1,907 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Optional, Tuple, Sequence, Literal - -import torch -from torch import BoolTensor, IntTensor, Tensor -from torchmetrics.functional import auroc, average_precision, pearson_corrcoef, r2_score -from torchmetrics.utilities.checks import _input_squeeze -from torchmetrics.functional.classification.accuracy import ( - _mode, - _check_subset_validity, - _accuracy_compute, - _accuracy_update, -) -from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute -from torchmetrics.functional.classification.f_beta import _fbeta_compute -from torchmetrics.functional import mean_squared_error, mean_absolute_error -from torchmetrics.utilities.checks import _input_squeeze -from torchmetrics.utilities.enums import AverageMethod - -from graphium.utils.tensor import nan_mean -from graphium.ipu.ipu_utils import import_poptorch - - -def auroc_ipu( - preds: Tensor, - target: Tensor, - num_classes: Optional[int] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, - pos_label: Optional[int] = None, - average: Optional[str] = "macro", - max_fpr: Optional[float] = None, - sample_weights: Optional[Sequence] = None, -): - """ - A modified version of the `torchmetrics.functional.auroc` that can ignore NaNs - by giving them the same value for both `preds` and `target`. - This allows it to work with compilation - and IPUs since it doesn't modify the tensor's shape. - """ - - target = target.clone() - preds = preds.clone() - - # Replace the nan-targets in the preds/target tensors by 0 - nan_targets = target.isnan() - preds[nan_targets] = 0.0 - target[nan_targets] = 0.0 - - # Get the original weight matrix. If None, set all weights = 1 - if sample_weights is None: - sample_weights = torch.ones(target.shape[0], dtype=preds.dtype, device=preds.device) - sample_weights[nan_targets] = 0.0 - - # Compute the loss, and rescale by the number of nan elements - score = auroc( - preds=preds, - target=target.to(int), - num_classes=num_classes, - task=task, - pos_label=pos_label, - average=average, - max_fpr=max_fpr, - sample_weights=sample_weights, - ) - - return score - - -def average_precision_ipu( - preds: Tensor, - target: Tensor, - num_classes: Optional[int] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, - ignore_index: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[str] = "macro", - sample_weights: Optional[Sequence] = None, -): - """ - A modified version of the `torchmetrics.functional.average_precision` that can ignore NaNs - by giving them the same value for both `preds` and `target`. - This allows it to work with compilation - and IPUs since it doesn't modify the tensor's shape. - """ - - target = target.clone() - preds = preds.clone() - - # Replace the nan-targets in the preds/target tensors by 0 - # Average precision is not sensitive to true negatives - nan_targets = target.isnan() - preds[nan_targets] = 0.0 - target[nan_targets] = 0.0 - - # No need to use sample weights (which is no longer supported in torchmetrics >=0.10) - # # Get the original weight matrix. If None, set all weights = 1 - # if sample_weights is None: - # sample_weights = torch.ones(target.shape[0], dtype=preds.dtype, device=preds.device) - # sample_weights[nan_targets] = 0.0 - - # Compute the loss, and rescale by the number of nan elements - score = average_precision( - preds=preds, - target=target, - num_classes=num_classes, - task=task, - ignore_index=ignore_index, - pos_label=pos_label, - average=average, - # sample_weights=sample_weights, - ) - - return score - - -def precision_ipu( - preds: Tensor, - target: Tensor, - average: Optional[str] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, - threshold: float = 0.5, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, -): - """ - A modified version of the `torchmetrics.functional.precision` that can ignore NaNs - by giving them the same value for both `preds` and `target`. - This allows it to work with compilation - and IPUs since it doesn't modify the tensor's shape. - """ - - (tp, fp, tn, fn), mode = get_confusion_matrix( - preds=preds, - target=target, - average=average, - mdmc_average=mdmc_average, - threshold=threshold, - top_k=top_k, - subset_accuracy=False, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - ) - - return _precision_compute(tp, fp, fn, average, mdmc_average) - - -def recall_ipu( - preds: Tensor, - target: Tensor, - average: Optional[str] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, - threshold: float = 0.5, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, -): - """ - A modified version of the `torchmetrics.functional.recall` that can ignore NaNs - by giving them the same value for both `preds` and `target`. - This allows it to work with compilation - and IPUs since it doesn't modify the tensor's shape. - """ - - (tp, fp, tn, fn), mode = get_confusion_matrix( - preds=preds, - target=target, - average=average, - mdmc_average=mdmc_average, - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - ) - - return _recall_compute(tp, fp, fn, average, mdmc_average) - - -def accuracy_ipu( - preds: Tensor, - target: Tensor, - average: Optional[str] = "micro", - mdmc_average: Optional[str] = "global", - threshold: float = 0.5, - top_k: Optional[int] = None, - subset_accuracy: bool = False, - num_classes: Optional[int] = None, - multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, -) -> Tensor: - """ - A modified version of the `torchmetrics.functional.accuracy` that can ignore NaNs - by giving them the same value for both `preds` and `target`. - This allows it to work with compilation - and IPUs since it doesn't modify the tensor's shape. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth labels - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - subset_accuracy: - Whether to compute subset accuracy for multi-label and multi-dimensional - multi-class inputs (has no effect for other input types). - - - For multi-label inputs, if the parameter is set to ``True``, then all labels for - each sample must be correctly predicted for the sample to count as correct. If it - is set to ``False``, then all labels are counted separately - this is equivalent to - flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). - - - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all - sub-sample (on the extra axis) must be correct for the sample to be counted as correct. - If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, - in the case of label predictions, to flattening the inputs beforehand (i.e. - ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter - still applies in both cases, if set. - - Raises: - ValueError: - If ``top_k`` parameter is set for ``multi-label`` inputs. - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - ValueError: - If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set - and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. - """ - - (tp, fp, tn, fn), mode = get_confusion_matrix( - preds=preds, - target=target, - average=average, - mdmc_average=mdmc_average, - threshold=threshold, - top_k=top_k, - subset_accuracy=subset_accuracy, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - ) - - return _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode) - - -def get_confusion_matrix( - preds: Tensor, - target: Tensor, - average: Optional[str] = "micro", - mdmc_average: Optional[str] = "global", - threshold: float = 0.5, - top_k: Optional[int] = None, - subset_accuracy: bool = False, - num_classes: Optional[int] = None, - multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, -) -> Tuple[Tuple[Tensor], Tensor]: - """ - Calculates the confusion matrix according to the specified average method. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth labels - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - """ - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - allowed_mdmc_average = [None, "samplewise", "global"] - if mdmc_average not in allowed_mdmc_average: - raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): - raise ValueError( - f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes" - ) - - if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): - raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") - - #### ADDED #### - # Put all the NaNs as the 0-class - nans = torch.isnan(target) - target[nans] = 0 - preds[nans] = 0 - if (preds.ndim > 1) and (preds.shape[1] > 1): - preds[nans, 0] = 1 - target = target.to(int) - #### END ADDED #### - - preds, target = _input_squeeze(preds, target) - mode = _mode(preds, target, threshold, top_k, num_classes, multiclass, ignore_index) - reduce = "macro" if average in ["weighted", "none", None] else average - - if subset_accuracy and _check_subset_validity(mode): - # correct, total = _subset_accuracy_update(preds, target, threshold, top_k, ignore_index) - # return _subset_accuracy_compute(correct, total) - raise NotImplementedError("subset_accuracy not implemented") - tp, fp, tn, fn = _accuracy_update( - preds, target, reduce, mdmc_average, threshold, num_classes, top_k, multiclass, ignore_index, mode - ) - - #### ADDED #### - num_nans = nans.sum(0) - if tp.numel() > 1: - tp[0] = tp[0] - num_nans - tn[1:] = tn[1:] - num_nans - else: - tn = tn - num_nans - if (preds.ndim > 1) and (preds.shape[1] > 1): - tp = tp - num_nans - #### END ADDED #### - - return (tp, fp, tn, fn), mode - - -class NaNTensor(Tensor): - """ - Class to create and manage a NaN tensor along it's properties - - The goal of the class is to override the regular tensor such that the basic - operations (sum, mean, max, etc) ignore the NaNs in the input. - It also supports NaNs in integer tensors (as the lowest integer possible). - """ - - @property - def get_nans(self) -> BoolTensor: - """ - Gets the boolean Tensor containing the location of NaNs. - In the case of an integer tensor, this returns where the tensor is equal to its minimal value - In the case of a boolean tensor, this returns a Tensor filled with `False` - """ - if self.is_floating_point(): - return self.isnan() - elif self.is_signed(): - return self == torch.iinfo(self.dtype).min - else: - return torch.zeros(self.shape, device=self.device, dtype=bool) - - def sum(self, *args, **kwargs) -> Tensor: - """ - Overloads the traditional sum to ignore the NaNs - """ - tensor = self.to(float) - tensor[self.get_nans] = float("nan") - if self.is_floating_point(): - dtype = self.dtype - else: - dtype = torch.int64 - return tensor.nansum(*args, **kwargs).to(dtype) - - def mean(self, *args, **kwargs) -> Tensor: - """ - Overloads the traditional mean to ignore the NaNs - """ - tensor = self.to(float) - tensor[self.get_nans] = float("nan") - return nan_mean(tensor, *args, **kwargs).to(self.dtype) - - def numel(self) -> int: - """ - Returns the number of non-NaN elements. - """ - return super(NaNTensor, ~self.get_nans).sum() - - def min(self, *args, **kwargs) -> Tensor: - """ - Returns the min vale of a tensor whitout NaNs - """ - tensor = self - tensor = tensor[~self.get_nans] - return super(NaNTensor, tensor).min(*args, **kwargs) - - def max(self, *args, **kwargs) -> Tensor: - """ - Returns the max vale of a tensor whitout NaNs - """ - tensor = self - tensor = tensor[~self.get_nans] - return super(NaNTensor, tensor).max(*args, **kwargs) - - def argsort(self, dim=-1, descending=False) -> IntTensor: - """ - Return the indices that sort the tensor, while putting all the NaNs to the end of the sorting. - """ - tensor = self - if descending: - tensor[tensor.get_nans] = float("-inf") - else: - tensor[tensor.get_nans] = float("inf") - return super(NaNTensor, tensor).argsort(dim=dim, descending=descending) - - def size(self, dim) -> Tensor: - """ - Instead of returning the size, return the number of non-NaN elements in - a specific dimension. Useful for the `r2_score` metric. - """ - return (~self.get_nans).sum(dim=dim) - - def __lt__(self, other) -> Tensor: - """ - Stupid fix that allows the code to work with `r2_score`, - since it requires the size to be > 2. But since `self.size` now returns - a Tensor instead of a value, we check that all elements are > 2. - """ - if (not isinstance(other, Tensor)) and (other == 2): - return super().__lt__(other).all() - else: - return super().__lt__(other) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - """ - This __torch_function__ implementation wraps subclasses such that - methods called on subclasses return a subclass instance instead of - a ``torch.Tensor`` instance. - - One corollary to this is that you need coverage for torch.Tensor - methods if implementing __torch_function__ for subclasses. - - Affects the call torch.sum() as to behave the same way as NaNTensor.sum() - - We recommend always calling ``super().__torch_function__`` as the base - case when doing the above. - - While not mandatory, we recommend making `__torch_function__` a classmethod. - """ - if func.__name__ == "sum": - kwargs = {} if kwargs is None else kwargs - return args[0].sum(*args[1:], **kwargs) - else: - return super().__torch_function__(func, types, args=args, kwargs=kwargs) - - -def pearson_ipu(preds, target): - """Computes pearson correlation coefficient. - - Handles NaNs in the target without reshaping tensors in order to work on IPU. - - Args: - preds: estimated scores - target: ground truth scores - """ - preds = NaNTensor(preds) - target = NaNTensor(target) - preds[target.get_nans] = float("nan") - pearson = pearson_corrcoef(preds, target.to(preds.dtype)) - return Tensor(pearson) - - -def spearman_ipu(preds, target): - """Computes spearman rank correlation coefficient. - - Handles NaNs in the target without reshaping tensors in order to work on IPU. - - Args: - preds: estimated scores - target: ground truth scores - """ - nans = target.isnan() - dtype = preds.dtype - preds[nans] = float("inf") - target[nans] = float("inf") - preds_sort = _rank_data(preds).to(dtype=dtype) - target_sort = _rank_data(target).to(dtype=dtype) - target_sort[nans] = float("nan") - spearman = pearson_ipu(preds_sort, target_sort) - return Tensor(spearman) - - -def _rank_data(data: Tensor) -> Tensor: - """Calculate the rank for each element of a tensor. - - The rank refers to the indices of an element in the corresponding sorted tensor (starting from 1). - Duplicates of the same value will be assigned the mean of their rank. - - Adopted from `Rank of element tensor`_ - """ - n = data.numel() - rank = torch.empty_like(data) - idx = data.argsort() - rank[idx] = torch.arange(1, n + 1, dtype=data.dtype, device=data.device) - - # TODO: Repeats not yet supported - # repeats = _find_repeats(data) - # for r in repeats: - # condition = data == r - # rank[condition] = rank[condition].mean() - return rank - - -def r2_score_ipu(preds, target, *args, **kwargs) -> Tensor: - """ - Computes r2 score also known as `R2 Score_Coefficient Determination`_: - - .. math:: R^2 = 1 - \frac{SS_{res}}{SS_{tot}} - - where :math:`SS_{res}=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and - :math:`SS_{tot}=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate - adjusted r2 score given by - - .. math:: R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1} - - where the parameter :math:`k` (the number of independent regressors) should - be provided as the ``adjusted`` argument. - Handles NaNs without reshaping tensors in order to work on IPU. - - Args: - preds: estimated labels - target: ground truth labels - adjusted: number of independent regressors for calculating adjusted r2 score. - multioutput: Defines aggregation in the case of multiple output scores. Can be one of the following strings: - - * ``'raw_values'`` returns full set of scores - * ``'uniform_average'`` scores are uniformly averaged - * ``'variance_weighted'`` scores are weighted by their individual variances - """ - preds = NaNTensor(preds) - target = NaNTensor(target) - preds[target.get_nans] = float("nan") - score = r2_score(preds, target, *args, **kwargs) - return Tensor(score) - - -def fbeta_score_ipu( - preds: Tensor, - target: Tensor, - beta: float = 1.0, - average: Optional[str] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, - threshold: float = 0.5, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, -): - """ - A modified version of the `torchmetrics.functional.classification.f_beta._fbeta_compute` - that can ignore NaNs by giving them the same value for both `preds` and `target`. - This allows it to work with compilation - and IPUs since it doesn't modify the tensor's shape. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth labels - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - (see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - (see :ref:`pages/classification:input types`) - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. See the parameter's - :ref:`documentation section ` - for a more detailed explanation and examples. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - subset_accuracy: - Whether to compute subset accuracy for multi-label and multi-dimensional - multi-class inputs (has no effect for other input types). - - - For multi-label inputs, if the parameter is set to ``True``, then all labels for - each sample must be correctly predicted for the sample to count as correct. If it - is set to ``False``, then all labels are counted separately - this is equivalent to - flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). - - - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all - sub-sample (on the extra axis) must be correct for the sample to be counted as correct. - If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, - in the case of label predictions, to flattening the inputs beforehand (i.e. - ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter - still applies in both cases, if set. - - Raises: - ValueError: - If ``top_k`` parameter is set for ``multi-label`` inputs. - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - ValueError: - If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set - and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. - """ - - (tp, fp, tn, fn), mode = get_confusion_matrix( - preds=preds, - target=target, - average=average, - mdmc_average=mdmc_average, - ignore_index=ignore_index, - num_classes=num_classes, - threshold=threshold, - top_k=top_k, - multiclass=multiclass, - ) - - b2 = beta**2 - fbeta = ((1 + b2) * tp) / ((1 + b2) * tp + b2 * fn + fp) - - if average in (None, "none", AverageMethod.NONE): - pass - elif average == AverageMethod.MICRO: - pass - elif average == AverageMethod.MACRO: - fbeta = fbeta.mean() - elif average == AverageMethod.WEIGHTED: - weights = tp + fn - fbeta = (weights * fbeta).sum() / weights.sum() - else: - raise ValueError( - f"`average={average}` not yet supported. Chose between None, Micro, Macro, or Weighted" - ) - - return fbeta - - -def f1_score_ipu( - preds: Tensor, - target: Tensor, - beta: float = 1.0, - average: Optional[str] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, - threshold: float = 0.5, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, -): - """ - A modified version of the `torchmetrics.functional.classification.f_beta._fbeta_compute` - that can ignore NaNs by giving them the same value for both `preds` and `target`. - Used to calculate the f1_score on IPU with beta parameter equal to 1.0 - This allows it to work with compilation and IPUs since it doesn't modify the tensor's shape. - - Computes f_beta metric from stat scores: true positives, false positives, true negatives, false negatives. - - Args: - tp: True positives - fp: False positives - tn: True negatives - fn: False negatives - beta: The parameter `beta` (which determines the weight of recall in the combined score) - ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method - average: Defines the reduction that is applied - mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter) - """ - - return fbeta_score_ipu( - preds, - target, - beta=beta, - average=average, - mdmc_average=mdmc_average, - ignore_index=ignore_index, - num_classes=num_classes, - threshold=threshold, - top_k=top_k, - multiclass=multiclass, - ) - - -def mean_squared_error_ipu(preds: Tensor, target: Tensor, squared: bool) -> Tensor: - """Computes mean squared error. - - Handles NaNs without reshaping tensors in order to work on IPU. - - Args: - preds: estimated labels - target: ground truth labels - squared: returns RMSE value if set to False - - Return: - Tensor with MSE - """ - target = target.clone() - preds = preds.clone() - - # Replace the nan-targets in the preds/target tensors by 0 - nan_targets = target.isnan() - preds[nan_targets] = 0.0 - target[nan_targets] = 0.0 - - # Compute the loss, and rescale by the number of nan elements - loss = mean_squared_error(preds, target, squared) - - if squared: - factor = nan_targets.numel() / ((~nan_targets).sum()) - else: - factor = (nan_targets.numel() / ((~nan_targets).sum())).sqrt() - - loss = loss * factor - - return loss - - -def mean_absolute_error_ipu(preds: Tensor, target: Tensor) -> Tensor: - """Computes mean absolute error. - - Handles NaNs without reshaping tensors in order to work on IPU. - - Args: - preds: estimated labels - target: ground truth labels - - Return: - Tensor with MAE - """ - target = target.clone() - preds = preds.clone() - - # Replace the nan-targets in the preds/target tensors by 0 - nan_targets = target.isnan() - preds[nan_targets] = 0.0 - target[nan_targets] = 0.0 - - # Compute the loss, and rescale by the number of nan elements - loss = mean_absolute_error(preds, target) - loss = loss * nan_targets.numel() / ((~nan_targets).sum()) - - return loss diff --git a/graphium/ipu/ipu_simple_lightning.py b/graphium/ipu/ipu_simple_lightning.py deleted file mode 100644 index b2fca086e..000000000 --- a/graphium/ipu/ipu_simple_lightning.py +++ /dev/null @@ -1,169 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -import lightning -from lightning_graphcore import IPUStrategy -from lightning.pytorch.loggers import WandbLogger - -import torch -from torch import nn - -import torchvision -import torchvision.transforms as transforms - -import mup - -from graphium.nn.base_layers import FCLayer -from graphium.utils.mup import set_base_shapes - - -ON_IPU = True # Change this line to run on CPU -SEED = 42 - - -# The simple PyTorch model used in each of these examples -class SimpleTorchModel(torch.nn.Module): - def __init__(self, in_dim, hidden_dim, kernel_size, num_classes): - super().__init__() - self.in_dim = in_dim - self.hidden_dim = hidden_dim - self.kernel_size = kernel_size - self.num_classes = num_classes - - conv_block = nn.Sequential( - nn.Conv2d(in_channels=in_dim, out_channels=hidden_dim, kernel_size=kernel_size), - nn.BatchNorm2d(hidden_dim), - nn.ReLU(), - nn.MaxPool2d(kernel_size), - nn.MaxPool2d(kernel_size), - ) - - self.the_network = nn.Sequential( - conv_block, - torch.nn.Flatten(), - FCLayer(4 * hidden_dim, hidden_dim), - FCLayer(hidden_dim, hidden_dim), - FCLayer(hidden_dim, num_classes, activation=None, is_readout_layer=True), - nn.LogSoftmax(1), - ) - - def make_mup_base_kwargs(self, divide_factor: float = 2.0): - return dict( - in_dim=self.in_dim, - hidden_dim=round(self.hidden_dim / divide_factor), - kernel_size=self.kernel_size, - num_classes=self.num_classes, - ) - - def forward(self, x): - return self.the_network(x) - - -# This class shows a minimal lightning example. This example uses our own -# SimpleTorchModel which is a basic 2 conv, 2 FC torch network. It can be -# found in simple_torch_model.py. -class SimpleLightning(lightning.LightningModule): - def __init__(self, in_dim, hidden_dim, kernel_size, num_classes, on_ipu): - super().__init__() - self.model = SimpleTorchModel( - in_dim=in_dim, hidden_dim=hidden_dim, kernel_size=kernel_size, num_classes=num_classes - ) - self.on_ipu = on_ipu - - def training_step(self, batch, _): - x, label = batch - prediction = self.model(x) - loss = torch.nn.functional.nll_loss(prediction, label) - return loss - - def validation_step(self, batch, _): - x, label = batch - prediction = self.model(x) - preds = torch.argmax(prediction, dim=1) - acc = torch.sum(preds == label).float() / len(label) - loss = torch.nn.functional.nll_loss(prediction, label) - return loss, acc - - # PopTorch doesn't currently support logging within steps. Use the Lightning - # callback hooks instead. - def on_train_batch_end(self, outputs, batch, batch_idx): - self.log("StepLoss", outputs["loss"]) - - def validation_epoch_end(self, outputs): - loss = [out[0] for out in outputs] - self.log("val_loss", torch.stack(loss).mean(), prog_bar=True) - - acc = [out[1] for out in outputs] - self.log("val_acc", torch.stack(acc).mean(), prog_bar=True) - - def configure_optimizers(self): - adam = torch.optim.Adam - - if self.on_ipu: - import poptorch - - adam = poptorch.optim.Adam - - optimizer = mup.MuAdam(self.parameters(), lr=0.01, impl=adam) - return optimizer - - -if __name__ == "__main__": - torch.manual_seed(SEED) - - # Create the model as usual. - predictor = SimpleLightning(in_dim=1, hidden_dim=32, kernel_size=3, num_classes=10, on_ipu=ON_IPU) - model = predictor.model - base = model.__class__(**model.make_mup_base_kwargs(divide_factor=2)) - predictor.model = set_base_shapes(model, base, rescale_params=False) - - torch.manual_seed(SEED) - # Normal PyTorch dataset. - train_set = torchvision.datasets.FashionMNIST( - "out/FashionMNIST", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]) - ) - val_set = torchvision.datasets.FashionMNIST( - "out/FashionMNIST", train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]) - ) - - # Normal PyTorch dataloader. - train_loader = torch.utils.data.DataLoader(train_set, batch_size=16, shuffle=True) - val_loader = torch.utils.data.DataLoader(val_set, batch_size=16, shuffle=False) - - torch.manual_seed(SEED) - - ipus = None - plugins = None - if ON_IPU: - import poptorch - - training_opts = poptorch.Options() - inference_opts = poptorch.Options() - - # Set the seeds - training_opts.randomSeed(SEED) - inference_opts.randomSeed(SEED) - ipus = 1 - strategy = IPUStrategy(training_opts=training_opts, inference_opts=inference_opts) - - trainer = lightning.Trainer( - logger=WandbLogger(), - ipus=ipus, - max_epochs=3, - log_every_n_steps=1, - plugins=plugins, - ) - - # When fit is called the model will be compiled for IPU and will run on the available IPU devices. - trainer.fit(predictor, train_dataloaders=train_loader, val_dataloaders=val_loader) diff --git a/graphium/ipu/ipu_utils.py b/graphium/ipu/ipu_utils.py deleted file mode 100644 index c5140ecb5..000000000 --- a/graphium/ipu/ipu_utils.py +++ /dev/null @@ -1,162 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -import os -import tempfile -from datetime import datetime -from copy import deepcopy -from types import ModuleType -from typing import Optional, Tuple, List -import torch - - -def import_poptorch(raise_error=True) -> Optional[ModuleType]: - """ - Import poptorch and returns it. - It is wrapped in a function to avoid breaking the code - for non-IPU devices which did not install poptorch. - - Parameters: - raise_error: Whether to raise an error if poptorch is unavailable. - If `False`, return `None` - - Returns: - The poptorch module - - """ - try: - import poptorch - - return poptorch - except ImportError as e: - if raise_error: - raise e - return - - -def is_running_on_ipu() -> bool: - """ - Returns whether the current module is running on ipu. - Needs to be used in the `forward` or `backward` pass. - """ - poptorch = import_poptorch(raise_error=False) - on_ipu = (poptorch is not None) and (poptorch.isRunningOnIpu()) - return on_ipu - - -def load_ipu_options( - ipu_opts: List[str], - seed: Optional[int] = None, - model_name: Optional[str] = None, - gradient_accumulation: Optional[int] = None, - precision: Optional[int] = None, - ipu_inference_opts: Optional[List[str]] = None, -) -> Tuple["poptorch.Options", "poptorch.Options"]: - """ - Load the IPU options from the config file. - - Parameters: - ipu_cfg: The list configurations for the IPU, written as a list of strings to make use of `poptorch.Options.loadFromFile` - - write a temporary config gile, and read it. See `Options.loadFromFile` - #? see the tutorial for IPU options here - # https://github.com/graphcore/tutorials/tree/sdk-release-2.6/tutorials/pytorch/efficient_data_loading - #? see the full documentation for ipu options here - # https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/reference.html?highlight=options#poptorch.Options - - ***minibatch size***: The number of samples processed by one simple fwd/bwd pass. - = # of samples in a minibatch - - ***device iterations***: A device iteration corresponds to one iteration of the training loop executed on the IPU, starting with data-loading and ending with a weight update. - In this simple case, when we set n deviceIterations, the host will prepare n mini-batches in an infeed queue so the IPU can perform efficiently n iterations. - = # of minibatches to be processed at a time - = # of training / backward pass in this call - - ***gradient accumulation factor***: After each backward pass the gradients are accumulated together for K mini-batches. set K in the argument - = # of minibatches to accumulate gradients from - - ***replication factor***: Replication describes the process of running multiple instances of the same model simultaneously on different IPUs to achieve data parallelism. - If the model requires N IPUs and the replication factor is M, N x M IPUs will be necessary. - = # of times the model is copied to speed up computation, each replica of the model is sent a different subset of the dataset - - ***global batch size***: In a single device iteration, many mini-batches may be processed and the resulting gradients accumulated. - We call this total number of samples processed for one optimiser step the global batch size. - = total number of samples processed for *one optimiser step* - = (minibatch size x Gradient accumulation factor) x Number of replicas - - seed: random seed for the IPU - model_name: Name of the model, to be used for ipu profiling - ipu_inference_opts: optional IPU configuration overrides for inference. - If this is provided, options in this file override those in `ipu_file` for inference. - - Returns: - - training_opts: IPU options for the training set. - - inference_opts: IPU options for inference. - It differs from the `training_opts` by enforcing `gradientAccumulation` to 1 - - """ - - poptorch = import_poptorch() - ipu_options = poptorch.Options() - ipu_opts_file = ipu_options_list_to_file(ipu_opts) - ipu_options.loadFromFile(ipu_opts_file.name) - ipu_opts_file.close() - - ipu_options.outputMode(poptorch.OutputMode.All) - if seed is not None: - ipu_options.randomSeed(seed) - if model_name is not None: - ipu_options.modelName(f"{model_name}_train") - if gradient_accumulation is not None: - current = ipu_options.Training.gradient_accumulation - assert (current == 1) or ( - current == gradient_accumulation - ), f"Received inconsistent gradient accumulation `{current}` and `{gradient_accumulation}" - ipu_options.Training.gradientAccumulation(gradient_accumulation) - - if precision == "16-true": - # IPUOptions.loadFromFile currently doesn't support setting half partials, doing it here - ipu_options.Precision.setPartialsType(torch.half) - training_opts = ipu_options - - # Change the inference options to remove gradient accumulation - inference_opts = deepcopy(ipu_options) - inference_opts.Training.gradientAccumulation(1) - if ipu_inference_opts is not None: - ipu_inference_opts_file = ipu_options_list_to_file(ipu_inference_opts) - inference_opts.loadFromFile(ipu_inference_opts_file.name) - ipu_inference_opts_file.close() - - return training_opts, inference_opts - - -def ipu_options_list_to_file(ipu_opts: Optional[List[str]]) -> tempfile._TemporaryFileWrapper: - """ - Create a temporary file from a list of ipu configs, such that it can be read by `poptorch.Options.loadFromFile` - - Parameters: - ipu_opts: The list configurations for the IPU, written as a list of strings to make use of `poptorch.Options.loadFromFile` - Returns: - tmp_file: The temporary file of ipu configs - """ - if ipu_opts is None: - return - - tmp_file = tempfile.NamedTemporaryFile("w", delete=True) - for s in ipu_opts: - tmp_file.write(s + "\n") - tmp_file.flush() - return tmp_file diff --git a/graphium/ipu/ipu_wrapper.py b/graphium/ipu/ipu_wrapper.py deleted file mode 100644 index 0ac04b883..000000000 --- a/graphium/ipu/ipu_wrapper.py +++ /dev/null @@ -1,235 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Dict, Any, Optional, Callable, Union, Type, Tuple, Iterable - -from torch_geometric.data import Batch -from torch import Tensor -from lightning_graphcore import IPUStrategy -from lightning.pytorch.utilities.types import STEP_OUTPUT -from lightning.pytorch.trainer.states import RunningStage - -from graphium.trainer.predictor import PredictorModule -from graphium.ipu.ipu_utils import import_poptorch - -import torch -from torch_geometric.data import Data, Batch -from torch_geometric.data.data import BaseData -from loguru import logger -import functools -import collections -from graphium.data.utils import get_keys - -poptorch = import_poptorch() - - -class PyGArgsParser(poptorch.ICustomArgParser): - """ - This class is responsible for converting a PyG Batch from and to - a tensor of tuples. This allows PyG Batch to be used as inputs to - IPU programs. Copied from poppyg repo, in the future import from - the repo directly. - """ - - @staticmethod - def sortedTensorKeys(struct: BaseData) -> Iterable[str]: - """ - Find all the keys that map to a tensor value in struct. The keys - are returned in sorted order. - """ - all_keys = sorted(get_keys(struct)) - - def isTensor(k: str) -> bool: - return isinstance(struct[k], torch.Tensor) - - return filter(isTensor, all_keys) - - def yieldTensors(self, struct: BaseData): - """ - yield every torch.Tensor in struct in sorted order - """ - for k in self.sortedTensorKeys(struct): - yield struct[k] - - def reconstruct(self, original_structure: BaseData, tensor_iterator: Iterable[Tensor]): - """ - Create a new instance with the same class type as the - original_structure. This new instance will be initialized with tensors - from the provided iterator and uses the same sorted keys from the - yieldTensors() implementation. - """ - tensor_keys = self.sortedTensorKeys(original_structure) - kwargs = {k: next(tensor_iterator) for k in tensor_keys} - - for k in get_keys(original_structure): - if k not in kwargs: - # copy non-tensor properties to the new instance - kwargs[k] = original_structure[k] - - cls = original_structure.__class__ - - if issubclass(cls, Batch): - kwargs["_base_cls"] = Data - return Batch(**kwargs) - - return cls(**kwargs) - - -# PyG uses the BaseData object as the root for data and batch objects -poptorch.registerCustomArgParser(BaseData, PyGArgsParser()) - - -class PredictorModuleIPU(PredictorModule): - """ - This class wraps around the `PredictorModule` to make it work with IPU and the `IPUPluginGraphium`. - """ - - def __init__(self, *args, **kwargs): - # Import poptorch in a safe way that will work when working with cpu/gpu - self.poptorch = import_poptorch() - super().__init__(*args, **kwargs) - - @staticmethod - def compute_loss( - preds: Dict[str, Tensor], - targets: Dict[str, Tensor], - weights: Optional[Tensor], - loss_fun: Dict[str, Callable], - target_nan_mask: Union[Type, str] = "ignore", - multitask_handling: Optional[str] = None, - ) -> Tuple[Tensor, Dict[str, Tensor]]: - return PredictorModule.compute_loss( - preds, targets, weights, loss_fun, target_nan_mask, multitask_handling - ) - - def on_train_batch_end(self, outputs, batch, batch_idx): - outputs = self.convert_from_fp16(outputs) - outputs["loss"] = outputs["loss"][outputs["loss"] != 0].mean() - super().on_train_batch_end(outputs, batch, batch_idx) - - def training_step(self, batch, batch_idx) -> Dict[str, Any]: - features, labels = batch["features"], batch["labels"] - features, labels = self.squeeze_input_dims(features, labels) - dict_input = {"features": features, "labels": labels} - step_dict = super().training_step(dict_input, to_cpu=False) - - loss = step_dict.pop("loss") - step_dict["loss"] = self.poptorch.identity_loss(loss, reduction="mean") - return step_dict - - def validation_step(self, batch, batch_idx) -> Dict[str, Any]: - features, labels = batch["features"], batch["labels"] - features, labels = self.squeeze_input_dims(features, labels) - dict_input = {"features": features, "labels": labels} - step_dict = super().validation_step(dict_input, to_cpu=False) - - return step_dict - - def test_step(self, batch, batch_idx) -> Dict[str, Any]: - # Build a dictionary from the tuples - features, labels = batch["features"], batch["labels"] - features, labels = self.squeeze_input_dims(features, labels) - dict_input = {"features": features, "labels": labels} - step_dict = super().test_step(dict_input, to_cpu=False) - - return step_dict - - def predict_step(self, **inputs) -> Dict[str, Any]: - # Build a dictionary from the tuples - dict_input = inputs - step_dict = super().predict_step(dict_input, to_cpu=False) - - return step_dict - - def on_validation_batch_end( - self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0 - ) -> None: - # convert data that will be tracked - outputs = self.convert_from_fp16(outputs) - super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) - - def evaluation_epoch_end(self, outputs: Any): - outputs = self.convert_from_fp16(outputs) - super().evaluation_epoch_end(outputs) - - def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - outputs = self.convert_from_fp16(outputs) - super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx) - - def configure_optimizers(self, impl=None): - if impl is None: - dtype = self.precision_to_dtype(self.trainer.precision) - impl = functools.partial( - self.poptorch.optim.Adam, - accum_type=dtype, - first_order_momentum_accum_type=dtype, - second_order_momentum_accum_type=torch.float, - ) - return super().configure_optimizers(impl=impl) - - def squeeze_input_dims(self, features, labels): - for key, tensor in features: - if isinstance(tensor, torch.Tensor): - features[key] = features[key].squeeze(0) - - for key in labels: - labels[key] = labels[key].squeeze(0) - - return features, labels - - def convert_from_fp16(self, data: Any) -> Any: - """ - Converts tensors from FP16 to FP32. Useful to convert the IPU program output data - """ - if isinstance(data, collections.Sequence): - for idx in range(len(data)): - data[idx] = self.convert_from_fp16(data[idx]) - elif isinstance(data, collections.Mapping): - for key in data: - data[key] = self.convert_from_fp16(data[key]) - elif isinstance(data, torch.Tensor) and data.dtype == torch.float16: - data = data.float() - return data - - def _convert_features_dtype(self, feats): - """ - Converts features to trainer precision rather than model precision. - Necessary to run IPU on FP16. - """ - dtype = self.precision_to_dtype(self.trainer.precision) - - # Convert features to dtype - if isinstance(feats, torch.Tensor): - feats = feats.to(dtype) - elif isinstance(feats, (Data, Batch, dict)): - for key, val in feats.items(): - if isinstance(val, torch.Tensor) and (val.is_floating_point()): - feats[key] = val.to(dtype=dtype) - else: - raise ValueError(f"Unsupported feats type `{type(feats)}` : {feats}") - return feats - - def precision_to_dtype(self, precision): - return torch.half if precision == "16-true" else torch.float - - def get_num_graphs(self, data: Batch): - """ - IPU specific method to compute the number of graphs in a Batch, - that considers gradient accumulation, multiple IPUs and multiple - device iterations. Essential to estimate throughput in graphs/s. - """ - num_graphs = torch.max(data.batch, dim=-1).values - num_graphs = torch.sum(num_graphs) - - return num_graphs diff --git a/graphium/ipu/to_dense_batch.py b/graphium/ipu/to_dense_batch.py deleted file mode 100644 index 9198ccf3f..000000000 --- a/graphium/ipu/to_dense_batch.py +++ /dev/null @@ -1,186 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import Optional, Tuple - -import torch -from torch import Tensor -from torch_scatter import scatter_add - - -def to_sparse_batch(x: Tensor, mask_idx: Tensor): - """ - Reverse function of `to_dense_batch` - """ - return torch.index_select(x.reshape(-1, x.shape[-1]), 0, mask_idx) - - -def to_sparse_batch_from_packed(x: Tensor, pack_from_node_idx: Tensor): - """ - Reverse function of `to_packed_dense_batch` - """ - return x[pack_from_node_idx[:, 0], pack_from_node_idx[:, 1]] - - -def to_dense_batch( - x: Tensor, - batch: Optional[Tensor] = None, - fill_value: float = 0.0, - max_num_nodes_per_graph: Optional[int] = None, - batch_size: Optional[int] = None, - drop_nodes_last_graph=False, -) -> Tuple[Tensor, Tensor]: - r"""Given a sparse batch of node features - :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with - :math:`N_i` indicating the number of nodes in graph :math:`i`), creates a - dense node feature tensor - :math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with - :math:`N_{\max} = \max_i^B N_i`). - In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times - N_{\max}}` is returned, holding information about the existence of - fake-nodes in the dense representation. - - Parameters: - x: Node feature matrix - :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. - batch: Batch vector - :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each - node to a specific example. Must be ordered. (default: :obj:`None`) - fill_value: The value for invalid entries in the - resulting dense output tensor. (default: :obj:`0`) - max_num_nodes_per_graph: The size of the output node dimension. - (default: :obj:`None`) - batch_size: The batch size. (default: :obj:`None`) - drop_nodes_last_graph: Whether to drop the nodes of the last graphs that exceed - the `max_num_nodes_per_graph`. Useful when the last graph is a padding. - - :rtype: (:class:`Tensor`, :class:`BoolTensor`) - """ - if batch is None and max_num_nodes_per_graph is None: - mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device) - return x.unsqueeze(0), mask - - if batch is None: - batch = x.new_zeros(x.size(0), dtype=torch.long) - - if batch_size is None: - assert x.device.type != "ipu", ( - "When using the IPU the batch size must be " - "provided during compilation instead of determined at runtime" - ) - batch_size = int(batch.max()) + 1 - if x.device not in ["ipu", "xla"]: - num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0, dim_size=batch_size) - else: - # Can't use scatter_add here due to PopTorch bug, will be fixed in SDK 3.3 - arange = torch.arange(batch_size).unsqueeze(-1) - num_nodes = batch.eq(arange).sum(dim=-1) - cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) - - if max_num_nodes_per_graph is None: # Must be provided on IPU - max_num_nodes_per_graph = int(num_nodes.max()) - - idx = torch.arange(batch.size(0), dtype=torch.long, device=x.device) - idx = (idx - cum_nodes[batch]) + (batch * max_num_nodes_per_graph) - - size = [batch_size * max_num_nodes_per_graph] + list(x.size())[1:] - - out = x.new_full(size, fill_value) - - ##### CHANGES FROM PYG ##### - - # In case the last graph represents padding. Drop the overflowing nodes. - if drop_nodes_last_graph: - num_nodes = num_nodes[:-1] - idx[idx >= size[0]] = size[0] - 1 - - # Raise error if num_nodes > max_num_nodes - if x.device.type != "ipu": - assert ( - num_nodes <= max_num_nodes_per_graph - ).all(), f"Encountered graphs with {num_nodes.max()} nodes, greater than `max_num_nodes = {max_num_nodes_per_graph}`" - - out[idx] = x - out = out.view([batch_size, max_num_nodes_per_graph] + list(x.size())[1:]) - - # Create a zero-mask on the right device - mask_sz = batch_size * max_num_nodes_per_graph - if x.device.type in ("ipu", "xla"): - mask = torch.zeros(mask_sz, dtype=torch.int32, device="cpu") - mask = mask.to(x.device) - # Can't use mask[idx] here due to PopTorch bug, will be fixed in SDK 3.3 - # mask[idx] = 1 - # mask = mask.bool() - if drop_nodes_last_graph: - num_nodes_with_padding = torch.cat((num_nodes, torch.tensor([0], dtype=torch.int32)), dim=0) - else: - num_nodes_with_padding = num_nodes - - arange = torch.arange(max_num_nodes_per_graph) - mask = num_nodes_with_padding.unsqueeze(-1).gt(arange).flatten() - - else: - mask = torch.zeros(mask_sz, dtype=torch.bool, device=x.device) - mask[idx] = 1 - - ##### END CHANGES FROM PYG ##### - - mask = mask.view(batch_size, max_num_nodes_per_graph) - - return out, mask, idx # Added `idx` as a return - - -def to_packed_dense_batch( - x: Tensor, - pack_from_node_idx: Tensor, - pack_attn_mask: Tensor, - fill_value: float = 0.0, - max_num_nodes_per_pack: Optional[int] = None, -) -> Tuple[Tensor, Tensor]: - r"""Given a sparse batch of node features - :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with - :math:`N_i` indicating the number of nodes in graph :math:`i`), creates a - dense node feature tensor - :math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with - :math:`N_{\max} = \max_i^B N_i`). - In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times - N_{\max}}` is returned, holding information about the existence of - fake-nodes in the dense representation. - - Parameters: # TODO: Update docstring - x: Node feature matrix - :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. - batch: Batch vector - :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each - node to a specific example. Must be ordered. (default: :obj:`None`) - fill_value: The value for invalid entries in the - resulting dense output tensor. (default: :obj:`0`) - max_num_nodes_per_graph: The size of the output node dimension. - (default: :obj:`None`) - batch_size: The batch size. (default: :obj:`None`) - drop_nodes_last_graph: Whether to drop the nodes of the last graphs that exceed - the `max_num_nodes_per_graph`. Useful when the last graph is a padding. - - :rtype: (:class:`Tensor`, :class:`BoolTensor`) - """ - - if max_num_nodes_per_pack is None: # Must be provided on IPU - max_num_nodes_per_pack = pack_attn_mask.shape[-1] - - size = [pack_attn_mask[0], max_num_nodes_per_pack] + list(x.size())[1:] - - out = x.new_full(size, fill_value) - out[pack_from_node_idx[:, 0], pack_from_node_idx[:, 1]] = x - - return out diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index db28adac4..3ee75984b 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -14,7 +14,7 @@ from typing import Iterable, List, Dict, Literal, Tuple, Union, Callable, Any, Optional, Type from torch_geometric.data import Batch -from graphium.ipu.to_dense_batch import to_dense_batch +from torch_geometric.utils import to_dense_batch from loguru import logger # Misc imports @@ -40,10 +40,6 @@ ResidualConnectionRandom, ) from graphium.nn.utils import MupMixin -from graphium.ipu.ipu_utils import import_poptorch, is_running_on_ipu - -poptorch = import_poptorch(raise_error=False) - import collections @@ -1476,8 +1472,6 @@ def __init__( if accelerator_kwargs is not None: accelerator = accelerator_kwargs["_accelerator"] - if accelerator == "ipu": - self._apply_ipu_options(accelerator_kwargs) self._check_bad_arguments() @@ -1529,45 +1523,6 @@ def _check_bad_arguments(self): f"Task heads have graph level tasks {', '.join(graph_level_tasks)}, but pooling is none." ) - def _apply_ipu_options(self, ipu_kwargs): - gnn_layers_per_ipu = ipu_kwargs.get("gnn_layers_per_ipu") - self._apply_ipu_pipeline_split(gnn_layers_per_ipu) - - def _apply_ipu_pipeline_split(self, gnn_layers_per_ipu): - r""" - Apply pipeline split from accelerator options if applicable - """ - - if gnn_layers_per_ipu is None: - return - - if not isinstance(gnn_layers_per_ipu, collections.abc.Sequence): - raise ValueError("gnn_layers_per_ipu must be a Sequence (e.g. a list)") - - valid_ipu_pipeline_lengths = [1, 2, 4, 8, 16] - pipeline_length = len(gnn_layers_per_ipu) - - if pipeline_length not in valid_ipu_pipeline_lengths: - raise ValueError( - f"Length of gnn_layers_per_ipu must be one of {valid_ipu_pipeline_lengths}, " - f"got {gnn_layers_per_ipu} of length {pipeline_length} instead" - ) - - model_depth = len(self.gnn.layers) - - if sum(gnn_layers_per_ipu) != model_depth: - raise ValueError( - f"The values in gnn_layers_per_ipu must add up to the depth of the model, " - f"got {gnn_layers_per_ipu} with total {sum(gnn_layers_per_ipu)} vs model depth " - f"of {model_depth}" - ) - - begin_block_layer_indices = [sum(gnn_layers_per_ipu[:i]) for i in range(1, pipeline_length)] - - for begin_block_layer_index, ipu_id in zip(begin_block_layer_indices, range(1, pipeline_length)): - self.gnn.layers[begin_block_layer_index] = poptorch.BeginBlock( - self.gnn.layers[begin_block_layer_index], ipu_id=ipu_id - ) def _enable_readout_cache(self, module_filter: Optional[Union[str, List[str]]]): """ @@ -1934,7 +1889,6 @@ def forward(self, g: Batch): node_feats=g["feat"], batch=g.batch, max_num_nodes=self.max_num_nodes_per_graph, - drop_nodes_last_graph=is_running_on_ipu(), ) # Check if at least one graph-level task is present if self.task_level == "graph": @@ -2024,7 +1978,6 @@ def compute_nodepairs( max_num_nodes: int = None, fill_value: float = float("nan"), batch_size: int = None, - drop_nodes_last_graph: bool = False, ) -> torch.Tensor: r""" Vectorized implementation of nodepair-level task: @@ -2035,19 +1988,16 @@ def compute_nodepairs( fill_value: The value for invalid entries in the resulting dense output tensor. (default: :obj:`NaN`) batch_size: The batch size. (default: :obj:`None`) - drop_nodes_last_graph: Whether to drop the nodes of the last graphs that exceed - the `max_num_nodes_per_graph`. Useful when the last graph is a padding. Returns: result: concatenated node features of shape B * max_num_nodes * 2*h, where B is number of graphs, max_num_nodes is the chosen maximum number nodes, and h is the feature dim """ - dense_feat, mask, _ = to_dense_batch( + dense_feat, mask = to_dense_batch( node_feats, batch=batch, fill_value=fill_value, batch_size=batch_size, - max_num_nodes_per_graph=max_num_nodes, - drop_nodes_last_graph=drop_nodes_last_graph, + max_num_nodes=max_num_nodes, ) n = dense_feat.size(1) h_X = dense_feat[:, :, None].repeat(1, 1, n, 1) diff --git a/graphium/nn/base_graph_layer.py b/graphium/nn/base_graph_layer.py index 66869f888..0986cb9c6 100644 --- a/graphium/nn/base_graph_layer.py +++ b/graphium/nn/base_graph_layer.py @@ -258,14 +258,14 @@ def out_dim_factor(self) -> int: @property def max_num_nodes_per_graph(self) -> Optional[int]: """ - Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU) + Get the maximum number of nodes per graph. Useful for reshaping a compiled model """ return self._max_num_nodes_per_graph @max_num_nodes_per_graph.setter def max_num_nodes_per_graph(self, value: Optional[int]): """ - Set the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU) + Set the maximum number of nodes per graph. Useful for reshaping a compiled model """ if value is not None: assert isinstance(value, int) and ( @@ -276,14 +276,14 @@ def max_num_nodes_per_graph(self, value: Optional[int]): @property def max_num_edges_per_graph(self) -> Optional[int]: """ - Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU) + Get the maximum number of nodes per graph. Useful for reshaping a compiled model """ return self._max_num_edges_per_graph @max_num_edges_per_graph.setter def max_num_edges_per_graph(self, value: Optional[int]): """ - Set the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU) + Set the maximum number of nodes per graph. Useful for reshaping a compiled model """ if value is not None: assert isinstance(value, int) and ( diff --git a/graphium/nn/base_layers.py b/graphium/nn/base_layers.py index dbc89f19f..dafba127f 100644 --- a/graphium/nn/base_layers.py +++ b/graphium/nn/base_layers.py @@ -26,7 +26,6 @@ from mup import set_base_shapes, MuReadout from torch.nn.functional import linear -from graphium.ipu.ipu_utils import is_running_on_ipu SUPPORTED_ACTIVATION_MAP = { "ReLU", @@ -156,7 +155,6 @@ def forward( query: Tensor, key: Tensor, value: Tensor, - key_padding_mask: Optional[Tensor] = None, attn_bias: Optional[Tensor] = None, precision: Optional[str] = "32", *args, @@ -186,15 +184,7 @@ def forward( attn_weights = q @ k.transpose(-1, -2) # [batch, num_heads, nodes, nodes] attn_weights += attn_bias - key_padding_mask_value = float("-inf") if precision == "32" else -10000 - # key_padding_mask: [batch, 1, 1, nodes] - if key_padding_mask is not None: - masked_attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - key_padding_mask_value, - ) - else: - masked_attn_weights = attn_weights + masked_attn_weights = attn_weights masked_attn_weights = F.softmax(masked_attn_weights, dim=-1) attn_probs = F.dropout(masked_attn_weights, p=self.dropout, training=self.training) # [batch, num_heads, nodes, nodes] * [batch, num_heads, nodes, head_size] -> [batch, num_heads, nodes, head_size] @@ -243,9 +233,7 @@ class MuReadoutGraphium(MuReadout): Not quite a drop-in replacement for `mup.MuReadout` - you need to specify `base_width`. - Set `base_width` to width of base model passed to `mup.set_base_shapes` - to get same results on IPU and CPU. Should still "work" with any other - value, but won't give the same results as CPU + Set `base_width` to width of base model passed to `mup.set_base_shapes`. """ def __init__(self, in_features, *args, **kwargs): @@ -725,32 +713,23 @@ def forward( Parameters: input: `torch.Tensor[total_num_nodes, hidden]` batch: batch attribute of the batch object, batch.batch - batch_size: The batch size. Must be provided when working on IPU + batch_size: The batch size. Returns: torch.Tensor: `torch.Tensor[total_num_nodes, hidde]` """ - on_ipu = is_running_on_ipu() if self.drop_rate > 0: keep_prob = 1 - self.drop_rate # Parse the batch size if batch_size is None: - if on_ipu: - raise ValueError( - "When using the IPU the batch size must be " - "provided during compilation instead of determined at runtime" - ) - else: - batch_size = int(batch_idx.max()) + 1 + batch_size = int(batch_idx.max()) + 1 # mask shape: [num_graphs, 1] mask = input.new_empty(batch_size, 1).bernoulli_(keep_prob) - # if on_ipu, the last graph is a padded fake graph - if on_ipu: - mask[-1] = 0 + # using gather to extend mask to [total_num_nodes, 1] node_mask = mask[batch_idx] if keep_prob == 0: diff --git a/graphium/nn/encoders/gaussian_kernel_pos_encoder.py b/graphium/nn/encoders/gaussian_kernel_pos_encoder.py index 44ddd5578..19ff813ac 100644 --- a/graphium/nn/encoders/gaussian_kernel_pos_encoder.py +++ b/graphium/nn/encoders/gaussian_kernel_pos_encoder.py @@ -2,7 +2,6 @@ from torch_geometric.data import Batch from graphium.nn.pyg_layers.utils import PreprocessPositions -from graphium.ipu.ipu_utils import is_running_on_ipu from graphium.nn.encoders.base_encoder import BaseEncoder @@ -116,13 +115,10 @@ def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> Dict[str, A """ input_keys = self.parse_input_keys_with_prefix(key_prefix) - on_ipu = is_running_on_ipu() max_num_nodes_per_graph = None - if on_ipu: - max_num_nodes_per_graph = self.max_num_nodes_per_graph attn_bias_3d, node_feature_3d = self.preprocess_3d_positions( - batch, max_num_nodes_per_graph, on_ipu, positions_3d_key=input_keys[0] + batch, max_num_nodes_per_graph, positions_3d_key=input_keys[0] ) # Return `attn_bias_3d` if the key starts with 'nodepair_' diff --git a/graphium/nn/pyg_layers/gps_pyg.py b/graphium/nn/pyg_layers/gps_pyg.py index 3d8671c53..24c74a6ec 100644 --- a/graphium/nn/pyg_layers/gps_pyg.py +++ b/graphium/nn/pyg_layers/gps_pyg.py @@ -18,6 +18,7 @@ from torch.nn import Module from torch import Tensor from torch_geometric.data import Batch +from torch_geometric.utils import to_dense_batch from graphium.nn.base_graph_layer import BaseGraphModule from graphium.nn.base_layers import FCLayer, MultiheadAttentionMup, MLP from graphium.nn.pyg_layers import ( @@ -29,13 +30,6 @@ ) from graphium.data.utils import get_keys from graphium.utils.decorators import classproperty -from graphium.ipu.to_dense_batch import ( - to_dense_batch, - to_sparse_batch, - to_packed_dense_batch, - to_sparse_batch_from_packed, -) -from graphium.ipu.ipu_utils import is_running_on_ipu PYG_LAYERS_DICT = { "pyg:gin": GINConvPyg, @@ -286,7 +280,7 @@ def forward(self, batch: Batch) -> Batch: # MLP block, with skip connection feat_mlp = self.mlp(feat) # Add the droppath to the output of the MLP - batch_size = None if feat.device.type != "ipu" else batch.graph_is_true.shape[0] + batch_size = None if self.droppath_ffn is not None: feat_mlp = self.droppath_ffn(feat_mlp, batch.batch, batch_size) feat = feat + feat_mlp @@ -375,50 +369,27 @@ def _to_dense_batch( h: Tensor, batch: Batch, batch_size: Optional[int] = None, - max_num_nodes_per_graph: Optional[int] = None, - on_ipu: bool = False, + max_num_nodes: Optional[int] = None, ) -> Tensor: """ Convert the batch of graphs to a dense batch. """ - if self._use_packing(batch): - attn_mask = batch.pack_attn_mask - key_padding_mask = None - idx = batch.pack_from_node_idx - h_dense = to_packed_dense_batch( - h, - pack_from_node_idx=idx, - pack_attn_mask=attn_mask, - max_num_nodes_per_pack=100, # TODO: This should be a parameter - ) - else: - attn_mask = None - h_dense, key_padding_mask, idx = to_dense_batch( - h, - batch=batch.batch, # The batch index as a vector that indicates for nodes of which graph it belongs to - batch_size=batch_size, - max_num_nodes_per_graph=max_num_nodes_per_graph, - drop_nodes_last_graph=on_ipu, - ) - key_padding_mask = ~key_padding_mask - return h_dense, attn_mask, key_padding_mask, idx - - def _to_sparse_batch(self, batch: Batch, h_dense: Tensor, idx: Tensor) -> Tensor: + h_dense, key_padding_mask = to_dense_batch( + h, + batch=batch.batch, # The batch index as a vector that indicates for nodes of which graph it belongs to + batch_size=batch_size, + max_num_nodes=max_num_nodes, + ) + key_padding_mask = ~key_padding_mask + return h_dense, key_padding_mask + + def _to_sparse_batch(self, batch: Batch, h_dense: Tensor, mask: torch.BoolTensor) -> Tensor: """ Convert the dense batch back to a sparse batch. """ - if self._use_packing(batch): - h = to_sparse_batch_from_packed( - h_dense, - pack_from_node_idx=idx, - ) - else: - h = to_sparse_batch( - h_dense, - mask_idx=idx, - ) - return h + + return h_dense[mask] def _self_attention_block(self, feat: Tensor, feat_in: Tensor, batch: Batch) -> Tensor: """ @@ -429,21 +400,17 @@ def _self_attention_block(self, feat: Tensor, feat_in: Tensor, batch: Batch) -> """ # Multi-head attention. - on_ipu = is_running_on_ipu() max_num_nodes_per_graph = None - if on_ipu: - max_num_nodes_per_graph = self.max_num_nodes_per_graph # Convert the tensor to a dense batch, then back to a sparse batch - batch_size = None if feat.device.type != "ipu" else batch.graph_is_true.shape[0] + batch_size = None # h[num_nodes, hidden_dim] -> h_dense[num_graphs, max_num_nodes, hidden_dim] - feat_dense, attn_mask, key_padding_mask, idx = self._to_dense_batch( + feat_dense, attn_mask = self._to_dense_batch( feat, batch=batch, # The batch index as a vector that indicates for nodes of which graph it belongs to batch_size=batch_size, - max_num_nodes_per_graph=max_num_nodes_per_graph, - on_ipu=on_ipu, + max_num_nodes=max_num_nodes_per_graph, ) attn_bias = None @@ -452,11 +419,11 @@ def _self_attention_block(self, feat: Tensor, feat_in: Tensor, batch: Batch) -> # h_dense[num_graphs, max_num_nodes, hidden_dim] -> feat_attn[num_graphs, max_num_nodes, hidden_dim] feat_attn = self._sa_block( - feat_dense, attn_bias=attn_bias, attn_mask=attn_mask, key_padding_mask=key_padding_mask + feat_dense, attn_bias=attn_bias, attn_mask=attn_mask ) # feat_attn[num_graphs, max_num_nodes, hidden_dim] -> feat_attn[num_nodes, hidden_dim] - feat_attn = self._to_sparse_batch(batch, feat_attn, idx) + feat_attn = self._to_sparse_batch(batch, feat_attn, attn_mask) # Dropout, residual, norm if self.dropout_attn is not None: diff --git a/graphium/nn/pyg_layers/utils.py b/graphium/nn/pyg_layers/utils.py index 83dc4f737..169e1e5a7 100644 --- a/graphium/nn/pyg_layers/utils.py +++ b/graphium/nn/pyg_layers/utils.py @@ -16,13 +16,13 @@ import torch import torch.nn as nn from torch_geometric.data import Batch +from torch_geometric.utils import to_dense_batch from typing import Tuple from torch import Tensor from torch_geometric.typing import SparseTensor from graphium.nn.base_layers import MLP, get_norm -from graphium.ipu.to_dense_batch import to_dense_batch, to_sparse_batch class PreprocessPositions(nn.Module): @@ -74,7 +74,7 @@ def __init__( self.node_proj = nn.Linear(self.num_kernel, self.embed_dim) def forward( - self, batch: Batch, max_num_nodes_per_graph: int, on_ipu: bool, positions_3d_key: str + self, batch: Batch, max_num_nodes_per_graph: int, positions_3d_key: str ) -> Tuple[Tensor, Tensor]: r""" Inputs: @@ -82,8 +82,6 @@ def forward( Batch object. max_num_nodes_per_graph: Maximum number of nodes per graph. - on_ipu: - If model rus on IPU. positions_3d_key: The key of the pyg graph object that contains the 3D positions. @@ -92,17 +90,14 @@ def forward( pos = batch[positions_3d_key] if self.first_normalization is not None: pos = self.first_normalization(pos) - batch_size = None if pos.device.type != "ipu" else batch.graph_is_true.shape[0] - # batch_size = None if batch.feat.device.type != "ipu" else batch.graph_is_true.shape[0] #[Andy] batch.feat is only available after passing through layers, not a good attribute to check + batch_size = None # pos: [batch, nodes, 3] # padding_mask: [batch, nodes] - # idx: [totoal_nodes] - pos, mask, idx = to_dense_batch( + pos, mask = to_dense_batch( pos, batch=batch.batch, batch_size=batch_size, - max_num_nodes_per_graph=max_num_nodes_per_graph, - drop_nodes_last_graph=on_ipu, + max_num_nodes=max_num_nodes_per_graph, ) # check nan with the pos from to_dense_batch, # and generate mask. 1 for nan, 0 for other values. @@ -153,7 +148,7 @@ def forward( # unsqueezed mask size: [batch, 1, 1] apply on tensor [batch, nodes, embed_dim] node_feature.masked_fill_(nan_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 0.0) # [total_nodes, embed_dim] - node_feature = to_sparse_batch(node_feature, idx) + node_feature = node_feature[mask] return attn_bias, node_feature diff --git a/graphium/trainer/__init__.py b/graphium/trainer/__init__.py index ed2cbf2a4..1e1682e2f 100644 --- a/graphium/trainer/__init__.py +++ b/graphium/trainer/__init__.py @@ -2,3 +2,6 @@ from . import metrics from .predictor import PredictorModule +from .predictor_summaries import SingleTaskSummary +from .predictor_summaries import MultiTaskSummary +from .progress_bar import ProgressBarMetrics diff --git a/graphium/trainer/metrics.py b/graphium/trainer/metrics.py index 22361faa6..7a20e0049 100644 --- a/graphium/trainer/metrics.py +++ b/graphium/trainer/metrics.py @@ -12,16 +12,22 @@ """ -from typing import Union, Callable, Optional, Dict, Any +from typing import Union, Callable, Optional, Dict, Any, Literal, List, Tuple import sys import torch from torch import Tensor +import torch.distributed as dist import operator as op +from copy import deepcopy +from loguru import logger +from torch.nn.modules.loss import _Loss from torchmetrics.utilities.distributed import reduce import torchmetrics.functional.regression.mae +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics import Metric from graphium.utils.tensor import nan_mean @@ -38,7 +44,7 @@ class Thresholder: def __init__( self, threshold: float, - operator: Union[str, Callable] = "greater", + operator: Union[Literal["greater", "gt", ">", "lower", "lt", "<"], Callable] = "greater", th_on_preds: bool = True, th_on_target: bool = False, ): @@ -74,10 +80,10 @@ def _get_operator(operator): """Operator can either be a string, or a callable""" if isinstance(operator, str): op_name = operator.lower() - if op_name in ["greater", "gt"]: + if op_name in ["greater", "gt", ">"]: op_str = ">" operator = op.gt - elif op_name in ["lower", "lt"]: + elif op_name in ["lower", "lt", "<"]: op_str = "<" operator = op.lt else: @@ -129,6 +135,28 @@ def __eq__(self, obj) -> bool: return all(is_eq) +def _filter_nans(preds: Tensor, target: Tensor, target_nan_mask: Union[Literal[None, "none", "ignore"], int]) -> Tuple[Tensor, Tensor]: + """Handle the NaNs according to the chosen options""" + + if target_nan_mask is None: # No NaN handling + return preds, target + + if target.dtype in [torch.int, torch.int16, torch.int32, torch.int64, torch.int8]: + target_nans = (torch.iinfo(target.dtype).min == target) | (torch.iinfo(target.dtype).max == target) + else: + target_nans = torch.isnan(target) + if ~target_nans.any(): # No NaNs + return preds, target + elif isinstance(target_nan_mask, (int, float)): # Replace NaNs + target = target.clone() + target[target_nans] = target_nan_mask + elif target_nan_mask == "ignore": # Remove NaNs + target = target[~target_nans] + preds = preds[~target_nans] + else: + raise ValueError(f"Invalid option `{target_nan_mask}`") + return preds, target + class MetricWrapper: r""" Allows to initialize a metric from a name or Callable, and initialize the @@ -137,10 +165,10 @@ class MetricWrapper: def __init__( self, - metric: Union[str, Callable], + metric: Union[str, torchmetrics.Metric, torch.nn.modules.loss._Loss], threshold_kwargs: Optional[Dict[str, Any]] = None, - target_nan_mask: Optional[Union[str, int]] = None, - multitask_handling: Optional[str] = None, + target_nan_mask: Union[Literal[None, "none", "ignore"], int] = None, + multitask_handling: Literal[None, "none", "flatten", "mean-per-label"] = None, squeeze_targets: bool = False, target_to_int: bool = False, **kwargs, @@ -187,7 +215,7 @@ def __init__( Other arguments to call with the metric """ - self.metric, self.metric_name = self._get_metric(metric) + metric_class, self.metric_name = self._get_metric_class(metric) self.thresholder = None if threshold_kwargs is not None: self.thresholder = Thresholder(**threshold_kwargs) @@ -198,6 +226,34 @@ def __init__( self.target_to_int = target_to_int self.kwargs = kwargs + self.metric, self.kwargs = self._initialize_metric(metric_class, self.target_nan_mask, self.multitask_handling, **self.kwargs) + + @staticmethod + def _initialize_metric(metric, target_nan_mask, multitask_handling, **kwargs): + r""" + Initialize the metric with the provided kwargs + """ + + if not isinstance(metric, type): + if callable(metric): + metric = MetricToConcatenatedTorchMetrics( + metric_fn=metric, + target_nan_mask=target_nan_mask, + multitask_handling=multitask_handling, + **kwargs) + return metric, kwargs + elif all(hasattr(metric, method) for method in ["update", "compute", "reset", "to"]): + return metric, kwargs + else: + raise ValueError(f"metric must be a callable, or a class with 'update', 'compute', 'reset', 'to', provided: `{type(metric)}`") + + metric = metric(**kwargs) + if not all(hasattr(metric, method) for method in ["update", "compute", "reset", "to"]): + raise ValueError(f"metric must be a callable, or a class with 'update', 'compute', 'reset', 'to', provided: `{type(metric)}`") + + return metric, kwargs + + @staticmethod def _parse_target_nan_mask(target_nan_mask): """ @@ -254,20 +310,35 @@ def _parse_multitask_handling(multitask_handling, target_nan_mask): return multitask_handling @staticmethod - def _get_metric(metric): + def _get_metric_class(metric): from graphium.utils.spaces import METRICS_DICT if isinstance(metric, str): - metric_name = metric - metric = METRICS_DICT[metric] + metric_name = MetricWrapper._ipu_metrics_name_conversion(metric) + metric = METRICS_DICT[metric_name] else: metric_name = None metric = metric return metric, metric_name - - def compute(self, preds: Tensor, target: Tensor) -> Tensor: + + @staticmethod + def _ipu_metrics_name_conversion(metric, warning=True): r""" - Compute the metric, apply the thresholder if provided, and manage the NaNs + Convert the metric name from the removed ipu metrics to the regular torchmetrics metrics + """ + metric_name = metric + if metric_name.endswith("_ipu"): # For backward compatibility when loading models with metrics for ipu + metric_name = metric_name[:-4] + if metric_name == "average_precision": # A previous typo in the `spaces.py` + metric_name = "averageprecision" + if warning: + logger.warning(f"Using the metric `{metric_name}` instead of `{metric}`") + return metric_name + + def update(self, preds: Tensor, target: Tensor) -> Tensor: + r""" + Update the parameters of the metric, apply the thresholder if provided, and manage the NaNs. + See `torchmetrics.Metric.update` for more details. """ if preds.ndim == 1: preds = preds.unsqueeze(-1) @@ -279,8 +350,6 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor: if self.thresholder is not None: preds, target = self.thresholder(preds, target) - target_nans = torch.isnan(target) - # for the classifigression task, cast predictions from # (batch_size, n_targets * n_brackets) to (batch_size, n_targets, n_brackets) # TODO: make this more flexible to the target shape in the future @@ -290,7 +359,7 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor: else: classifigression = False - if self.multitask_handling is None: + if (self.multitask_handling is None): # In case of no multi-task handling, apply the nan filtering, then compute the metrics assert ( self.target_nan_mask != "ignore" @@ -300,7 +369,9 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor: target = target.squeeze() if self.target_to_int: target = target.to(int) - metric_val = self.metric(preds, target, **self.kwargs) + self.metric.update(preds, target) + + elif self.multitask_handling == "flatten": # Flatten the tensors, apply the nan filtering, then compute the metrics if classifigression: @@ -313,16 +384,26 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor: target = target.squeeze() if self.target_to_int: target = target.to(int) - metric_val = self.metric(preds, target, **self.kwargs) + self.metric.update(preds, target) + + elif isinstance(self.metric, MetricToConcatenatedTorchMetrics): + # NaN's and multitask handling are handled by the MetricToConcatenatedTorchMetrics + if self.squeeze_targets: + target = target.squeeze() + if self.target_to_int: + target = target.to(int) + self.metric.update(preds, target) + elif self.multitask_handling == "mean-per-label": # Loop the columns (last dim) of the tensors, apply the nan filtering, compute the metrics per column, then average the metrics - target_list = [target[..., ii][~target_nans[..., ii]] for ii in range(target.shape[-1])] - # TODO: make this more flexible to the target shape in the future + target_list = [target[..., ii] for ii in range(target.shape[-1])] if classifigression: - preds_list = [preds[..., i, :][~target_nans[..., i]] for i in range(preds.shape[1])] + preds_list = [preds[..., ii, :] for ii in range(preds.shape[1])] else: - preds_list = [preds[..., ii][~target_nans[..., ii]] for ii in range(preds.shape[-1])] - metric_val = [] + preds_list = [preds[..., ii] for ii in range(preds.shape[-1])] + + if not isinstance(self.metric, list): + self.metric = [deepcopy(self.metric) for _ in range(len(target_list))] for ii in range(len(target_list)): try: this_preds, this_target = self._filter_nans(preds_list[ii], target_list[ii]) @@ -330,44 +411,78 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor: this_target = this_target.squeeze() if self.target_to_int: this_target = this_target.to(int) - metric_val.append(self.metric(this_preds, this_target, **self.kwargs)) + self.metric[ii].update(this_preds, this_target) except: pass - # Average the metric - metric_val = nan_mean(torch.stack(metric_val)) else: # Wrong option raise ValueError(f"Invalid option `self.multitask_handling={self.multitask_handling}`") - return metric_val + def compute(self) -> Tensor: + r""" + Compute the metric with the method `self.compute` + """ + if isinstance(self.metric, list): + metrics = [metric.compute() for metric in self.metric] + return nan_mean(torch.stack(metrics)) + + return self.metric.compute() - def _filter_nans(self, preds: Tensor, target: Tensor): - """Handle the NaNs according to the chosen options""" - target_nans = torch.isnan(target) + def update_compute(self, preds: Tensor, target: Tensor) -> Tensor: + r""" + Update the parameters of the metric, apply the thresholder if provided, and manage the NaNs. + Then compute the metric with the method `self.compute` + """ - if self.target_nan_mask is None: - pass - elif isinstance(self.target_nan_mask, (int, float)): - target = target.clone() - target[torch.isnan(target)] = self.target_nan_mask - elif self.target_nan_mask == "ignore": - target = target[~target_nans] - preds = preds[~target_nans] + self.update(preds, target) + return self.compute() + + def reset(self) -> None: + r""" + Reset the metric with the method `self.metric.reset` + """ + if isinstance(self.metric, list): + for metric in self.metric: + metric.reset() else: - raise ValueError(f"Invalid option `{self.target_nan_mask}`") - return preds, target + self.metric.reset() + + def to(self, device: Union[str, torch.device]) -> None: + r""" + Move the metric to the device with the method `self.metric.to` + """ + if isinstance(self.metric, list): + for metric in self.metric: + metric.to(device) + else: + self.metric.to(device) + + @property + def device(self) -> torch.device: + r""" + Return the device of the metric with the method `self.metric.device` or `self.metric[0].device` + """ + if isinstance(self.metric, list): + return self.metric[0].device + return self.metric.device + + + def _filter_nans(self, preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """Handle the NaNs according to the chosen options""" + + return _filter_nans(preds, target, self.target_nan_mask) def __call__(self, preds: Tensor, target: Tensor) -> Tensor: r""" Compute the metric with the method `self.compute` """ - return self.compute(preds, target) + return self.update_compute(preds, target) def __repr__(self): r""" Control how the class is printed """ - full_str = f"{self.metric.__name__}" + full_str = f"{self.metric.__repr__()}" if self.thresholder is not None: full_str += f"({self.thresholder})" @@ -405,10 +520,198 @@ def __getstate__(self): def __setstate__(self, state: dict): """Reload the class from pickling.""" - state["metric"], state["metric_name"] = self._get_metric(state["metric"]) + state["metric"], state["metric_name"] = self._get_metric_class(state["metric"]) thresholder = state.pop("threshold_kwargs", None) if thresholder is not None: thresholder = Thresholder(**thresholder) state["thresholder"] = thresholder + state["metric"], state["at_compute_kwargs"] = self._initialize_metric(state["metric"], state["target_nan_mask"], state["multitask_handling"], **state["kwargs"]) self.__dict__.update(state) + +class LossWrapper(): + r""" + A simple wrapper to convert any metric or loss to an equivalent of `torchmetrics.Metric` + by adding the `update`, `compute`, and `reset` methods to make it compatible with `MetricWrapper`. + However, it is simply limited to computing the average of the metric over all the updates. + """ + + def __init__(self, loss): + self.loss = loss + self.scores: List[Tensor] = [] + + def update(self, preds: Tensor, target: Tensor): + self.scores.append(self.loss(preds, target)) + + def compute(self): + if len(self.scores) == 0: + raise ValueError("No scores to compute") + elif len(self.scores) == 1: + return self.scores[0] + return nan_mean(torch.stack(self.scores)) + + def to(self, device: Union[str, torch.device]): + for ii in range(len(self.scores)): + self.scores[ii] = self.scores[ii].to(device) + + @property + def device(self) -> torch.device: + self.loss.device + + def reset(self): + self.scores = [] + + +class MetricToMeanTorchMetrics(Metric): + r""" + A simple wrapper to convert any metric or loss to an equivalent of `torchmetrics.Metric` + by adding the `update`, `compute`, and `reset` methods to make it compatible with `MetricWrapper`. + + However, it is limited in functionality. At each `.update()`, it computes the metric and stores in a list. + Then at `.compute()` it returns the average of the computed metric, while ignoring NaNs. + """ + scores: List[Tensor] = [] + + def __init__(self, metric_fn): + super().__init__(dist_sync_on_step=False) + self.metric_fn = metric_fn + self.add_state("scores", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor): + self.scores.append(self.metric_fn(preds.detach(), target)) + + def compute(self): + if len(self.scores) == 0: + raise ValueError("No scores to compute") + elif len(self.scores) == 1: + return self.scores[0] + return nan_mean(torch.stack(self.scores)) + + +class MetricToConcatenatedTorchMetrics(Metric): + + preds: List[Tensor] # Always on CPU + target: List[Tensor] # Always on CPU + + def __init__(self, + metric_fn: Callable, + target_nan_mask: Union[Literal[None, "none", "ignore"], int] = None, + multitask_handling: Literal[None, "none", "flatten", "mean-per-label"] = None, + **kwargs, + ): + r""" + A wrapper around the `torchmetrics.Metric` to handle the saving and syncing of `preds` and `target` tensors, + and moving them to the CPU. + This is useful for certain metrics that require to save all preds and targets, such as auroc and average_precision. + Otherwise, if using `MetricWrapper` with the option `mean-per-label`, the `preds` and `target` would be + duplicated for each label, causing major memory spikes. + On top of that, all preds and targets would be on the GPU, which would cause the memory to increase at every step, + and potentially lead to out-of-memory before the end of the epoch. + + Parameters + ---------- + + metric_fn: + The metric function to use. This function should take `preds` and `target` as input, and return a scalar value. + + target_nan_mask: + - None: Do not change behaviour if there are NaNs + + - int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then + all NaNs will be replaced by zeros + + - 'ignore': The NaN values will be removed from the tensor before computing the metrics. + Must be coupled with the `multitask_handling='flatten'` or `multitask_handling='mean-per-label'`. + + multitask_handling: + - None: Do not process the tensor before passing it to the metric. + Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`. + Use either 'flatten' or 'mean-per-label'. + + - 'flatten': Flatten the tensor to produce the equivalent of a single task + + - 'mean-per-label': Loop all the labels columns, process them as a single task, + and average the results over each task + *This option might slow down the computation if there are too many labels* + + """ + + super().__init__(compute_on_cpu=True, dist_sync_on_step=False, sync_on_compute=False) + self.metric_fn = metric_fn + self.target_nan_mask = target_nan_mask + self.multitask_handling = multitask_handling + self.kwargs = kwargs + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + self._to_device_warned: bool = False + super().to("cpu") + + def update(self, preds: Tensor, target: Tensor): + + # If distributed, gather the preds and target tensors + if self.dist_sync_fn is not None: + preds_list = self.dist_sync_fn(preds, self.process_group) + target_list = self.dist_sync_fn(target, self.process_group) + preds = dim_zero_cat(preds_list) + target = dim_zero_cat(target_list) + + # Move the tensors to the CPU after gathering them + self.preds.append(preds.detach().cpu()) + self.target.append(target.cpu()) + + def compute(self): + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + + if (self.multitask_handling is None) or (self.multitask_handling in ["none", "flatten"]): + preds, target = _filter_nans(preds, target, self.target_nan_mask) + value = self.metric_fn(preds, target, **self.kwargs) + + elif self.multitask_handling == "mean-per-label": + value = [] + # Loop the columns (last dim) of the tensors, apply the nan filtering, compute the metrics per column, then average the metrics + target_list = [target[..., ii] for ii in range(target.shape[-1])] + preds_list = [preds[..., ii] for ii in range(preds.shape[-1])] + for ii in range(len(target_list)): + try: + this_preds, this_target = _filter_nans(preds_list[ii], target_list[ii], self.target_nan_mask) + value.append(self.metric_fn(this_preds, this_target, **self.kwargs)) + except: + pass + value = nan_mean(torch.stack(value)) + else: + # Wrong option + raise ValueError(f"Invalid option `self.multitask_handling={self.multitask_handling}`") + return value + + def to(self, device: Union[str, torch.device]): + """ + Disables the moving of the metric to another device. Stays on CPU to avoid overflow. + """ + device = torch.device(device) + if device == torch.device("cpu"): + return + if not self._to_device_warned: + self._to_device_warned = True + logger.warning(f"{self.get_obj_name(self)}({self.get_obj_name(self.metric_fn)}) stays on `{self.device}`, won't move to `{device}`") + + @staticmethod + def get_obj_name(obj): + """ + Returns the name of a function, class, or instance of a class. + + Parameters: + - obj: The object to get the name of. + + Returns: + - The name of the object as a string. + """ + # If the object is a class or function, return its __name__ + if hasattr(obj, '__name__'): + return obj.__name__ + # If the object is an instance of a class, return its class's __name__ + elif hasattr(obj, '__class__'): + return obj.__class__.__name__ + else: + return str(obj) # Fallback to converting the object to string + diff --git a/graphium/trainer/predictor.py b/graphium/trainer/predictor.py index 8cfb1ad28..f15521268 100644 --- a/graphium/trainer/predictor.py +++ b/graphium/trainer/predictor.py @@ -14,7 +14,7 @@ import time from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Literal, Mapping import lightning import numpy as np @@ -23,20 +23,20 @@ from mup.optim import MuAdam from torch import Tensor, nn from torch_geometric.data import Batch, Data +from torchmetrics import Metric from graphium.config.config_convert import recursive_config_reformating from graphium.data.datamodule import BaseDataModule -from graphium.trainer.metrics import MetricWrapper +from graphium.trainer.metrics import MetricWrapper, LossWrapper from graphium.trainer.predictor_options import ( EvalOptions, FlagOptions, ModelOptions, OptimOptions, ) -from graphium.trainer.predictor_summaries import TaskSummaries +from graphium.trainer.predictor_summaries import MultiTaskSummary, GradientNormMetric from graphium.utils import fs from graphium.utils.moving_average_tracker import MovingAverageTracker -from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT from graphium.utils.tensor import dict_tensor_fp16_to_fp32 @@ -54,15 +54,12 @@ def __init__( scheduler_kwargs: Optional[Dict[str, Any]] = None, target_nan_mask: Optional[Union[str, int]] = None, multitask_handling: Optional[str] = None, - metrics: Dict[str, Callable] = None, + metrics: Dict[str, Dict[str, Union[Metric, "MetricWrapper"]]] = None, metrics_on_progress_bar: Dict[str, List[str]] = [], metrics_on_training_set: Optional[Dict[str, List[str]]] = None, flag_kwargs: Dict[str, Any] = None, task_norms: Optional[Dict[Callable, Any]] = None, metrics_every_n_train_steps: Optional[int] = None, - replicas: int = 1, - gradient_acc: int = 1, - global_bs: Optional[int] = 1, ): """ The Lightning module responsible for handling the predictions, losses, metrics, optimization, etc. @@ -139,11 +136,12 @@ def __init__( # Task-specific evalutation attributes self.loss_fun = {} + loss_names = {} self.metrics = {} self.metrics_on_progress_bar = {} self.metrics_on_training_set = {} for task in self.tasks: - self.loss_fun[task] = EvalOptions.parse_loss_fun(loss_fun[task]) + loss_names[task], self.loss_fun[task] = EvalOptions.parse_loss_fun(loss_fun[task]) self.metrics[task] = ( self._eval_options_dict[task].metrics if self._eval_options_dict[task].metrics is not None @@ -164,36 +162,51 @@ def __init__( # Set the parameters for optimizer options self.optim_options.set_kwargs() + # Add the loss to the metrics + metrics_with_loss = deepcopy(self.metrics) + for task in self.tasks: + metrics_with_loss[task][f"loss_{loss_names[task]}"] = MetricWrapper( + metric=LossWrapper(self.loss_fun[task]), + target_nan_mask=self.target_nan_mask, + multitask_handling=self.multitask_handling, + ) + # Initialize the epoch summary - monitor = self.optim_options.scheduler_kwargs["monitor"].split("/")[0] - mode = self.optim_options.scheduler_kwargs["mode"] - - self.task_epoch_summary = TaskSummaries( - task_loss_fun=self.loss_fun, - task_metrics=self.metrics, - task_metrics_on_training_set=self.metrics_on_training_set, - task_metrics_on_progress_bar=self.metrics_on_progress_bar, - monitor=monitor, - mode=mode, - ) + self.task_epoch_summary = { + "train": MultiTaskSummary( + task_metrics=metrics_with_loss, + step_name="train", + task_metrics_on_progress_bar=None, + task_metrics_on_training_set=self.metrics_on_training_set, + ), + "val": MultiTaskSummary( + task_metrics=metrics_with_loss, + step_name="val", + task_metrics_on_progress_bar=self.metrics_on_progress_bar, + task_metrics_on_training_set=None, + ), + "test": MultiTaskSummary( + task_metrics=metrics_with_loss, + step_name="test", + task_metrics_on_progress_bar=None, + task_metrics_on_training_set=None, + ), + } # This helps avoid a bug when saving hparams to yaml with different dict or str formats self._set_hparams(recursive_config_reformating(self.hparams)) # throughput estimation - self.mean_val_time_tracker = MovingAverageTracker() - self.mean_val_tput_tracker = MovingAverageTracker() - self.validation_step_outputs = [] - self.test_step_outputs = [] - self.epoch_start_time = None + self.mean_time_tracker = MovingAverageTracker() + self.mean_tput_tracker = MovingAverageTracker() + self.epoch_start_time = {} # Decide whether to log every step or once at the end # of the epoch. self.metrics_every_n_train_steps = metrics_every_n_train_steps # Wether save preds and targets for each training step. - self.samples_seen = 0 - self.global_bs = global_bs + self.model_grad = GradientNormMetric() def forward( self, inputs: Dict @@ -234,6 +247,22 @@ def _get_task_key(self, task_level: str, task: str): if not task.startswith(task_prefix): task = task_prefix + task return task + + def _get_average_loss_from_outputs(self, outputs: Dict[Literal["loss", "task_losses"], Tensor], step_name: Literal["train", "val", "test"]) -> Dict[str, Tensor]: + r""" + Averages the loss over the different tasks + """ + global_loss = torch.as_tensor(outputs["loss"]).detach() + if global_loss.numel() > 1: + global_loss = global_loss[global_loss != 0].mean() + average_losses = {f"_global/loss/{step_name}": global_loss} + for task in self.tasks: + this_losses = torch.as_tensor(outputs["task_losses"][task]).detach() + if this_losses.numel() > 1: + this_losses = this_losses[this_losses != 0].mean() + average_losses[f"{task}/loss/{step_name}"] = this_losses + return average_losses + def configure_optimizers(self, impl=None): if impl is None: @@ -306,7 +335,7 @@ def compute_loss( wrapped_loss_fun_dict = { task: MetricWrapper( - metric=loss, + metric=LossWrapper(loss), threshold_kwargs=None, target_nan_mask=target_nan_mask, multitask_handling=multitask_handling, @@ -316,16 +345,18 @@ def compute_loss( if weights is not None: raise NotImplementedError("Weights are no longer supported in the loss") + all_task_losses = { - task: wrapped(preds=preds[task], target=targets[task]) + task: wrapped.update_compute(preds=preds[task], target=targets[task]) for task, wrapped in wrapped_loss_fun_dict.items() } + total_loss = torch.sum(torch.stack(list(all_task_losses.values())), dim=0) num_tasks = len(all_task_losses.keys()) weighted_loss = total_loss / num_tasks return weighted_loss, all_task_losses - def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool) -> Dict[str, Any]: + def _general_step(self, batch: Dict[str, Tensor], step_name: Literal["train", "val", "test"]) -> Dict[str, Any]: r"""Common code for training_step, validation_step and testing_step""" preds = self.forward(batch) # The dictionary of predictions @@ -366,7 +397,6 @@ def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool) multitask_handling=self.multitask_handling, ) - device = "cpu" if to_cpu else None for task in preds: task_specific_norm = self.task_norms[task] if self.task_norms is not None else None if hasattr(task_specific_norm, "normalize_val_test"): @@ -379,28 +409,18 @@ def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool) # if normalize_val_test is true, no denormalization is applied, all losses and metrics are normalized version preds[task] = task_specific_norm.denormalize(preds[task]) targets_dict[task] = task_specific_norm.denormalize(targets_dict[task]) - preds[task] = preds[task].detach().to(device=device) - targets_dict[task] = targets_dict[task].detach().to(device=device) - if weights is not None: - weights = weights.detach().to(device=device) + preds[task] = preds[task].detach() + targets_dict[task] = targets_dict[task].detach() - step_dict = {"preds": preds, "targets": targets_dict, "weights": weights} - # step_dict[f"{self.loss_fun._get_name()}/{step_name}"] = loss.detach().cpu() original - - # step_dict[f"weighted_loss/{step_name}"] = loss.detach().cpu() - # step_dict[f"loss/{step_name}"] = loss.detach().cpu() - for task in self.tasks: - step_dict[ - self.task_epoch_summary.metric_log_name(task, self.loss_fun[task]._get_name(), step_name) - ] = loss.detach() + self.task_epoch_summary[step_name].update(preds, targets_dict) + step_dict = {} step_dict["loss"] = loss - # print("loss ", self.global_step, self.current_epoch, loss) step_dict["task_losses"] = task_losses - step_dict["gradient_norm"] = self.get_gradient_norm() return step_dict - def flag_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool) -> Dict[str, Any]: + + def flag_step(self, batch: Dict[str, Tensor], step_name: Literal["train", "val", "test"]) -> Dict[str, Any]: r""" Perform adversarial data agumentation during one training step using FLAG. Paper: https://arxiv.org/abs/2010.09891 @@ -456,56 +476,51 @@ def flag_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool) -> D ) loss = loss / n_steps - device = "cpu" if to_cpu else None for key in preds.keys(): - preds[key] = preds[key].detach().to(device=device) - targets[key] = targets[key].detach().to(device=device) + preds[key] = preds[key].detach() + targets[key] = targets[key].detach() if weights is not None: - weights = weights.detach().to(device=device) + weights = weights.detach() - step_dict = {"preds": preds, "targets": targets, "weights": weights} + step_dict = {} step_dict[f"loss/{step_name}"] = loss.detach().cpu() step_dict["loss"] = loss step_dict["task_losses"] = task_losses + self.task_epoch_summary[step_name].update(preds, targets) return step_dict def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]: - self.train_batch_start_time = time.time() + + self.model_grad.reset() + self.task_epoch_summary["train"].reset() + self.batch_start_time = time.time() self.skip_log_train_metrics = (self.metrics_every_n_train_steps is None) or ( (batch_idx % self.metrics_every_n_train_steps) != 0 ) return super().on_train_batch_start(batch, batch_idx) def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None: - train_batch_time = time.time() - self.train_batch_start_time # To be used for throughput calculation + train_batch_time = time.time() - self.batch_start_time # To be used for throughput calculation # Get the metrics that are logged at every step (loss, grad_norm, batch_time, batch_tput) - concatenated_metrics_logs = {} - concatenated_metrics_logs["train/loss"] = outputs["loss"] - concatenated_metrics_logs["epoch_count"] = self.current_epoch - # Incriment by the batch size - self.samples_seen += self.global_bs - concatenated_metrics_logs["samples_seen"] = self.samples_seen + metrics_logs = {} # report the training loss for each individual tasks - for task in self.tasks: - concatenated_metrics_logs[f"train/loss/{task}"] = outputs["task_losses"][task] - # get the mean loss value for individual tasks as they are a tensor of size --> gradient accumulation * replication * device_iter # filter zeros out for the individual losses - for key in concatenated_metrics_logs: - if isinstance(concatenated_metrics_logs[key], torch.Tensor): - if concatenated_metrics_logs[key].numel() > 1: - concatenated_metrics_logs[key] = concatenated_metrics_logs[key][ - concatenated_metrics_logs[key] != 0 - ].mean() + losses = self._get_average_loss_from_outputs(outputs, step_name="train") + + metrics_logs.update(losses) + metrics_logs.update(self.task_epoch_summary["train"].compute()) # If logging is skipped for this step, then log the important metrics anyway and return if self.skip_log_train_metrics: - if self.logger is not None: - self.logger.log_metrics( - concatenated_metrics_logs, step=self.global_step - ) # This is a pytorch lightning function call + self.log_dict( + dictionary=metrics_logs, + logger=True, + on_step=True, + prog_bar=True, + ) return ### The code below is not executed if the logging is skipped for this step ### @@ -513,155 +528,130 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None: # Get the throughput of the batch num_graphs = self.get_num_graphs(batch["features"]) tput = num_graphs / train_batch_time - concatenated_metrics_logs["train/batch_time"] = train_batch_time - concatenated_metrics_logs["train/batch_tput"] = tput - - # Compute all the metrics for the training set - self.task_epoch_summary.update_predictor_state( - step_name="train", - targets=outputs["targets"], - preds=outputs["preds"], - loss=outputs["loss"], # This is the weighted loss for now, but change to task-specific loss - task_losses=outputs["task_losses"], - n_epochs=self.current_epoch, - ) - metrics_logs = self.task_epoch_summary.get_metrics_logs() # Dict[task, metric_logs] - metrics_logs["_global"]["grad_norm"] = self.get_gradient_norm() - concatenated_metrics_logs.update(metrics_logs) + metrics_logs["_global/batch_time/train"] = train_batch_time + metrics_logs["_global/batch_tput/train"] = tput + self.mean_time_tracker.update(train_batch_time) + self.mean_tput_tracker.update(tput) + + metrics_computed = self.task_epoch_summary["train"].compute() + self.task_epoch_summary["train"].reset() + metrics_logs.update(metrics_computed) + metrics_logs["_global/grad_norm/train"] = self.model_grad.compute() + self.model_grad.reset() # Log the metrics - if self.logger is not None: - self.logger.log_metrics( - concatenated_metrics_logs, step=self.global_step - ) # This is a pytorch lightning function call + self.log_dict( + dictionary=metrics_logs, + logger=True, + on_step=True, + prog_bar=True, + ) - def training_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> Dict[str, Any]: + def training_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]: step_dict = None # Train using FLAG if self.flag_kwargs["n_steps"] > 0: - step_dict = self.flag_step(batch=batch, step_name="train", to_cpu=to_cpu) + step_dict = self.flag_step(batch=batch, step_name="train") # Train normally, without using FLAG elif self.flag_kwargs["n_steps"] == 0: - # step_dict = self._general_step(batch=batch, step_name="train", to_cpu=True) - step_dict = self._general_step(batch=batch, step_name="train", to_cpu=to_cpu) + # step_dict = self._general_step(batch=batch, step_name="train") + step_dict = self._general_step(batch=batch, step_name="train") + + # Update the gradients + self.model_grad.update(self.model) - # Remove the preds and targets if no logging is required - if self.skip_log_train_metrics: - step_dict.pop("preds") - step_dict.pop("targets") return step_dict # Returning the metrics_logs with the loss - def get_gradient_norm(self): - # compute the norm - total_norm = torch.tensor(0.0) - for p in self.parameters(): - if p.grad is not None: - param_norm = p.grad.detach().data.norm(2) - total_norm += param_norm.detach().cpu() ** 2 - total_norm = total_norm**0.5 - return total_norm + def validation_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]: + return self._general_step(batch=batch, step_name="val") + + def test_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]: + return self._general_step(batch=batch, step_name="test") + + def _general_epoch_start(self, step_name: Literal["train", "val", "test"]) -> None: + self.task_epoch_summary[step_name].reset() + self.epoch_start_time[step_name] = time.time() + self.mean_time_tracker.reset() + self.mean_tput_tracker.reset() + + def predict_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]: + preds = self.forward(batch) # The dictionary of predictions + targets_dict = batch.get("labels") - def validation_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> Dict[str, Any]: - return self._general_step(batch=batch, step_name="val", to_cpu=to_cpu) + return preds, targets_dict - def test_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> Dict[str, Any]: - return self._general_step(batch=batch, step_name="test", to_cpu=to_cpu) - def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str, device: str) -> None: + def _general_epoch_end(self, step_name: Literal["train", "val", "test"]) -> Dict[str, Tensor]: r"""Common code for training_epoch_end, validation_epoch_end and testing_epoch_end""" # Transform the list of dict of dict, into a dict of list of dict - preds = {} - targets = {} - for task in self.tasks: - preds[task] = torch.cat([out["preds"][task].to(device) for out in outputs], dim=0) - targets[task] = torch.cat([out["targets"][task].to(device) for out in outputs], dim=0) - if ("weights" in outputs[0].keys()) and (outputs[0]["weights"] is not None): - weights = torch.cat([out["weights"].to(device) for out in outputs], dim=0) - else: - weights = None - - # NOTE: Computing the loss over the entire split may cause - # overflow issues when using fp16 - loss, task_losses = self.compute_loss( - preds=dict_tensor_fp16_to_fp32(preds), - targets=dict_tensor_fp16_to_fp32(targets), - weights=weights, - target_nan_mask=self.target_nan_mask, - multitask_handling=self.multitask_handling, - loss_fun=self.loss_fun, - ) - - self.task_epoch_summary.update_predictor_state( - step_name=step_name, - preds=preds, - targets=targets, - loss=loss, - task_losses=task_losses, - n_epochs=self.current_epoch, - ) - metrics_logs = self.task_epoch_summary.get_metrics_logs() - self.task_epoch_summary.set_results(task_metrics=metrics_logs) - - return metrics_logs # Consider returning concatenated dict for logging + + metric_logs = self.task_epoch_summary[step_name].compute() + self.task_epoch_summary[step_name].reset() + metric_logs_cpu = {k: v for k, v in metric_logs.items() if v.device == torch.device("cpu")} + if len(metric_logs_cpu) > 0: + self.log_dict(metric_logs_cpu, logger=True, prog_bar=True, sync_dist=False, on_epoch=True) + + metric_logs_accelerator = {k: v for k, v in metric_logs.items() if v.device != torch.device("cpu")} + if len(metric_logs_accelerator) > 0: + self.log_dict(metric_logs_accelerator, logger=True, prog_bar=True, sync_dist=True, on_epoch=True) + + # Time metrics are tracked always on CPU, without progress bar, so we log them separatly + time_metrics = {} + time_metrics[f"_global/mean_batch_time/{step_name}"] = torch.tensor(self.mean_time_tracker.mean_value) + time_metrics[f"_global/mean_tput/{step_name}"] = self.mean_tput_tracker.mean_value + time_metrics[f"_global/epoch_time/{step_name}"] = torch.tensor(time.time() - self.epoch_start_time[step_name]) + + self.log_dict(time_metrics, logger=True, prog_bar=False, sync_dist=False, on_epoch=True) + + return metric_logs def on_train_epoch_start(self) -> None: - self.epoch_start_time = time.time() + self._general_epoch_start(step_name="train") def on_train_epoch_end(self) -> None: - if self.epoch_start_time is None: - logger.warning("epoch timer not initialized") - else: - epoch_time = time.time() - self.epoch_start_time - self.epoch_start_time = None - self.log("epoch_time", torch.tensor(epoch_time), sync_dist=True) + self._general_epoch_end(step_name="train") def on_validation_epoch_start(self) -> None: - self.mean_val_time_tracker.reset() - self.mean_val_tput_tracker.reset() + self._general_epoch_start(step_name="val") return super().on_validation_epoch_start() def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - self.validation_batch_start_time = time.time() + self.batch_start_time = time.time() return super().on_validation_batch_start(batch, batch_idx, dataloader_idx) def on_validation_batch_end( - self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, outputs, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: - val_batch_time = time.time() - self.validation_batch_start_time - self.validation_step_outputs.append(outputs) - self.mean_val_time_tracker.update(val_batch_time) + val_batch_time = time.time() - self.batch_start_time + self.mean_time_tracker.update(val_batch_time) num_graphs = self.get_num_graphs(batch["features"]) - self.mean_val_tput_tracker.update(num_graphs / val_batch_time) + self.mean_tput_tracker.update(num_graphs / val_batch_time) return super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) def on_validation_epoch_end(self) -> None: - metrics_logs = self._general_epoch_end( - outputs=self.validation_step_outputs, step_name="val", device="cpu" - ) - self.validation_step_outputs.clear() - concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs) - concatenated_metrics_logs["val/mean_time"] = torch.tensor(self.mean_val_time_tracker.mean_value) - concatenated_metrics_logs["val/mean_tput"] = self.mean_val_tput_tracker.mean_value - self.log_dict(concatenated_metrics_logs, sync_dist=True) - - # Save yaml file with the per-task metrics summaries - full_dict = {} - full_dict.update(self.task_epoch_summary.get_dict_summary()) + self._general_epoch_end(step_name="val") + return super().on_validation_epoch_end() - def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - self.test_step_outputs.append(outputs) + def on_test_epoch_start(self) -> None: + self._general_epoch_start(step_name="test") + return super().on_test_epoch_start() def on_test_epoch_end(self) -> None: - metrics_logs = self._general_epoch_end(outputs=self.test_step_outputs, step_name="test", device="cpu") - self.test_step_outputs.clear() - concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs) - - self.log_dict(concatenated_metrics_logs, sync_dist=True) - # Save yaml file with the per-task metrics summaries - full_dict = {} - full_dict.update(self.task_epoch_summary.get_dict_summary()) + self._general_epoch_end(step_name="test") + return super().on_test_epoch_end() + + def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + self.batch_start_time = time.time() + return super().on_test_batch_start(batch, batch_idx, dataloader_idx) + + def on_test_batch_end(self, outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + test_batch_time = time.time() - self.batch_start_time + self.mean_time_tracker.update(test_batch_time) + num_graphs = self.get_num_graphs(batch["features"]) + self.mean_tput_tracker.update(num_graphs / test_batch_time) + return super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx) def on_train_start(self): hparams_log = deepcopy(self.hparams) @@ -669,16 +659,15 @@ def on_train_start(self): if self.logger is not None: self.logger.log_hyperparams(hparams_log) - def get_progress_bar_dict(self) -> Dict[str, float]: - prog_dict = {} - prog_dict["loss"] = self.task_epoch_summary.weighted_loss.detach().cpu() - results_on_progress_bar = self.task_epoch_summary.get_results_on_progress_bar("val") - for task in self.tasks: - prog_dict[self.task_epoch_summary.metric_log_name(task, "loss", "val")] = ( - self.task_epoch_summary.task_summaries[task].summaries["val"].loss - ) - prog_dict.update(results_on_progress_bar) - return prog_dict + @property + def get_metrics_on_progress_bar(self) -> List[str]: + prog_list = ["_global/loss/train"] + for task_name in self.tasks: + for metric in self.metrics_on_progress_bar[task_name]: + this_summary = self.task_epoch_summary["val"][task_name] + prog_list.append(this_summary.metric_log_name(metric)) + + return prog_list def __repr__(self) -> str: r""" @@ -692,10 +681,12 @@ def __repr__(self) -> str: @staticmethod def list_pretrained_models(): """List available pretrained models.""" - return GRAPHIUM_PRETRAINED_MODELS_DICT + from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT + + return GRAPHIUM_PRETRAINED_MODELS_DICT # Avoiding circular imports with `space.py` @staticmethod - def load_pretrained_model(name_or_path: str, device: str = None): + def load_pretrained_model(name_or_path: str, device: str = None, strict: bool = True, **kwargs): """Load a pretrained model from its name. Args: @@ -703,11 +694,13 @@ def load_pretrained_model(name_or_path: str, device: str = None): from `graphium.trainer.PredictorModule.list_pretrained_models()`. """ + from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT # Avoiding circular imports with `space.py` + name = GRAPHIUM_PRETRAINED_MODELS_DICT.get(name_or_path) if name is not None: return PredictorModule.load_from_checkpoint( - GRAPHIUM_PRETRAINED_MODELS_DICT[name_or_path], map_location=device + GRAPHIUM_PRETRAINED_MODELS_DICT[name_or_path], map_location=device, strict=strict, **kwargs ) if name is None and not (fs.exists(name_or_path) and fs.get_extension(name_or_path) == "ckpt"): @@ -716,7 +709,7 @@ def load_pretrained_model(name_or_path: str, device: str = None): "or pass a valid checkpoint (.ckpt) path." ) - return PredictorModule.load_from_checkpoint(name_or_path, map_location=device) + return PredictorModule.load_from_checkpoint(name_or_path, map_location=device, strict=strict, **kwargs) def set_max_nodes_edges_per_graph(self, datamodule: BaseDataModule, stages: Optional[List[str]] = None): datamodule.setup() diff --git a/graphium/trainer/predictor_options.py b/graphium/trainer/predictor_options.py index 04a62e84b..5303976a9 100644 --- a/graphium/trainer/predictor_options.py +++ b/graphium/trainer/predictor_options.py @@ -30,10 +30,6 @@ from torch import nn -from graphium.utils.spaces import LOSS_DICT -from graphium.utils.spaces import SCHEDULER_DICT - - @dataclass class ModelOptions: r""" @@ -117,6 +113,7 @@ def set_kwargs(self): scheduler_class = torch_scheduler_kwargs.pop("module_type") if self.scheduler_class is None: if isinstance(scheduler_class, str): + from graphium.utils.spaces import SCHEDULER_DICT self.scheduler_class = SCHEDULER_DICT[scheduler_class] elif isclass(scheduler_class): self.scheduler_class = scheduler_class @@ -196,12 +193,15 @@ def parse_loss_fun(loss_fun: Union[str, Dict, Callable]) -> Callable: Function or callable to compute the loss, takes `preds` and `targets` as inputs. """ + from graphium.utils.spaces import LOSS_DICT # Avoiding circular imports with `spaces.py` + if isinstance(loss_fun, str): if loss_fun not in LOSS_DICT.keys(): raise ValueError( f"`loss_fun` expected to be one of the strings in {LOSS_DICT.keys()}. " f"Provided: {loss_fun}." ) + loss_name = loss_fun loss_fun = LOSS_DICT[loss_fun]() elif isinstance(loss_fun, dict): if loss_fun.get("name") is None: @@ -214,10 +214,12 @@ def parse_loss_fun(loss_fun: Union[str, Dict, Callable]) -> Callable: loss_fun = deepcopy(loss_fun) loss_name = loss_fun.pop("name") loss_fun = LOSS_DICT[loss_name](**loss_fun) - elif not callable(loss_fun): + elif callable(loss_fun): + loss_name = str(loss_fun) + else: raise ValueError(f"`loss_fun` must be `str`, `dict` or `callable`. Provided: {type(loss_fun)}") - return loss_fun + return loss_name, loss_fun @dataclass diff --git a/graphium/trainer/predictor_summaries.py b/graphium/trainer/predictor_summaries.py index 4cec79377..06e762dde 100644 --- a/graphium/trainer/predictor_summaries.py +++ b/graphium/trainer/predictor_summaries.py @@ -14,14 +14,18 @@ r"""Classes to store information about resulting evaluation metrics when using a Predictor Module.""" -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Literal, Iterable, Set from loguru import logger +from copy import deepcopy +import inspect import numpy as np import torch from torch import Tensor +from torchmetrics import MeanMetric, Metric +from torchmetrics.aggregation import BaseAggregator -from graphium.utils.tensor import nan_mean, nan_std, nan_median, tensor_fp16_to_fp32 +from graphium.trainer.metrics import MetricToConcatenatedTorchMetrics class SummaryInterface(object): @@ -29,37 +33,30 @@ class SummaryInterface(object): An interface to define the functions implemented by summary classes that implement SummaryInterface. """ - def set_results(self, **kwargs): + def update(self, preds: Tensor, targets: Tensor) -> None: raise NotImplementedError() - def get_dict_summary(self): + def compute(self, **kwargs) -> Tensor: raise NotImplementedError() - - def update_predictor_state(self, **kwargs): - raise NotImplementedError() - - def get_metrics_logs(self, **kwargs): + + def reset(self) -> None: raise NotImplementedError() -class Summary(SummaryInterface): - # TODO (Gabriela): Default argument cannot be [] +class SingleTaskSummary(SummaryInterface): def __init__( self, - loss_fun: Union[str, Callable], - metrics: Dict[str, Callable], - metrics_on_training_set: List[str] = [], - metrics_on_progress_bar: List[str] = [], - monitor: str = "loss", - mode: str = "min", + metrics: Dict[str, Union[Metric, "MetricWrapper"]], + step_name: str, + metrics_on_training_set: Optional[List[str]] = None, + metrics_on_progress_bar: Optional[List[str]] = None, task_name: Optional[str] = None, + compute_mean: bool = True, + compute_std: bool = True, ): r""" A container to be used by the Predictor Module that stores the results for the given metrics on the predictions and targets provided. Parameters: - loss_fun: - Loss function used during training. Acceptable strings are 'mse', 'bce', 'mae', 'cosine'. - Otherwise, a callable object must be provided, with a method `loss_fun._get_name()`. metrics: A dictionnary of metrics to compute on the prediction, other than the loss function. @@ -67,394 +64,293 @@ def __init__( metrics_on_training_set: The metrics names from `metrics` to be computed on the training set for each iteration. - If `None`, all the metrics are computed. Using less metrics can significantly improve - performance, depending on the number of readouts. + If `None`, no metrics are computed. metrics_on_progress_bar: - The metrics names from `metrics` to display also on the progress bar of the training - - monitor: - `str` metric to track (Default=`"loss/val"`) + The metrics names from `metrics` to display also on the progress bar of the training. + If `None`, no metrics are displayed. task_name: name of the task (Default=`None`) + compute_mean: + whether to compute the mean of the predictions and targets + + compute_std: + whether to compute the standard deviation of the predictions and targets + """ - self.loss_fun = loss_fun - self.metrics = metrics - self.metrics_on_training_set = metrics_on_training_set - self.metrics_on_progress_bar = metrics_on_progress_bar - self.monitor = monitor - self.mode = mode + self.step_name = step_name + self.compute_mean = compute_mean + self.compute_std = compute_std - self.summaries = {} - self.best_summaries = {} + if not isinstance(metrics, dict): + raise ValueError(f"metrics must be a dictionary. Got {type(metrics)}") + self.metrics = deepcopy(metrics) # Current predictor state # self.predictor_outputs = None - self.step_name: str = None - self.targets: Tensor = None - self.preds: Tensor = None - self.loss = None # What type? - self.n_epochs: int = None - self.task_name = task_name - self.logged_metrics_exceptions = [] # Track which metric exceptions have been logged + self.logged_metrics_exceptions: List[str] = [] # Track which metric exceptions have been logged + self.last_metrics_exceptions: List[str] = [] # Track which metric exceptions have been logged + + # Add default metrics + if ("mean_preds" not in self.metrics) and compute_mean: + self.metrics["mean_preds"] = MeanMetric(nan_strategy="ignore") + if ("mean_target" not in self.metrics) and compute_mean: + self.metrics["mean_target"] = MeanMetric(nan_strategy="ignore") + if ("std_preds" not in self.metrics) and compute_std: + self.metrics["std_preds"] = STDMetric(nan_strategy="ignore") + if ("std_target" not in self.metrics) and compute_std: + self.metrics["std_target"] = STDMetric(nan_strategy="ignore") + + # Parse the metrics filters + self.metrics_on_training_set = self._parse_metrics_filter(metrics_on_training_set) + self.metrics_on_progress_bar = self._parse_metrics_filter(metrics_on_progress_bar) + + # Update the metrics to compute on the training set + if self.compute_mean: + self.metrics_on_training_set.update(["mean_preds", "mean_target"]) + if self.compute_std: + self.metrics_on_training_set.update(["std_preds", "std_target"]) + + self._cached_metrics: Dict[str, Tensor] = {} + self._logged_warnings: Set[str] = set() # Set to track which metrics have been logged + self._device: torch.device = None + + @property + def get_cached_metrics(self) -> Dict[str, Tensor]: + return deepcopy(self._cached_metrics) + + def _parse_metrics_filter(self, filter: Optional[Union[List[str], Dict[str, Any]]]) -> List[str]: + if filter is None: + filter = [] + elif isinstance(filter, dict): + filter = list(filter.keys()) + elif isinstance(filter, (list, tuple, set)): + filter = list(filter) + elif isinstance(filter, str): + filter = [filter] + else: + raise ValueError(f"metrics_to_use must be a list or a dictionary. Got {type(filter)}") - def update_predictor_state( - self, step_name: str, targets: Tensor, preds: Tensor, loss: Tensor, n_epochs: int - ): + # Ensure that the filter is a subset of the metrics + all_metrics = set(self.metrics.keys()) + filter = set(filter) + if not filter.issubset(all_metrics): + raise ValueError(f"metrics_to_use must be a subset of the metrics. Got {filter - all_metrics}, available {all_metrics}") + + return filter + + @property + def metrics_to_use(self) -> Dict[str, Callable]: r""" - update the state of the predictor - Parameters: - step_name: which stage you are in, e.g. "train" - targets: the targets tensor - predictions: the predictions tensor - loss: the loss tensor - n_epochs: the number of epochs + return the metrics to use by filtering the metrics dictionary if it is the training step. Otherwise, return all metrics. """ - self.step_name = step_name - self.targets = targets - self.preds = preds - self.loss = loss - self.n_epochs = n_epochs - def set_results( - self, - metrics: Dict[str, Tensor], - ): + if self.step_name == "train": + metrics_to_use = { + key: metric for key, metric in self.metrics.items() if key in self.metrics_on_training_set + } + + return metrics_to_use + return self.metrics + + @staticmethod + def _update(metric_key:str, metric_obj, preds: Tensor, targets: Tensor) -> None: r""" - set the reults from the metrics - [!] This function requires that self.update_predictor_state() be called before it. + update the state of the metrics Parameters: - metrics: a dictionary of metrics + targets: the targets tensor + predictions: the predictions tensor """ - # Include the task_name in the loss for logging, and similarly for other metrics - metrics[self.metric_log_name(self.task_name, "loss", self.step_name)] = self.loss - self.summaries[self.step_name] = Summary.Results( - targets=self.targets, - preds=self.preds, - loss=self.loss, - metrics=metrics, # Should include task name from get_metrics_logs() - monitored_metric=f"{self.monitor}/{self.step_name}", # Include task name? - n_epochs=self.n_epochs, - ) - if self.is_best_epoch(self.step_name, self.loss, metrics): - self.best_summaries[self.step_name] = self.summaries[self.step_name] + # Check the `metric_obj.update` signature to know if it takes `preds` and `targets` or only one of them + varnames = [val.name for val in inspect.signature(metric_obj.update).parameters.values()] + if ("preds" == varnames[0]) and ("target" == varnames[1]): + # The typical case of `torchmetrics` + metric_obj.update(preds, targets) + elif ("preds" == varnames[1]) and ("target" == varnames[0]): + # Unusual case where the order of the arguments is reversed + metric_obj.update(targets, preds) + elif ("value" == varnames[0]) and ("preds" in metric_key): + # The case where the metric takes only one value, and it is the prediction + metric_obj.update(preds) + elif ("value" == varnames[0]) and ("target" in metric_key): + # The case where the metric takes only one value, and it is the target + metric_obj.update(targets) + else: + raise ValueError(f"Metric {metric_key} update method signature `{varnames}` is not recognized.") - def is_best_epoch(self, step_name: str, loss: Tensor, metrics: Dict[str, Tensor]) -> bool: + + def update(self, preds: Tensor, targets: Tensor) -> None: r""" - check if the current epoch is the best epoch based on self.mode criteria + update the state of the metrics Parameters: - step_name: which stage you are in, e.g. "train" - loss: the loss tensor - metrics: a dictionary of metrics + targets: the targets tensor + predictions: the predictions tensor """ - # TODO (Gabriela): Check for bugs related to monitor_name - if not (step_name in self.best_summaries.keys()): - return True - - # Include the task_name in the loss for logging, and similarly for other metrics - metrics[self.metric_log_name(self.task_name, "loss", self.step_name)] = loss - monitor_name = f"{self.monitor}/{step_name}" # Include task_name? - if ( - not monitor_name in self.best_summaries.keys() - ): # Feels like there's a bug here. What is this trying to do??? - return True - - if self.mode == "max": - return metrics[monitor_name] > self.best_summaries[step_name].monitored - elif self.mode == "min": - return metrics[monitor_name] < self.best_summaries[step_name].monitored + self._device = preds.device + + for metric_key, metric_obj in self.metrics_to_use.items(): + metric_obj.to(self.device) + try: + self._update(metric_key, metric_obj, preds, targets) + except Exception as err: + err_msg = f"Error for metric {metric_key} on task {self.task_name} and step {self.step_name}. Exception: {err}" + # Check if the error is due to the device mismatch, cast to the device, and retry + + if err_msg not in self._logged_warnings: + logger.warning(err_msg) + self._logged_warnings.add(err_msg) + + + def _compute(self, metrics_to_use: Optional[Union[List[str], Dict[str, Any]]] = None) -> Dict[str, Tensor]: + + # Parse the metrics to use + if metrics_to_use is None: + metrics_to_use = list(self.metrics.keys()) + elif isinstance(metrics_to_use, dict): + metrics_to_use = list(metrics_to_use.keys()) else: - ValueError(f"Mode must be 'min' or 'max', provided `{self.mode}`") + raise ValueError(f"metrics_to_use must be a list or a dictionary. Got {type(metrics_to_use)}") + + self.last_metrics_exceptions = [] # Reset the exceptions for this step + + # Compute the metrics + computed_metrics = {} + for metric_key in metrics_to_use: + metric_name = self.metric_log_name(metric_key) + metric_obj = self.metrics[metric_key] + try: + computed_metrics[f"{metric_name}"] = metric_obj.compute() + except Exception as e: + # If the metric computation fails, return NaN and log a warning only once + computed_metrics[f"{metric_name}"] = torch.tensor(torch.nan, device=self.device) + # Warn only if it's the first warning for that metric + if metric_name not in self.logged_metrics_exceptions: + self.logged_metrics_exceptions.append(metric_name) + logger.warning(f"Error for metric {metric_name}. NaN is returned. Exception: {e}") + self.last_metrics_exceptions.append(metric_name) - def get_results( - self, - step_name: str, - ): + return computed_metrics + + def compute(self) -> Dict[str, Tensor]: r""" - retrieve the results for a given step - Parameters: - step_name: which stage you are in, e.g. "train" + compute the metrics Returns: - the results for the given step + the computed metrics """ - return self.summaries[step_name] + computed_metrics = self._compute(metrics_to_use=self.metrics_to_use) + self._cached_metrics = computed_metrics - def get_best_results( - self, - step_name: str, - ): + return computed_metrics + + def reset(self) -> None: r""" - retrieve the best results for a given step - Parameters: - step_name: which stage you are in, e.g. "train" - Returns: - the best results for the given step + reset the state of the metrics """ - return self.best_summaries[step_name] + for metric_key, metric in self.metrics.items(): + try: + metric.reset() + except AttributeError as e: + metric_name = self.metric_log_name(metric_key) + # Skip error if the message is `AttributeError: 'Tensor' object has no attribute 'clear'. Did you mean: 'char'?` + # This error happens when there's nothing to reset, usually because the metric failed. + if (metric_name not in self.last_metrics_exceptions) or ("'Tensor' object has no attribute 'clear'" not in str(e)): + raise e def get_results_on_progress_bar( self, - step_name: str, ) -> Dict[str, Tensor]: r""" retrieve the results to be displayed on the progress bar for a given step - Parameters: - step_name: which stage you are in, e.g. "train" - Returns: - the results to be displayed on the progress bar for the given step - """ - results = self.summaries[step_name] - results_prog = { - # f"{kk}/{step_name}": results.metrics[f"{kk}/{step_name}"] for kk in self.metrics_on_progress_bar - self.metric_log_name(self.task_name, kk, step_name): results.metrics[ - self.metric_log_name(self.task_name, kk, step_name) - ] - for kk in self.metrics_on_progress_bar - } - return results_prog - def get_dict_summary(self) -> Dict[str, Any]: - r""" - retrieve the full summary in a dictionary Returns: - the full summary in a dictionary - """ - full_dict = {} - # Get metric summaries - full_dict["metric_summaries"] = {} - for key, val in self.summaries.items(): - full_dict["metric_summaries"][key] = {k: v for k, v in val.metrics.items()} - full_dict["metric_summaries"][key]["n_epochs"] = val.n_epochs - - # Get metric summaries at best epoch - full_dict["best_epoch_metric_summaries"] = {} - for key, val in self.best_summaries.items(): - full_dict["best_epoch_metric_summaries"][key] = val.metrics - full_dict["best_epoch_metric_summaries"][key]["n_epochs"] = val.n_epochs - - return full_dict - - def get_metrics_logs(self) -> Dict[str, Any]: - r""" - Get the data about metrics to log. - Note: This function requires that self.update_predictor_state() be called before it. - Returns: - A dictionary of metrics to log. + the results to be displayed on the progress bar for the given step """ + cached_metrics = self.get_cached_metrics + if cached_metrics is None: + results_prog = self._compute(metrics_to_use=self.metrics_on_progress_bar) + else: + results_prog = {} + for metric_key in self.metrics_on_progress_bar: + metric_name = self.metric_log_name(metric_key) + results_prog[metric_name] = cached_metrics[metric_name] - targets = tensor_fp16_to_fp32(self.targets) - preds = tensor_fp16_to_fp32(self.preds) - - targets = targets.to(dtype=preds.dtype, device=preds.device) - - # Compute the metrics always used in regression tasks - metric_logs = {} - metric_logs[self.metric_log_name(self.task_name, "mean_pred", self.step_name)] = nan_mean(preds) - metric_logs[self.metric_log_name(self.task_name, "std_pred", self.step_name)] = nan_std(preds) - metric_logs[self.metric_log_name(self.task_name, "median_pred", self.step_name)] = nan_median(preds) - metric_logs[self.metric_log_name(self.task_name, "mean_target", self.step_name)] = nan_mean(targets) - metric_logs[self.metric_log_name(self.task_name, "std_target", self.step_name)] = nan_std(targets) - metric_logs[self.metric_log_name(self.task_name, "median_target", self.step_name)] = nan_median( - targets - ) - - # Specify which metrics to use - metrics_to_use = self.metrics - if self.step_name == "train": - metrics_to_use = { - key: metric for key, metric in metrics_to_use.items() if key in self.metrics_on_training_set - } - # Compute the additional metrics - for key, metric in metrics_to_use.items(): - metric_name = self.metric_log_name( - self.task_name, key, self.step_name - ) # f"{key}/{self.step_name}" - try: - metric_logs[metric_name] = metric(preds, targets) - except Exception as e: - metric_logs[metric_name] = torch.as_tensor(float("nan")) - # Warn only if it's the first warning for that metric - if metric_name not in self.logged_metrics_exceptions: - self.logged_metrics_exceptions.append(metric_name) - logger.warning(f"Error for metric {metric_name}. NaN is returned. Exception: {e}") - - # Convert all metrics to CPU, except for the loss - # metric_logs[f"{self.loss_fun._get_name()}/{self.step_name}"] = self.loss.detach().cpu() - metric_logs[ - self.metric_log_name(self.task_name, self.loss_fun._get_name(), self.step_name) - ] = self.loss.detach().cpu() - # print("Metrics logs keys: ", metric_logs.keys()) - metric_logs = {key: metric.detach().cpu() for key, metric in metric_logs.items()} - - return metric_logs + return results_prog - def metric_log_name(self, task_name, metric_name, step_name): - if task_name is None: - return f"{metric_name}/{step_name}" + def metric_log_name(self, metric_name): + if self.task_name is None: + return f"{metric_name}/{self.step_name}" else: - return f"{task_name}/{metric_name}/{step_name}" - - class Results: - def __init__( - self, - targets: Tensor = None, - preds: Tensor = None, - loss: float = None, # Is this supposed to be a Tensor or float? - metrics: dict = None, - monitored_metric: str = None, - n_epochs: int = None, - ): - r""" - This inner class is used as a container for storing the results of the summary. - Parameters: - targets: the targets - preds: the prediction tensor - loss: the loss, float or tensor - metrics: the metrics - monitored_metric: the monitored metric - n_epochs: the number of epochs - """ - self.targets = targets.detach().cpu() - self.preds = preds.detach().cpu() - self.loss = loss.item() if isinstance(loss, Tensor) else loss - self.monitored_metric = monitored_metric - if monitored_metric in metrics.keys(): - self.monitored = metrics[monitored_metric].detach().cpu() - self.metrics = { - key: value.tolist() if isinstance(value, (Tensor, np.ndarray)) else value - for key, value in metrics.items() - } - self.n_epochs = n_epochs + return f"{self.task_name}/{metric_name}/{self.step_name}" + + @property + def device(self) -> Optional[torch.device]: + return self._device -class TaskSummaries(SummaryInterface): +class MultiTaskSummary(SummaryInterface): def __init__( self, - task_loss_fun: Callable, - task_metrics: Dict[str, Callable], - task_metrics_on_training_set: List[str], - task_metrics_on_progress_bar: List[str], - monitor: str = "loss", - mode: str = "min", + task_metrics: Dict[str, Dict[str, Union[Metric, "MetricWrapper"]]], + step_name: str, + task_metrics_on_training_set: Optional[Dict[str, List[str]]] = None, + task_metrics_on_progress_bar: Optional[Dict[str, List[str]]] = None, + compute_mean: bool = True, + compute_std: bool = True, ): r""" class to store the summaries of the tasks Parameters: - task_loss_fun: the loss function for each task - task_metrics: the metrics for each task - task_metrics_on_training_set: the metrics to use on the training set - task_metrics_on_progress_bar: the metrics to use on the progress bar - monitor: the metric to monitor - mode: the mode of the metric to monitor + + + compute_mean: + whether to compute the mean of the predictions and targets + + compute_std: + whether to compute the standard deviation of the predictions and targets + """ - self.task_loss_fun = task_loss_fun self.task_metrics = task_metrics - self.task_metrics_on_progress_bar = task_metrics_on_progress_bar - self.task_metrics_on_training_set = task_metrics_on_training_set - self.monitor = monitor - self.mode = mode - - self.task_summaries: Dict[str, Summary] = {} - self.task_best_summaries: Dict[str, Summary] = {} - self.tasks = list(task_loss_fun.keys()) + self.task_metrics_on_progress_bar = task_metrics_on_progress_bar if task_metrics_on_progress_bar is not None else {} + self.task_metrics_on_training_set = task_metrics_on_training_set if task_metrics_on_training_set is not None else {} + # Initialize all the single-task summaries + self.tasks = list(task_metrics.keys()) + self.task_summaries: Dict[str, SingleTaskSummary] = {} for task in self.tasks: - self.task_summaries[task] = Summary( - self.task_loss_fun[task], - self.task_metrics[task], - self.task_metrics_on_training_set[task], - self.task_metrics_on_progress_bar[task], - self.monitor, - self.mode, - task_name=task, + self.task_summaries[task] = SingleTaskSummary( + metrics = self.task_metrics[task], + step_name = step_name, + metrics_on_training_set = self.task_metrics_on_training_set[task] if task in self.task_metrics_on_training_set else None, + metrics_on_progress_bar = self.task_metrics_on_progress_bar[task] if task in self.task_metrics_on_progress_bar else None, + task_name = task, + compute_mean = compute_mean, + compute_std = compute_std, ) - # Current predictor state - self.weighted_loss = None - self.step_name = None + def __getitem__(self, task: str) -> SingleTaskSummary: + return self.task_summaries[task] + + def keys(self) -> List[str]: + return self.tasks - def update_predictor_state( - self, - step_name: str, - targets: Dict[str, Tensor], - preds: Dict[str, Tensor], - loss: Tensor, - task_losses: Dict[str, Tensor], - n_epochs: int, - ): + def update(self, preds: Dict[str, Tensor], targets: Dict[str, Tensor]) -> None: r""" update the state for all predictors Parameters: - step_name: the name of the step targets: the target tensors preds: the prediction tensors - loss: the loss tensor - task_losses: the task losses - n_epochs: the number of epochs """ - self.weighted_loss = loss - self.step_name = step_name for task in self.tasks: - self.task_summaries[task].update_predictor_state( - step_name, - targets[task], + self.task_summaries[task].update( preds[task].detach(), - task_losses[task].detach(), - n_epochs, + targets[task], ) - def set_results(self, task_metrics: Dict[str, Dict[str, Tensor]]): - """ - set the results for all tasks - Parameters: - task_metrics: the metrics for each task - """ - for task in self.tasks: - self.task_summaries[task].set_results(task_metrics[task]) - step_name = self.task_summaries[task].step_name - loss = self.task_summaries[task].loss - if self.task_summaries[task].is_best_epoch(step_name, loss, task_metrics[task]): - self.task_summaries[task].best_summaries[step_name] = self.task_summaries[task].summaries[ - step_name - ] - - def get_results( - self, - step_name: str, - ) -> Dict[str, Dict[str, Any]]: - """ - retrieve the results - Parameters: - step_name: the name of the step, i.e. "train" - Returns: - the results - """ - results = {} - for task in self.tasks: - results[task] = self.task_summaries[task].get_results(step_name) - return results - - def get_best_results( - self, - step_name: str, - ) -> Dict[str, Dict[str, Any]]: - """ - retrieve the best results - Parameters: - step_name: the name of the step, i.e. "train" - Returns: - the best results - """ - results = {} - for task in self.tasks: - results[task] = self.task_summaries[task].get_best_results(step_name) - return results - def get_results_on_progress_bar( self, step_name: str, @@ -469,77 +365,112 @@ def get_results_on_progress_bar( """ task_results_prog = {} for task in self.tasks: - # task_results_prog[task] = self.task_summaries[task].get_results_on_progress_bar(step_name) task_results_prog.update(self.task_summaries[task].get_results_on_progress_bar(step_name)) return task_results_prog - def get_dict_summary( - self, - ) -> Dict[str, Dict[str, Any]]: + def compute(self) -> Dict[str, Tensor]: r""" - get task summaries in a dictionary + compute the metrics for all tasks Returns: - the task summaries + the computed metrics for all tasks """ - task_full_dict = {} + computed_metrics = {} for task in self.tasks: - task_full_dict[task] = self.task_summaries[task].get_dict_summary() - return task_full_dict - - def get_metrics_logs( - self, - ) -> Dict[str, Dict[str, Tensor]]: + computed_metrics.update(self.task_summaries[task].compute()) + return computed_metrics + + def reset(self) -> None: r""" - get the logs for the metrics - Returns: - the task logs for the metrics + reset the state of the metrics """ - task_metrics_logs = {} for task in self.tasks: - task_metrics_logs[task] = self.task_summaries[task].get_metrics_logs() - # average metrics - for key in task_metrics_logs[task]: - if isinstance(task_metrics_logs[task][key], torch.Tensor): - if task_metrics_logs[task][key].numel() > 1: - task_metrics_logs[task][key] = task_metrics_logs[task][key][ - task_metrics_logs[task][key] != 0 - ].mean() - - # Include global (weighted loss) - task_metrics_logs["_global"] = {} - task_metrics_logs["_global"][f"loss/{self.step_name}"] = self.weighted_loss.detach().cpu() - return task_metrics_logs - - # TODO (Gabriela): This works to fix the logging on TB, but make it more efficient - def concatenate_metrics_logs( - self, - metrics_logs: Dict[str, Dict[str, Tensor]], - ) -> Dict[str, Tensor]: - r""" - concatenate the metrics logs - Parameters: - metrics_logs: the metrics logs - Returns: - the concatenated metrics logs - """ - concatenated_metrics_logs = {} - for task in list(self.tasks) + ["_global"]: - concatenated_metrics_logs.update(metrics_logs[task]) - concatenated_metrics_logs[f"loss/{self.step_name}"] = self.weighted_loss.detach().cpu() - return concatenated_metrics_logs + self.task_summaries[task].reset() - def metric_log_name( - self, - task_name: str, - metric_name: str, - step_name: str, - ) -> str: - r""" - print the metric name, task name and step name - Returns: - the metric name, task name and step name - """ - if task_name is None: - return f"{metric_name}/{step_name}" + +class STDMetric(BaseAggregator): + """ + A metric to compute the standard deviation of the predictions or targets. + Based on `torchmetrics.Metric`, with a similar implementation to `torchmetric.MeanMetric`. + + Parameters: + correction: + The correction to apply to the standard deviation. Instead of dividing by number of samples `N`, + we divide by `N-correction`. + + nan_strategy: options: + - ``'error'``: if any `nan` values are encountered will give a RuntimeError + - ``'warn'``: if any `nan` values are encountered will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impute any `nan` values with this value + + """ + def __init__(self, nan_strategy: Union[Literal["error", "warn", "ignore"], float]="warn", correction:int=0, **kwargs): + super().__init__( + "sum", + default_value=torch.tensor(0.0, dtype=torch.get_default_dtype()), + nan_strategy=nan_strategy, + state_name="mean_value", + **kwargs, + ) + self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("sum_of_squares", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total_weight", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.correction = correction + + def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None: + if not isinstance(value, Tensor): + value = torch.as_tensor(value, dtype=torch.float32, device=value.device) + if not isinstance(weight, Tensor): + weight = torch.as_tensor(weight, dtype=torch.float32, device=value.device) + + weight = torch.broadcast_to(weight, value.shape).clone() + # Check whether `_cast_and_nan_check_input` takes in `weight` + if "weight" in inspect.signature(self._cast_and_nan_check_input).parameters: + value, weight = self._cast_and_nan_check_input(value, weight) else: - return f"{task_name}/{metric_name}/{step_name}" + weight[value.isnan()] = torch.nan + value = self._cast_and_nan_check_input(value) + weight = self._cast_and_nan_check_input(weight) + + if value.numel() == 0: + return + + self.sum += (value * weight).sum() + self.sum_of_squares += (value * value * weight).sum() + self.total_weight += weight.sum() + + def compute(self) -> Tensor: + dividor = max(0, self.total_weight - self.correction) + mean = self.sum / self.total_weight + mean_of_squares = self.sum_of_squares / self.total_weight + variance = mean_of_squares - mean ** 2 + variance_corr = variance * (self.total_weight / dividor) + return torch.sqrt(variance_corr) + +class GradientNormMetric(Metric): + """ + A metric to compute the norm of the gradient. + Based on `torchmetrics.Metric`. + + Warning: + This metric is not compatible with other metrics since it doesn't take + the predictions and targets as input. It takes the model as input. + It also doesn't work per task, but for the full model + """ + def __init__(self, dist_sync_on_step=False): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state("gradient_norm_sq", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total_steps", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, model: torch.nn.Module) -> None: + total_norm = torch.tensor(0.0, device=self.device) + for p in model.parameters(): + if p.grad is not None: + param_norm = p.grad.detach().data.norm(2) + total_norm += param_norm.detach() ** 2 + self.gradient_norm_sq += total_norm + self.total_steps += 1 + + def compute(self) -> Tensor: + return (self.gradient_norm_sq / self.total_steps).sqrt() + diff --git a/graphium/trainer/progress_bar.py b/graphium/trainer/progress_bar.py new file mode 100644 index 000000000..a5dde1293 --- /dev/null +++ b/graphium/trainer/progress_bar.py @@ -0,0 +1,27 @@ +import sys +from typing import Any, Callable, Dict, List, Optional, Union, Literal, Iterable +from lightning.pytorch.callbacks import TQDMProgressBar + + + +class ProgressBarMetrics(TQDMProgressBar): + def __init__(self, metrics_on_progress_bar: Optional[Iterable[str]] = None, loss_alias:Optional[str]="_global/loss/train") -> None: + super().__init__() + if metrics_on_progress_bar is None: + metrics_on_progress_bar = {} + self.metrics_on_progress_bar = set(metrics_on_progress_bar) + self.loss_alias = loss_alias + + def get_metrics(self, trainer, pl_module) -> Dict[str, Union[int, str, float, Dict[str, float]]]: + + metrics = super().get_metrics(trainer, pl_module) + filtered_metrics = {} + for key, metric in metrics.items(): + if key in self.metrics_on_progress_bar: + if key == self.loss_alias: + filtered_metrics["loss"] = metric + else: + filtered_metrics[key] = metric + + return filtered_metrics + diff --git a/graphium/utils/packing.py b/graphium/utils/packing.py deleted file mode 100644 index 6db6856b1..000000000 --- a/graphium/utils/packing.py +++ /dev/null @@ -1,330 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -from typing import List, Tuple, Iterable, Optional -import numpy as np -import torch - - -class MolPack: - """ - Class that keeps track of the number of atoms and indices that are added - to each pack. Useful when doing packing, or other forms of smart batching. - A pack is a batch, but with optimized memory consumption. - """ - - def __init__(self): - self.num_nodes = 0 - self.num_graphs = 0 - self.average_atom = 0 - self.indices = [] - - def add_mol(self, num_nodes: int, idx: int) -> "MolPack": - """ - Add a molecule and it's index to the batch - - Parameters: - num_nodes: Number of atoms of the new molecule - - idx: Index associated to the molecule - """ - self.num_nodes += num_nodes - self.num_graphs += 1 - self.average_atom = self.num_nodes / self.num_graphs - self.indices.append(idx) - return self - - def expected_atoms(self, remaining_mean_num_nodes: float, batch_size: int) -> float: - """ - Given a desired batch size, and given the remaining mean number of - atoms, find the expected number of atoms of the current batch when it is full - - Parameters: - remaining_mean_num_nodes: Average number of atoms per molecule - left to be sampled and distributed across tasks. - - batch_size: Desired batch size - - Returns: - expected_atoms: The expected number of atoms in this batch if we - sample randomly the remaining molecules. - """ - return self.num_nodes + ((batch_size - self.num_graphs) * remaining_mean_num_nodes) - - def __repr__(self) -> str: - """ - Print the main attributes of the current class - """ - return f"{self.__class__.__name__}(m: {self.num_graphs},\ta: {self.num_nodes},\tav: {self.average_atom:.1f})" - - -def smart_packing(num_nodes: List[int], batch_size: int) -> List[List[int]]: - """ - Simple and fast algorithm for packing graphs such that each batch has roughly the - same number of atoms. - Has for-loop scalability issues `O(num_graphs * ipu_batch_size)` = `O(num_graphs^2 / batch_size)` - - Parameters: - num_nodes: List of the number of atoms per molecule for the entire global batch. - Must be of length `batch_size * ipu_batch_size`. - - batch_size: The batch size per iteration, considering a single device and single - forward pass. - The global batch size is `batch_size * device_iterations * replication_factor * gradient_accumulation` - - Returns: - packed_indices: A list of packs, each containing a list of indices, such that - if we collect `num_nodes` from the indices, then each pack has roughly the - same total number of atoms. - """ - - # Sort the list - num_nodes = np.asarray(num_nodes) - argsort_num_nodes = np.argsort(num_nodes) - sorted_num_nodes = num_nodes[argsort_num_nodes] - ipu_batch_size = int(len(num_nodes) / batch_size) - sorted_num_nodes, initial_num_nodes = ( - sorted_num_nodes[:-ipu_batch_size], - sorted_num_nodes[-ipu_batch_size:], - ) - reverse_cumsum = np.sum(sorted_num_nodes) - np.cumsum(sorted_num_nodes) + sorted_num_nodes[-1] - - # Start with the largest element in separate packs - mol_batches = [ - MolPack().add_mol(initial_num_nodes[-ii - 1], argsort_num_nodes[-ii - 1]) - for ii in range(ipu_batch_size) - ] - - # Loop from smallest to largest molecule, and add each molecule to the pack with smallest expected sum - for ii, num_atom in enumerate(sorted_num_nodes): - remaining_mean = reverse_cumsum[ii] / (len(sorted_num_nodes) - ii) - max_expected, idx_max_expected = 0, 0 - for jj, m in enumerate(mol_batches): - if m.num_graphs >= batch_size: - continue - expected = m.num_nodes + ( - (batch_size - m.num_graphs) * remaining_mean - ) # Faster than calling m.expected_atoms - if expected > max_expected: - max_expected = expected - idx_max_expected = jj - mol_batches[idx_max_expected].add_mol(num_atom, argsort_num_nodes[ii]) - - packed_indices = [batch.indices for batch in mol_batches] - - return packed_indices - - -def fast_packing(num_nodes: List[int], batch_size: int) -> List[List[int]]: - """ - Super fast algorithm for packing graphs such that each batch has roughly the - same number of atoms. Not as good as `smart_packing` but - faster and more scalable for-loop complexity of `O(batch_size)`. - - Parameters: - num_nodes: List of the number of atoms per molecule for the entire global batch. - Must be of length `batch_size * ipu_batch_size`. - - batch_size: The batch size per iteration, considering a single device and single - forward pass. - The global batch size is `batch_size * device_iterations * replication_factor * gradient_accumulation` - - Returns: - packed_indices: A list of packs, each containing a list of indices, such that - if we collect `num_nodes` from the indices, then each pack has roughly the - same total number of atoms. - """ - num_nodes = np.asarray(num_nodes) - argsort_num_nodes = np.argsort(num_nodes) - ipu_batch_size = int(len(num_nodes) / batch_size) - - packed_indices = np.stack( - [ - np.random.permutation(argsort_num_nodes[ii * ipu_batch_size : (ii + 1) * ipu_batch_size]) - for ii in range(batch_size) - ], - axis=0, - ).T.tolist() - return packed_indices - - -def hybrid_packing(num_nodes: List[int], batch_size: int) -> List[List[int]]: - """ - Uses a combination of the `smart_packing` `O(n^2)` on the most important data points, - and the `fast_packing` `O(n)` on the average-sized data points. - - Depending on the expected complexity - - Parameters: - num_nodes: List of the number of atoms per molecule for the entire global batch. - Must be of length `batch_size * ipu_batch_size`. - - batch_size: The batch size per iteration, considering a single device and single - forward pass. - The global batch size is `batch_size * device_iterations * replication_factor * gradient_accumulation` - - Returns: - packed_indices: A list of packs, each containing a list of indices, such that - if we collect `num_nodes` from the indices, then each pack has roughly the - same total number of atoms. - """ - - # Determine the parameters based on the complexity of the smart-packing. - # The bigger the complexity, the more the `fast_packing` algorithm becomes - # statistically powerful, and the more speed benefits it provides. - smart_packing_complexity = len(num_nodes) ** 2 / batch_size - if smart_packing_complexity < 1e4: - return smart_packing(num_nodes=num_nodes, batch_size=batch_size) - elif smart_packing_complexity < 1e5: - big, small = 3, 6 - else: - return fast_packing(num_nodes=num_nodes, batch_size=batch_size) - - # Small datasets benefit from smart-packing, without compute burden - ipu_batch_size = int(len(num_nodes) / batch_size) - if len(num_nodes) < (big + small) * ipu_batch_size: - return smart_packing(num_nodes=num_nodes, batch_size=batch_size) - - # Sort the list - num_nodes = np.asarray(num_nodes) - argsort_num_nodes = np.argsort(num_nodes) - - # Smallest and biggest graphs are often outliers and will benefit from the `smart_packing` - biggest_graphs = argsort_num_nodes[-big * ipu_batch_size :] - smallest_graphs = argsort_num_nodes[: small * ipu_batch_size] - big_n_small_graphs = np.concatenate([biggest_graphs, smallest_graphs]) - big_n_small_packs = smart_packing(num_nodes[big_n_small_graphs], batch_size=big + small) - big_n_small_indices = [big_n_small_graphs[pack] for pack in big_n_small_packs] - big_n_small_nodes = [num_nodes[pack] for pack in big_n_small_indices] - - # Medium graphs will be packed faster - medium_graphs = argsort_num_nodes[small * ipu_batch_size : -big * ipu_batch_size] - medium_packs = fast_packing(num_nodes[medium_graphs], batch_size=batch_size - big - small) - medium_indices = [medium_graphs[pack] for pack in medium_packs] - medium_nodes = [num_nodes[pack] for pack in medium_indices] - - # Pack the big/small with the medium in a smart way - big_n_small_sort = np.argsort(np.sum(np.stack(big_n_small_nodes, axis=1), axis=0)) - medium_sort = np.argsort(np.sum(np.stack(medium_nodes, axis=1), axis=0)) - packed_indices = [ - np.concatenate([medium_indices[medium_sort[ii]], big_n_small_indices[big_n_small_sort[-ii]]]) - for ii in range(len(medium_indices)) - ] - - return packed_indices - - -def get_pack_sizes(packed_indices, num_nodes): - """ - Get the number of atoms of each pack - """ - pack_sums = [] - for pack in packed_indices: - pack_sum = 0 - for idx in pack: - pack_sum += num_nodes[idx] - pack_sums.append(pack_sum) - return pack_sums - - -def estimate_max_pack_node_size(num_nodes: Iterable[int], batch_size: int, combined_batch_size: int): - """ - Estimate the value of `max_num_nodes`, which represents the maximum number of nodes - needed in a batch to fit the data. - - Parameters: - num_nodes: Number of nodes for all the graphs in the dataset - batch_size: The regular batch size per IPU - combined_batch_size: batch_size * device_iterations - * replication_factor * gradient_accumulation - - """ - - # Estimate the packing size needed - rand_indices = np.arange(len(num_nodes)) - np.random.shuffle(rand_indices) - max_pack_size = 0 - for ii in range(0, len(num_nodes), combined_batch_size): - this_indices = rand_indices[ii : ii + combined_batch_size] - choice = num_nodes[this_indices] - if len(choice) == combined_batch_size: - packed_indices = hybrid_packing(choice, batch_size) - max_pack_size = max(max_pack_size, max(get_pack_sizes(packed_indices, num_nodes[this_indices]))) - max_pack_size_per_graph = max_pack_size / batch_size - - return max_pack_size, max_pack_size_per_graph - - -def node_to_pack_indices_mask( - packed_indices: Iterable[Iterable[int]], all_num_nodes: Iterable[int], max_pack_size: Optional[int] = None -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Given a list of packed indices, and the number of nodes in each graph, - return a tensor of shape (sum(all_num_nodes), 2) where the first column - is the pack index, and the second column is the node index within the pack. - - Can be used to generate a dense packing of the nodes as follows: - ``` - # node_features: A tensor of shape (num_nodes, num_node_features) - # num_packs: The number of packs desired - # max_nodes_per_pack: The maximum number of nodes per pack - # dense_pack: A tensor of shape (num_packs, max_nodes_per_pack, num_node_features) - - dense_pack = torch.zeros([num_packs, max_nodes_per_pack, num_node_features]) - dense_pack[pack_from_node_idx[:, 0], pack_from_node_idx[:, 1]] = node_features - ``` - - This is useful when using a Transformer, to avoid wasteful padding when the - the longest sequence is much longer than the average sequence length. - - Parameters: - packed_indices: A list of lists of graph indices, where each sub-list - represents a pack of graphs - all_num_nodes: The number of nodes in each graph - max_pack_size: The maximum number of nodes per pack. If None, will be - infered from the provided packs. - Useful to determine the shape of the `pack_attn_mask`. - - Returns: - pack_from_node_idx: A tensor of shape (num_nodes, 2) where the first column - is the pack index, and the second column is the node index within the pack. - - pack_attn_mask: A tensor of shape (num_packs, max_pack_size, max_pack_size), - that represents the attention masking for each pack, - such that the graphs in the pack are masked out from each other. - """ - - all_num_nodes = torch.as_tensor(all_num_nodes, dtype=torch.long) - cumsum_num_nodes = torch.cumsum(all_num_nodes, dim=0) - if max_pack_size is None: - pack_sizes = get_pack_sizes(packed_indices, all_num_nodes) - max_pack_size = max(pack_sizes) - - # Get the node indices associated to the packs, with 0 padding - pack_from_node_idx = torch.zeros(sum(all_num_nodes), 2, dtype=torch.long) - pack_attn_mask = [] # masks for the attention - for ii, pack in enumerate(packed_indices): - jj = 0 # Counter for the number of nodes in the pack - this_pack_attn_mask = torch.ones((max_pack_size, max_pack_size), dtype=torch.bool) - for graph_idx in pack: - num_nodes = all_num_nodes[graph_idx] - node_idx = torch.arange(cumsum_num_nodes[graph_idx] - num_nodes, cumsum_num_nodes[graph_idx]) - this_pack_attn_mask[jj : jj + num_nodes, jj : jj + num_nodes] = False - pack_from_node_idx[node_idx, 0] = ii - pack_from_node_idx[node_idx, 1] = jj + torch.arange(num_nodes) - jj += num_nodes - pack_attn_mask.append(this_pack_attn_mask) - pack_attn_mask = torch.stack(pack_attn_mask, dim=0) - - return pack_from_node_idx, pack_attn_mask diff --git a/graphium/utils/read_file.py b/graphium/utils/read_file.py deleted file mode 100644 index 27d2fb216..000000000 --- a/graphium/utils/read_file.py +++ /dev/null @@ -1,173 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -""" Utiles for data parsing""" -import os -import warnings -import numpy as np -import pandas as pd -import datamol as dm -from functools import partial -from copy import copy -import fsspec - -from loguru import logger -from rdkit import Chem -from rdkit.Chem.Descriptors import ExactMolWt - -from graphium.utils.tensor import parse_valid_args, arg_in_func - - -def read_file(filepath, as_ext=None, **kwargs): - r""" - Allow to read different file format and parse them into a MolecularDataFrame. - Supported formats are: - * csv (.csv, .smile, .smiles, .tsv) - * txt (.txt) - * xls (.xls, .xlsx, .xlsm, .xls*) - * sdf (.sdf) - * pkl (.pkl) - - Arguments - ----------- - - filepath: str - The full path and name of the file to read. - It also supports the s3 url path. - as_ext: str, Optional - The file extension used to read the file. If None, the extension is deduced - from the extension of the file. Otherwise, no matter the file extension, - the file will be read according to the specified ``as_ext``. - (Default=None) - **kwargs: All the optional parameters required for the desired file reader. - - TODO: unit test to make sure it works well with all extensions - - Returns - --------- - df: pandas.DataFrame - The ``pandas.DataFrame`` containing the parsed data - - """ - - # Get the file extension - if as_ext is None: - file_ext = os.path.splitext(filepath)[-1].lower()[1:] - else: - file_ext = as_ext - if not isinstance(file_ext, str): - raise TypeError("`file_type` must be a `str`. Provided: {}".format(file_ext)) - - open_mode = "r" - - # Read the file according to the right extension - if file_ext in ["csv", "smile", "smiles", "smi", "tsv"]: - file_reader = pd.read_csv - elif file_ext == "txt": - file_reader = pd.read_table - elif file_ext[0:3] == "xls": - open_mode = "rb" - file_reader = partial(pd.read_excel, engine="openpyxl") - elif file_ext == "sdf": - file_reader = parse_sdf_to_dataframe - elif file_ext == "pkl": - open_mode = "rb" - file_reader = pd.read_pickle - else: - raise 'File extension "{}" not supported'.format(file_ext) - - kwargs = parse_valid_args(fn=file_reader, param_dict=kwargs) - - if file_ext[0:3] not in ["sdf", "xls"]: - with file_opener(filepath, open_mode) as file_in: - data = file_reader(file_in, **kwargs) - else: - data = file_reader(filepath, **kwargs) - return data - - -def parse_sdf_to_dataframe(sdf_path, as_cxsmiles=True, skiprows=None): - r""" - Allows to read an SDF file containing molecular informations, convert - it to a pandas DataFrame and convert the molecules to SMILES. It also - lists a warning of all the molecules that couldn't be read. - - Arguments - ----------- - - sdf_path: str - The full path and name of the sdf file to read - as_cxsmiles: bool, optional - Whether to use the CXSMILES notation, which preserves atomic coordinates, - stereocenters, and much more. - See `https://dl.chemaxon.com/marvin-archive/latest/help/formats/cxsmiles-doc.html` - (Default = True) - skiprows: int, list - The rows to skip from dataset. The enumerate index starts from 1 insted of 0. - (Default = None) - - """ - - # read the SDF file - # locally or from s3 - data = dm.read_sdf(sdf_path) - - # For each molecule in the SDF file, read all the properties and add it to a list of dict. - # Also count the number of molecules that cannot be read. - data_list = [] - count_none = 0 - if skiprows is not None: - if isinstance(skiprows, int): - skiprows = range(0, skiprows - 1) - skiprows = np.array(skiprows) - 1 - - for idx, mol in enumerate(data): - if (skiprows is not None) and (idx in skiprows): - continue - - if (mol is not None) and (ExactMolWt(mol) > 0): - mol_dict = mol.GetPropsAsDict() - data_list.append(mol_dict) - if as_cxsmiles: - smiles = Chem.rdmolfiles.MolToCXSmiles(mol, canonical=True) - else: - smiles = dm.to_smiles(mol, canonical=True) - data_list[-1]["SMILES"] = smiles - else: - count_none += 1 - logger.info(f"Could not read molecule # {idx}") - - # Display a message or warning after the SDF is done parsing - if count_none == 0: - logger.info("Successfully read the SDF file without error: {}".format(sdf_path)) - else: - warnings.warn( - ( - 'Error reading {} molecules from the "{}" file.\ - {} molecules read successfully.' - ).format(count_none, sdf_path, len(data_list)) - ) - return pd.DataFrame(data_list) - - -def file_opener(filename, mode="r"): - """File reader stream""" - filename = str(filename) - if "w" in mode: - filename = "simplecache::" + filename - if filename.endswith(".gz"): - instream = fsspec.open(filename, mode=mode, compression="gzip") - else: - instream = fsspec.open(filename, mode=mode) - return instream diff --git a/graphium/utils/spaces.py b/graphium/utils/spaces.py index 88812c0be..5f42f843c 100644 --- a/graphium/utils/spaces.py +++ b/graphium/utils/spaces.py @@ -15,15 +15,13 @@ from copy import deepcopy import torch import torch.optim.lr_scheduler as sc -import torchmetrics.functional as TorchMetrics +import torchmetrics as TorchMetrics import graphium.nn.base_layers as BaseLayers import graphium.nn.ensemble_layers as EnsembleLayers import graphium.nn.architectures as Architectures import graphium.utils.custom_lr as CustomLR import graphium.data.datamodule as Datamodules -import graphium.ipu.ipu_losses as IPULosses -import graphium.ipu.ipu_metrics as Metrics import graphium.nn.pyg_layers as PygLayers import graphium.nn.residual_connections as Residuals import graphium.nn.encoders as Encoders @@ -79,12 +77,6 @@ "l1": torch.nn.L1Loss, "mae": torch.nn.L1Loss, "hybrid_ce": Losses.HybridCELoss, - "bce_ipu": IPULosses.BCELossIPU, - "bce_logits_ipu": IPULosses.BCEWithLogitsLossIPU, - "mse_ipu": IPULosses.MSELossIPU, - "mae_ipu": IPULosses.L1LossIPU, - "l1_ipu": IPULosses.L1LossIPU, - "hybrid_ce_ipu": IPULosses.HybridCELossIPU, } @@ -102,39 +94,27 @@ } METRICS_CLASSIFICATION = { - "accuracy": TorchMetrics.accuracy, - "averageprecision": TorchMetrics.average_precision, - "auroc": TorchMetrics.auroc, - "confusionmatrix": TorchMetrics.confusion_matrix, - "f1": TorchMetrics.f1_score, - "fbeta": TorchMetrics.fbeta_score, - "precisionrecallcurve": TorchMetrics.precision_recall_curve, - "precision": TorchMetrics.precision, - "recall": TorchMetrics.recall, - "mcc": TorchMetrics.matthews_corrcoef, - "auroc_ipu": Metrics.auroc_ipu, - "accuracy_ipu": Metrics.accuracy_ipu, - "average_precision_ipu": Metrics.average_precision_ipu, - "f1_ipu": Metrics.f1_score_ipu, - "fbeta_ipu": Metrics.fbeta_score_ipu, - "precision_ipu": Metrics.precision_ipu, - "recall_ipu": Metrics.recall_ipu, + "accuracy": TorchMetrics.Accuracy, + "averageprecision": TorchMetrics.functional.average_precision, # Not using a class to better handle concatenation of preds and targets + "auroc": TorchMetrics.functional.auroc, # Not using a class to better handle concatenation of preds and targets + "confusionmatrix": TorchMetrics.ConfusionMatrix, + "f1": TorchMetrics.F1Score, + "fbeta": TorchMetrics.FBetaScore, + "precisionrecallcurve": TorchMetrics.PrecisionRecallCurve, + "precision": TorchMetrics.Precision, + "recall": TorchMetrics.Recall, + "mcc": TorchMetrics.MatthewsCorrCoef, } METRICS_REGRESSION = { - "mae": TorchMetrics.mean_absolute_error, - "mape": TorchMetrics.mean_absolute_percentage_error, - "mse": TorchMetrics.mean_squared_error, - "msle": TorchMetrics.mean_squared_log_error, - "pearsonr": TorchMetrics.pearson_corrcoef, - "spearmanr": TorchMetrics.spearman_corrcoef, - "r2_score": TorchMetrics.r2_score, - "cosine": TorchMetrics.cosine_similarity, - "pearsonr_ipu": Metrics.pearson_ipu, - "spearmanr_ipu": Metrics.spearman_ipu, - "r2_score_ipu": Metrics.r2_score_ipu, - "mae_ipu": Metrics.mean_absolute_error_ipu, - "mse_ipu": Metrics.mean_squared_error_ipu, + "mae": TorchMetrics.MeanAbsoluteError, + "mape": TorchMetrics.MeanAbsolutePercentageError, + "mse": TorchMetrics.MeanSquaredError, + "msle": TorchMetrics.MeanSquaredLogError, + "pearsonr": TorchMetrics.PearsonCorrCoef, + "spearmanr": TorchMetrics.SpearmanCorrCoef, + "r2_score": TorchMetrics.R2Score, + "cosine": TorchMetrics.CosineSimilarity, } METRICS_DICT = deepcopy(METRICS_CLASSIFICATION) @@ -144,8 +124,7 @@ DATAMODULE_DICT = { "GraphOGBDataModule": Datamodules.GraphOGBDataModule, "MultitaskFromSmilesDataModule": Datamodules.MultitaskFromSmilesDataModule, - "ADMETBenchmarkDataModule": Datamodules.ADMETBenchmarkDataModule, - "FakeDataModule": Datamodules.FakeDataModule, + "TDCBenchmarkDataModule": Datamodules.TDCBenchmarkDataModule, } GRAPHIUM_PRETRAINED_MODELS_DICT = { diff --git a/install_ipu.sh b/install_ipu.sh deleted file mode 100755 index a21022bdb..000000000 --- a/install_ipu.sh +++ /dev/null @@ -1,112 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Graphcore Limited. -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Graphcore Limited is not liable -for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -#!/bin/bash - -# Default location for the virtual environment -default_venv_name=".graphium_ipu" - -# Allow the user to specify the location of their virtual environment -# If not specified, use the default location -venv_name=${1:-$default_venv_name} - -# Constants -sdk_compressed_file="poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz" -sdk_wheel_file="poptorch-3.3.0+113432_960e9c294b_ubuntu_20_04-cp38-cp38-linux_x86_64.whl" -sdk_url="https://downloads.graphcore.ai/direct?package=poplar-poplar_sdk_ubuntu_20_04_3.3.0_208993bbb7-3.3.0&file=${sdk_compressed_file}" -sdk_path="${venv_name}/poplar_sdk-ubuntu_20_04-3.3.0+1403-208993bbb7" - -# Check for Python3 and pip -if ! command -v python3 &>/dev/null; then - echo "Python3 is required but it's not installed. Exiting." - exit 1 -fi - -if ! command -v pip3 &>/dev/null; then - echo "pip3 is required but it's not installed. Exiting." - exit 1 -fi - -# Remove existing venv directory if it exists -if [[ -d $venv_name ]]; then - echo "Removing existing virtual environment directory..." - rm -rf $venv_name -fi - -# Create the virtual environment -echo "Creating virtual environment..." -mkdir -p $venv_name -python3 -m venv $venv_name -source $venv_name/bin/activate - -# Update pip to the latest version -echo "Upgrading pip..." -python3 -m pip install --upgrade pip - -# Download the Poplar SDK -echo "Downloading Poplar SDK..." -wget -q -O "${venv_name}/${sdk_compressed_file}" "$sdk_url" - -# Check the wget exit status -if [ $? -ne 0 ]; then - echo "Failed to download Poplar SDK. Exiting." - exit 1 -fi - -# Unzip the SDK file -echo "Extracting Poplar SDK..." -tar -xzf "$venv_name/$sdk_compressed_file" -C $venv_name - -# Install the PopTorch wheel -echo "Installing PopTorch..." -python3 -m pip install "${sdk_path}/${sdk_wheel_file}" - -# Enable Poplar SDK (including Poplar and PopART) -echo "Enabling Poplar SDK..." -source ${sdk_path}/enable - -# Install the IPU specific and Graphium requirements -echo "Installing IPU specific and Graphium requirements..." -python3 -m pip install -r requirements_ipu.txt - -# Install Graphium in dev mode -echo "Installing Graphium in dev mode..." -python3 -m pip install --no-deps -e . - -# This is a quick test make sure poptorch is correctly installed -if python3 -c "import poptorch;print('poptorch installed correctly')" &> /dev/null; then - echo "Installation completed successfully." -else - echo "Installation was not successful. Please check the logs and try again." - exit 1 # Exit with status code 1 to indicate failure -fi - -# Download the datafiles (Total ~ 10Mb - nothing compared to the libraries) -echo "Downloading the sub-datasets consisting on the ToyMix dataset" -toymix_dir=expts/data/neurips2023/small-dataset/ -mkdir -p $toymix_dir - -base_url="https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/" -files=("ZINC12k.csv.gz" "Tox21-7k-12-labels.csv.gz" "qm9.csv.gz" "qm9_random_splits.pt" "Tox21_random_splits.pt" "ZINC12k_random_splits.pt") - -for file in "${files[@]}"; do - if [ ! -f "${toymix_dir}${file}" ]; then - echo "Downloading ${file}..." - wget -P "${toymix_dir}" "${base_url}${file}" - else - echo "${file} already exists. Skipping..." - fi -done - -echo "Data has been successfully downloaded." \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 0d9f34bfd..0fb199fa6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -23,7 +23,6 @@ nav: - graphium.data: api/graphium.data.md - graphium.utils: api/graphium.utils.md - graphium.config: api/graphium.config.md - - graphium.ipu: api/graphium.ipu.md - graphium.finetuning: api/graphium.finetuning.md - Tutorials: - feature_processing: @@ -36,7 +35,6 @@ nav: - Using GNN layers: tutorials/gnn/using_gnn_layers.ipynb - model_training: - Simple Molecular Model: tutorials/model_training/simple-molecular-model.ipynb - - Training on IPU: tutorials/model_training/running-multitask-ipu.ipynb - Design: design.md - Datasets: datasets.md - Pretrained Models: pretrained_models.md diff --git a/notebooks/finetuning-on-tdc-admet-benchmark.ipynb b/notebooks/finetuning-on-tdc-admet-benchmark.ipynb index 43eb47081..69ec83bd5 100644 --- a/notebooks/finetuning-on-tdc-admet-benchmark.ipynb +++ b/notebooks/finetuning-on-tdc-admet-benchmark.ipynb @@ -47,7 +47,7 @@ "\n", "[TDC](https://tdcommons.ai/) hosts a variety of ML-ready datasets and benchmarks for ML for drug discovery. The [TDC ADMET benchmarking group](https://tdcommons.ai/benchmark/admet_group/overview/) is a popular collection of benchmarks for evaluating new _foundation models_ (see e.g. [MolE](https://arxiv.org/abs/2211.02657)) due to the variety and relevance of the included tasks.\n", "\n", - "The ADMET benchmarking group is integrated in `graphium` through the `ADMETBenchmarkDataModule` data-module. This notebook shows how to easily fine-tune and test a model using that data-module. \n", + "The ADMET benchmarking group is integrated in `graphium` through the `TDCBenchmarkDataModule` data-module. This notebook shows how to easily fine-tune and test a model using that data-module. \n", "\n", "
\n", " NOTE: This notebook is still work in progress. While the fine-tuning logic is unfinished, the notebook does demo how one could use the data-module to easily loop over each of the datasets in the benchmarking group and get the prescribed train-test split. Once the fine-tuning logic is finalized, we will finish this notebook and officially provide it as a tutorial within Graphium. \n", @@ -59,7 +59,20 @@ "execution_count": 3, "id": "4d5af838", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: '../expts/configs/config_tdc_admet_demo.yaml'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# First, let's read the yaml configuration file\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m../expts/configs/config_tdc_admet_demo.yaml\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m file:\n\u001b[1;32m 3\u001b[0m config \u001b[38;5;241m=\u001b[39m yaml\u001b[38;5;241m.\u001b[39mload(file, Loader\u001b[38;5;241m=\u001b[39myaml\u001b[38;5;241m.\u001b[39mFullLoader)\n", + "File \u001b[0;32m~/miniconda3/envs/graphium3/lib/python3.12/site-packages/IPython/core/interactiveshell.py:324\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 319\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 322\u001b[0m )\n\u001b[0;32m--> 324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../expts/configs/config_tdc_admet_demo.yaml'" + ] + } + ], "source": [ "# First, let's read the yaml configuration file\n", "with open(\"../expts/configs/config_tdc_admet_demo.yaml\", \"r\") as file:\n", @@ -125,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "9538abfb", "metadata": {}, "outputs": [], @@ -173,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "1ee4586e", "metadata": {}, "outputs": [], @@ -184,8 +197,8 @@ " have settings related to a subset of the endpoints\n", " \"\"\"\n", " \n", - " if config[\"datamodule\"][\"module_type\"] != \"ADMETBenchmarkDataModule\":\n", - " raise ValueError(\"You can only use this method for the `ADMETBenchmarkDataModule`\")\n", + " if config[\"datamodule\"][\"module_type\"] != \"TDCBenchmarkDataModule\":\n", + " raise ValueError(\"You can only use this method for the `TDCBenchmarkDataModule`\")\n", " \n", " if isinstance(names, str):\n", " names = [names]\n", @@ -896,7 +909,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/profiling/configs_profiling.yaml b/profiling/configs_profiling.yaml index 0ff4f6c94..bde4bdb5f 100644 --- a/profiling/configs_profiling.yaml +++ b/profiling/configs_profiling.yaml @@ -11,8 +11,6 @@ datamodule: smiles_col: SMILES # Featurization - featurization_n_jobs: -1 - featurization_progress: True featurization: atom_property_list_onehot: [atomic-number, valence] atom_property_list_float: [mass, electronegativity] diff --git a/profiling/profile_mol_to_graph.py b/profiling/profile_mol_to_graph.py index 423f487cf..e8bf19315 100644 --- a/profiling/profile_mol_to_graph.py +++ b/profiling/profile_mol_to_graph.py @@ -16,7 +16,7 @@ import pickle from graphium.data.utils import load_micro_zinc -from graphium.features.featurizer import mol_to_pyggraph, mol_to_adj_and_features, mol_to_graph_dict +from graphium.features.featurizer import mol_to_pyggraph # Check out this profiling tool: https://kirillstrelkov.medium.com/python-profiling-with-vscode-3a17c0407833 @@ -67,10 +67,7 @@ def main(): graphs = [] for s in tqdm(smiles): - mol = dm.to_mol( - s - ) # Doesn't need `ordered=True` because this is just to test the speed of the featurizer - graphs.append(mol_to_graph_dict(mol, **featurizer)) + graphs.append(mol_to_pyggraph(s, **featurizer)) print(graphs[0]) diff --git a/pyproject.toml b/pyproject.toml index 78a5869da..364d8fd1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,7 +104,6 @@ filterwarnings = [ "ignore::DeprecationWarning:pkg_resources.*:", ] markers = [ - "ipu: marks tests that are specific to the IPU (deselect with '-m \"not ipu\"')", ] [tool.coverage.run] diff --git a/scripts/ipu_start.sh b/scripts/ipu_start.sh deleted file mode 100644 index b50ffcbd9..000000000 --- a/scripts/ipu_start.sh +++ /dev/null @@ -1,25 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -""" -Start the ipu environment and SDK -""" - -source /opt/gc/sdk-3.0.0+1128/poplar-ubuntu_20_04-3.0.0+5468-0379b9a65d/enable.sh -source /opt/gc/sdk-3.0.0+1128/popart-ubuntu_20_04-3.0.0+5468-0379b9a65d/enable.sh - -source ~/.venv/graphium_ipu/bin/activate # Change to your path - -export VISUAL=vim -export EDITOR="$VISUAL" diff --git a/scripts/ipu_venv.sh b/scripts/ipu_venv.sh deleted file mode 100644 index 826fcfa12..000000000 --- a/scripts/ipu_venv.sh +++ /dev/null @@ -1,30 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -""" -Create the pip environment for IPU -""" - -## Uncomment this to create the folder for the environment -# mkdir ~/.venv # Create the folder for the environment -# python3 -m venv ~/.venv/graphium_ipu # Create the environment -# source ~/.venv/graphium_ipu/bin/activate # Activate the environment - -# Installing the dependencies for the IPU environment -pip install torch==1.10+cpu torchvision==0.11+cpu torchaudio==0.10 -f https://download.pytorch.org/whl/torch_stable.html -pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.10.0+cpu.html -pip install dgl dglgo -f https://data.dgl.ai/wheels/repo.html -pip install /opt/gc/sdk-3.0.0+1128/poptorch-3.0.0+84519_672c9cbc7f_ubuntu_20_04-cp38-cp38-linux_x86_64.whl -pip install -r requirements.txt -pip install -e . diff --git a/tests/config_test_ipu_dataloader.yaml b/tests/config_test_dataloader.yaml similarity index 93% rename from tests/config_test_ipu_dataloader.yaml rename to tests/config_test_dataloader.yaml index f0f55d197..c1ef946d1 100644 --- a/tests/config_test_ipu_dataloader.yaml +++ b/tests/config_test_dataloader.yaml @@ -1,22 +1,14 @@ -# Testing the multitask pipeline with the QM9 dataset on IPU, by splitting it up into three tasks: homo, alpha and cv. +# Testing the multitask pipeline with the QM9 dataset, by splitting it up into three tasks: homo, alpha and cv. constants: - name: &name test_ipu #qm9_full + name: &name test_dataloader seed: &seed 42 raise_train_error: true # Whether the code should raise an error if it crashes during training accelerator: - type: ipu # cpu or ipu or gpu + type: cpu # cpu or gpu config_override: datamodule: args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 20 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 60 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 16 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 120 # Data handling-related batch_size_training: 6 batch_size_inference: 6 @@ -25,9 +17,6 @@ accelerator: precision: 16 accumulate_grad_batches: 4 - ipu_config: - - deviceIterations(2) - datamodule: module_type: "MultitaskFromSmilesDataModule" args: # Matches that in the test_multitask_datamodule.py case. @@ -269,8 +258,8 @@ predictor: homo: ["mae"] alpha: ["mae"] loss_fun: - homo: mse_ipu - alpha: mse_ipu + homo: mse + alpha: mse random_seed: *seed optim_kwargs: lr: 1.e-3 @@ -300,4 +289,4 @@ trainer: every_n_epochs: 1 trainer: max_epochs: 2 - min_epochs: 1 + min_epochs: 1 \ No newline at end of file diff --git a/tests/config_test_ipu_dataloader_multitask.yaml b/tests/config_test_ipu_dataloader_multitask.yaml deleted file mode 100644 index 8b8fbf417..000000000 --- a/tests/config_test_ipu_dataloader_multitask.yaml +++ /dev/null @@ -1,342 +0,0 @@ -# Testing the gcn model with the PCQMv2 dataset on IPU. -constants: - name: &name neurips2023_small_data_gcn - seed: &seed 42 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -accelerator: - type: ipu # cpu or ipu or gpu - config_override: - datamodule: - args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 80 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 44 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 80 - # Data handling-related - batch_size_training: 50 - batch_size_inference: 50 - predictor: - optim_kwargs: - loss_scaling: 1024 - trainer: - trainer: - precision: 16 - accumulate_grad_batches: 4 - - ipu_config: - - deviceIterations(5) # IPU would require large batches to be ready for the model. - - replicationFactor(1) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - - useIpuModel(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - -datamodule: - module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data - args: # Matches that in the test_multitask_datamodule.py case. - task_specific_args: # To be replaced by a new class "DatasetParams" - qm9: - df: null - df_path: qm9.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"] - sample_size: 2000 # use sample_size for test - seed: *seed - task_level: graph - label_normalization: - normalize_val_test: True - method: "normal" - - tox21: - df: null - df_path: Tox21-7k-12-labels.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] - sample_size: 2000 # use sample_size for test - seed: *seed - task_level: graph - - zinc: - df: null - df_path: ZINC12k.csv.gz - # df_path: data/neurips2023/small-dataset/ZINC12k.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["SA", "logp", "score"] - sample_size: 2000 # use sample_size for test - seed: *seed - task_level: graph - label_normalization: - normalize_val_test: True - method: "normal" - - # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - # processed_graph_data_path: "../datacache/neurips2023-small/" - featurization: - # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), - # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', - # 'num_chiral_centers (not included yet)'] - atom_property_list_onehot: [atomic-number, group, period, total-valence] - atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring] - # OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring'] - edge_property_list: [bond-type-onehot, stereo, in-ring] - add_self_loop: False - explicit_H: False # if H is included - use_bonds_weights: False - pos_encoding_as_features: # encoder dropout 0.18 - pos_types: - lap_eigvec: - pos_level: node - pos_type: laplacian_eigvec - num_pos: 8 - normalization: "none" # nomrlization already applied on the eigen vectors - disconnected_comp: True # if eigen values/vector for disconnected graph are included - lap_eigval: - pos_level: node - pos_type: laplacian_eigval - num_pos: 8 - normalization: "none" # nomrlization already applied on the eigen vectors - disconnected_comp: True # if eigen values/vector for disconnected graph are included - rw_pos: # use same name as pe_encoder - pos_level: node - pos_type: rw_return_probs - ksteps: 16 - - num_workers: -1 # -1 to use all - persistent_workers: False # if use persistent worker at the start of each epoch. - # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" - - -architecture: - model_type: FullGraphMultiTaskNetwork - mup_base_path: null - pre_nn: # Set as null to avoid a pre-nn network - out_dim: 16 - hidden_dims: 16 - depth: 1 - activation: relu - last_activation: none - dropout: &dropout 0.1 - normalization: &normalization layer_norm - last_normalization: *normalization - residual_type: none - - pre_nn_edges: null # Set as null to avoid a pre-nn network - - pe_encoders: - out_dim: 32 - pool: "sum" #"mean" "max" - last_norm: None #"batch_norm", "layer_norm" - encoders: #la_pos | rw_pos - la_pos: # Set as null to avoid a pre-nn network - encoder_type: "laplacian_pe" - input_keys: ["laplacian_eigvec", "laplacian_eigval"] - output_keys: ["feat"] - hidden_dim: 64 - out_dim: 32 - model_type: 'DeepSet' #'Transformer' or 'DeepSet' - num_layers: 2 - num_layers_post: 1 # Num. layers to apply after pooling - dropout: 0.1 - first_normalization: "none" #"batch_norm" or "layer_norm" - rw_pos: - encoder_type: "mlp" - input_keys: ["rw_return_probs"] - output_keys: ["feat"] - hidden_dim: 64 - out_dim: 32 - num_layers: 2 - dropout: 0.1 - normalization: "layer_norm" #"batch_norm" or "layer_norm" - first_normalization: "layer_norm" #"batch_norm" or "layer_norm" - - - - gnn: # Set as null to avoid a post-nn network - in_dim: 16 # or otherwise the correct value - out_dim: &gnn_dim 16 - hidden_dims: *gnn_dim - depth: 1 - activation: gelu - last_activation: none - dropout: 0.1 - normalization: "layer_norm" - last_normalization: *normalization - residual_type: simple - virtual_node: 'none' - layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps - layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1 - - - graph_output_nn: - graph: - pooling: [sum] - out_dim: *gnn_dim - hidden_dims: *gnn_dim - depth: 1 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - - task_heads: - qm9: - task_level: graph - out_dim: 19 - hidden_dims: 16 - depth: 1 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - tox21: - task_level: graph - out_dim: 12 - hidden_dims: 16 - depth: 1 - activation: relu - last_activation: sigmoid - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - zinc: - task_level: graph - out_dim: 3 - hidden_dims: 16 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - -#Task-specific -predictor: - metrics_on_progress_bar: - qm9: ["mae"] - tox21: ["auroc"] - zinc: ["mae"] - loss_fun: - qm9: mae_ipu - tox21: bce_ipu - zinc: mae_ipu - random_seed: *seed - optim_kwargs: - lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs - # weight_decay: 1.e-7 - torch_scheduler_kwargs: - module_type: WarmUpLinearLR - max_num_epochs: &max_epochs 1 - warmup_epochs: 1 - verbose: False - scheduler_kwargs: - # monitor: &monitor qm9/mae/train - # mode: min - # frequency: 1 - target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label - multitask_handling: flatten # flatten, mean-per-label - -# Task-specific -metrics: - qm9: &qm9_metrics - - name: mae - metric: mae_ipu - target_nan_mask: null - multitask_handling: flatten - threshold_kwargs: null - - name: pearsonr - metric: pearsonr_ipu - threshold_kwargs: null - target_nan_mask: null - multitask_handling: mean-per-label - - name: r2_score - metric: r2_score_ipu - target_nan_mask: null - multitask_handling: mean-per-label - threshold_kwargs: null - tox21: - - name: auroc - metric: auroc_ipu - task: binary - multitask_handling: mean-per-label - threshold_kwargs: null - - name: avpr - metric: average_precision_ipu - task: binary - multitask_handling: mean-per-label - threshold_kwargs: null - - name: f1 > 0.5 - metric: f1 - multitask_handling: mean-per-label - target_to_int: True - num_classes: 2 - average: micro - threshold_kwargs: &threshold_05 - operator: greater - threshold: 0.5 - th_on_preds: True - th_on_target: True - - name: precision > 0.5 - metric: precision - multitask_handling: mean-per-label - average: micro - threshold_kwargs: *threshold_05 - zinc: *qm9_metrics - -trainer: - seed: *seed - logger: - save_dir: logs/neurips2023-small/ - name: *name - project: *name - #early_stopping: - # monitor: *monitor - # min_delta: 0 - # patience: 10 - # mode: &mode min - model_checkpoint: - dirpath: models_checkpoints/neurips2023-small-gcn/ - filename: *name - # monitor: *monitor - # mode: *mode - # save_top_k: 1 - save_last: True - trainer: - max_epochs: *max_epochs - min_epochs: 1 - check_val_every_n_epoch: 20 diff --git a/tests/data/config_micro_ZINC.yaml b/tests/data/config_micro_ZINC.yaml index 88fc4a841..d2e94318f 100644 --- a/tests/data/config_micro_ZINC.yaml +++ b/tests/data/config_micro_ZINC.yaml @@ -11,8 +11,6 @@ datamodule: smiles_col: SMILES # Featurization - featurization_n_jobs: -1 - featurization_progress: True featurization: atom_property_list_onehot: [atomic-number, valence] atom_property_list_float: [mass, electronegativity, in-ring] diff --git a/tests/data/dummy_node_label_order_data.parquet b/tests/data/dummy_node_label_order_data.parquet new file mode 100644 index 000000000..a9a165d82 Binary files /dev/null and b/tests/data/dummy_node_label_order_data.parquet differ diff --git a/tests/dummy-pretrained-model.ckpt b/tests/dummy-pretrained-model.ckpt index b1312cffa..41e4df2b4 100644 Binary files a/tests/dummy-pretrained-model.ckpt and b/tests/dummy-pretrained-model.ckpt differ diff --git a/tests/test_attention.py b/tests/test_attention.py index 28b9cd2a1..bedea4933 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -21,8 +21,8 @@ import torch import unittest as ut from torch_geometric.data import Data, Batch +from torch_geometric.utils import to_dense_batch from copy import deepcopy -from graphium.ipu.to_dense_batch import to_dense_batch from graphium.nn.base_layers import MultiheadAttentionMup @@ -65,12 +65,11 @@ def test_attention_class(self): attention_layer_bias = MultiheadAttentionMup(biased_attention=True, **self.attn_kwargs) attention_layer_bias.eval() - h_dense, mask, _ = to_dense_batch( + h_dense, mask = to_dense_batch( bg.feat, batch=bg.batch, batch_size=None, - max_num_nodes_per_graph=None, - drop_nodes_last_graph=False, + max_num_nodes=None, ) # attn_bias [batch, num_heads, nodes, nodes] nodes = h_dense.size()[1] diff --git a/tests/test_base_layers.py b/tests/test_base_layers.py index 2093619f2..6e153d480 100644 --- a/tests/test_base_layers.py +++ b/tests/test_base_layers.py @@ -19,10 +19,10 @@ import torch import unittest as ut from torch_geometric.data import Data, Batch +from torch_geometric.utils import to_dense_batch, dense_to_sparse from copy import deepcopy from graphium.nn.base_layers import DropPath, TransformerEncoderLayerMup -from graphium.ipu.to_dense_batch import to_dense_batch, to_sparse_batch class test_Base_Layers(ut.TestCase): @@ -78,18 +78,19 @@ def test_transformer_encoder_layer_mup(self): biased_attention=False, d_model=self.in_dim, nhead=1, dim_feedforward=4 * self.in_dim ) - feat_dense, key_padding_mask, idx = to_dense_batch( + feat_dense, key_padding_mask = to_dense_batch( feat_in, batch=bg.batch, batch_size=self.batch_size, - max_num_nodes_per_graph=self.max_num_nodes_per_graph, - drop_nodes_last_graph=False, + max_num_nodes=self.max_num_nodes_per_graph, ) - attn_mask = None - key_padding_mask = ~key_padding_mask + key_padding_mask = ~key_padding_mask h_out_dense = layer.forward(feat_dense) - - h_out = to_sparse_batch(h_out_dense, mask_idx=idx) + h_out = h_out_dense[~key_padding_mask] self.assertEqual(h_out.shape, feat_in.shape) + + +if __name__ == "__main__": + ut.main() diff --git a/tests/test_collate.py b/tests/test_collate.py index 3cb453b32..6524596d6 100644 --- a/tests/test_collate.py +++ b/tests/test_collate.py @@ -28,12 +28,12 @@ class test_Collate(ut.TestCase): def test_collate_labels(self): # Create fake labels - labels_size_dict = { - "graph_label1": [1], - "graph_label2": [3], - "node_label2": [5], - "edge_label3": [5, 2], - "node_label4": [5, 1], + labels_num_cols_dict = { + "graph_label1": 1, + "graph_label2": 3, + "node_label2": 1, + "edge_label3": 2, + "node_label4": 1, } labels_dtype_dict = { "graph_label1": torch.float32, @@ -57,9 +57,16 @@ def test_collate_labels(self): pyg_labels[key] = val + 17 * 2 fake_labels.append(pyg_labels) + num_nodes = [g.num_nodes for g in fake_labels] + num_edges = [g.num_edges for g in fake_labels] + # Collate labels and check for the right shapes and dtypes collated_labels = collate_labels( - deepcopy(fake_labels), deepcopy(labels_size_dict), deepcopy(labels_dtype_dict) + deepcopy(fake_labels), + deepcopy(labels_num_cols_dict), + deepcopy(labels_dtype_dict), + num_nodes, + num_edges, ) self.assertEqual(collated_labels["graph_label1"].shape, torch.Size([num_labels, 1])) # , 1 self.assertEqual(collated_labels["graph_label2"].shape, torch.Size([num_labels, 3])) # , 1 @@ -108,15 +115,19 @@ def test_collate_labels(self): label4_true[missing_labels["node_label4"]] = float("nan") # Collate labels and check for the right shapes - labels_size_dict = { - "graph_label1": [1], - "graph_label2": [3], - "node_label2": [5], - "edge_label3": [5, 2], - "node_label4": [5, 1], + labels_num_cols_dict = { + "graph_label1": 1, + "graph_label2": 3, + "node_label2": 1, + "edge_label3": 2, + "node_label4": 1, } collated_labels = collate_labels( - deepcopy(fake_labels), deepcopy(labels_size_dict), deepcopy(labels_dtype_dict) + deepcopy(fake_labels), + deepcopy(labels_num_cols_dict), + deepcopy(labels_dtype_dict), + num_nodes, + num_edges, ) self.assertEqual(collated_labels["graph_label1"].shape, torch.Size([num_labels, 1])) # , 1 self.assertEqual(collated_labels["graph_label2"].shape, torch.Size([num_labels, 3])) # , 1 @@ -138,9 +149,14 @@ def test_collate_labels(self): collated_labels["node_label4"].numpy(), label4_true.flatten(0, 1).numpy() ) # Now test the `graphium_collate_fn` function when only labels are given - fake_labels2 = [{"labels": this_label} for this_label in fake_labels] + fake_labels2 = [ + {"labels": this_label, "num_nodes": this_label.num_nodes, "num_edges": this_label.num_edges} + for this_label in fake_labels + ] collated_labels = graphium_collate_fn( - deepcopy(fake_labels2), labels_size_dict=labels_size_dict, labels_dtype_dict=labels_dtype_dict + deepcopy(fake_labels2), + labels_num_cols_dict=labels_num_cols_dict, + labels_dtype_dict=labels_dtype_dict, )["labels"] self.assertEqual(collated_labels["graph_label1"].shape, torch.Size([num_labels, 1])) self.assertEqual(collated_labels["graph_label2"].shape, torch.Size([num_labels, 3])) diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 6b73110de..961cef5a7 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -18,7 +18,7 @@ import tempfile -class TestDataUtils(ut.TestCase): +class test_DataUtils(ut.TestCase): def test_list_datasets( self, ): diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 824b80d50..a2d0c162c 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -22,10 +22,12 @@ from graphium.utils.fs import rm, exists, get_size from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule +import graphium_cpp + TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000" -class Test_DataModule(ut.TestCase): +class test_DataModule(ut.TestCase): def test_ogb_datamodule(self): # other datasets are too large to be tested dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"] @@ -45,23 +47,22 @@ def test_ogb_datamodule(self): task_specific_args = {} task_specific_args["task_1"] = {"task_level": "graph", "dataset_name": dataset_name} dm_args = {} - dm_args["processed_graph_data_path"] = None dm_args["featurization"] = featurization_args dm_args["batch_size_training"] = 16 dm_args["batch_size_inference"] = 16 dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["featurization_n_jobs"] = 0 - dm_args["featurization_progress"] = True - dm_args["featurization_backend"] = "loky" - dm_args["featurization_batch_size"] = 50 - ds = GraphOGBDataModule(task_specific_args, **dm_args) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) - ds.prepare_data(save_smiles_and_ids=False) + ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) + + ds.prepare_data() # Check the keys in the dataset - ds.setup(save_smiles_and_ids=False) + ds.setup() assert set(ds.train_ds[0].keys()) == {"features", "labels"} # Delete the cache if already exist @@ -69,13 +70,13 @@ def test_ogb_datamodule(self): rm(TEMP_CACHE_DATA_PATH, recursive=True) # Reset the datamodule - ds = GraphOGBDataModule(task_specific_args, **dm_args) + ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) - ds.prepare_data(save_smiles_and_ids=True) + ds.prepare_data() # Check the keys in the dataset - ds.setup(save_smiles_and_ids=True) - assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"} + ds.setup() + assert set(ds.train_ds[0].keys()) == {"features", "labels"} # test module assert ds.num_edge_feats == 5 @@ -84,100 +85,7 @@ def test_ogb_datamodule(self): # test batch loader batch = next(iter(ds.train_dataloader())) - assert len(batch["smiles"]) == 16 assert len(batch["labels"]["graph_task_1"]) == 16 - assert len(batch["mol_ids"]) == 16 - - def test_none_filtering(self): - # Create the objects to filter - list_of_num = [ii for ii in range(100)] - list_of_str = [str(ii) for ii in list_of_num] - tuple_of_num = tuple(list_of_num) - array_of_num = np.asarray(list_of_num) - array_of_str = np.asarray(list_of_str) - tensor_of_num = torch.as_tensor(array_of_num) - arrays_of_num = np.stack([list_of_num, list_of_num, list_of_num], axis=1) - arrays_of_str = np.stack([list_of_str, list_of_str, list_of_str], axis=1) - tensors_of_num = torch.as_tensor(arrays_of_num) - dic = {"str": list_of_str, "num": list_of_num} - df = pd.DataFrame(dic) - df_shuffled = df.sample(frac=1) - series_num = df["num"] - series_num_shuffled = df_shuffled["num"] - - # Create different indexes to use for filtering - all_idx_none = [[3, 17, 88], [22, 33, 44, 55, 66, 77, 88], [], np.arange(len(list_of_num))] - - # Loop all the indexes and filter the objects. - for ii, idx_none in enumerate(all_idx_none): - msg = f"Failed for ii={ii}" - - # Create the true filtered sequences - filtered_num = [ii for ii in range(100) if ii not in idx_none] - filtered_str = [str(ii) for ii in filtered_num] - assert len(filtered_num) == len(list_of_num) - len(idx_none) - assert len(filtered_str) == len(list_of_str) - len(idx_none) - - # Filter the sequences from the Datamodule function - ( - list_of_num_2, - list_of_str_2, - tuple_of_num_2, - array_of_num_2, - array_of_str_2, - tensor_of_num_2, - df_2, - df_shuffled_2, - dic_2, - arrays_of_num_2, - arrays_of_str_2, - tensors_of_num_2, - series_num_2, - series_num_shuffled_2, - ) = graphium.data.MultitaskFromSmilesDataModule._filter_none_molecules( - idx_none, - list_of_num, - list_of_str, - tuple_of_num, - array_of_num, - array_of_str, - tensor_of_num, - df, - df_shuffled, - dic, - arrays_of_num, - arrays_of_str, - tensors_of_num, - series_num, - series_num_shuffled, - ) - - df_shuffled_2 = df_shuffled_2.sort_values(by="num", axis=0) - series_num_shuffled_2 = series_num_shuffled_2.sort_values(axis=0) - - # Assert the filtering is done correctly - self.assertListEqual(list_of_num_2, filtered_num, msg=msg) - self.assertListEqual(list_of_str_2, filtered_str, msg=msg) - self.assertListEqual(list(tuple_of_num_2), filtered_num, msg=msg) - self.assertListEqual(array_of_num_2.tolist(), filtered_num, msg=msg) - self.assertListEqual(array_of_str_2.tolist(), filtered_str, msg=msg) - self.assertListEqual(tensor_of_num_2.tolist(), filtered_num, msg=msg) - for jj in range(arrays_of_num.shape[1]): - self.assertListEqual(arrays_of_num_2[:, jj].tolist(), filtered_num, msg=msg) - self.assertListEqual(arrays_of_str_2[:, jj].tolist(), filtered_str, msg=msg) - self.assertListEqual(tensors_of_num_2[:, jj].tolist(), filtered_num, msg=msg) - self.assertListEqual(dic_2["num"], filtered_num, msg=msg) - self.assertListEqual(dic_2["str"], filtered_str, msg=msg) - self.assertListEqual(df_2["num"].tolist(), filtered_num, msg=msg) - self.assertListEqual(df_2["str"].tolist(), filtered_str, msg=msg) - self.assertListEqual(series_num_2.tolist(), filtered_num, msg=msg) - - # When the dataframe is shuffled, the lists are different because the filtering - # is done on the row indexes, not the dataframe indexes. - bool_to_check = (len(idx_none) == 0) or (len(idx_none) == len(df_shuffled)) - self.assertIs(df_shuffled_2["num"].tolist() == filtered_num, bool_to_check, msg=msg) - self.assertIs(df_shuffled_2["str"].tolist() == filtered_str, bool_to_check, msg=msg) - self.assertIs(series_num_shuffled_2.tolist() == filtered_num, bool_to_check, msg=msg) def test_caching(self): # other datasets are too large to be tested @@ -201,10 +109,6 @@ def test_caching(self): dm_args["batch_size_inference"] = 16 dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["featurization_n_jobs"] = 0 - dm_args["featurization_progress"] = True - dm_args["featurization_backend"] = "loky" - dm_args["featurization_batch_size"] = 50 # Delete the cache if already exist if exists(TEMP_CACHE_DATA_PATH): @@ -214,10 +118,10 @@ def test_caching(self): assert not exists(TEMP_CACHE_DATA_PATH) ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) # assert not ds.load_data_from_cache(verbose=False) - ds.prepare_data(save_smiles_and_ids=False) + ds.prepare_data() # Check the keys in the dataset - ds.setup(save_smiles_and_ids=False) + ds.setup() assert set(ds.train_ds[0].keys()) == {"features", "labels"} # ds_batch = next(iter(ds.train_dataloader())) @@ -227,23 +131,9 @@ def test_caching(self): # Test loading cached data assert exists(TEMP_CACHE_DATA_PATH) - cached_ds_from_ram = GraphOGBDataModule( - task_specific_args, - processed_graph_data_path=TEMP_CACHE_DATA_PATH, - dataloading_from="ram", - **dm_args, - ) - cached_ds_from_ram.prepare_data() - cached_ds_from_ram.setup() - cached_train_loader_from_ram = cached_ds_from_ram.get_dataloader( - cached_ds_from_ram.train_ds, shuffle=False, stage="train" - ) - batch_from_ram = next(iter(cached_train_loader_from_ram)) - cached_ds_from_disk = GraphOGBDataModule( task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, - dataloading_from="disk", **dm_args, ) cached_ds_from_disk.prepare_data() @@ -254,59 +144,31 @@ def test_caching(self): batch_from_disk = next(iter(cached_train_loader_from_disk)) # Features are the same - np.testing.assert_array_almost_equal( - batch["features"].edge_index, batch_from_ram["features"].edge_index - ) np.testing.assert_array_almost_equal( batch["features"].edge_index, batch_from_disk["features"].edge_index ) - assert batch["features"].num_nodes == batch_from_ram["features"].num_nodes assert batch["features"].num_nodes == batch_from_disk["features"].num_nodes - np.testing.assert_array_almost_equal( - batch["features"].edge_weight, batch_from_ram["features"].edge_weight - ) np.testing.assert_array_almost_equal( batch["features"].edge_weight, batch_from_disk["features"].edge_weight ) - np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_ram["features"].feat) np.testing.assert_array_almost_equal(batch["features"].feat, batch_from_disk["features"].feat) - np.testing.assert_array_almost_equal( - batch["features"].edge_feat, batch_from_ram["features"].edge_feat - ) np.testing.assert_array_almost_equal( batch["features"].edge_feat, batch_from_disk["features"].edge_feat ) - np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_ram["features"].batch) np.testing.assert_array_almost_equal(batch["features"].batch, batch_from_disk["features"].batch) - np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_ram["features"].ptr) np.testing.assert_array_almost_equal(batch["features"].ptr, batch_from_disk["features"].ptr) # Labels are the same - np.testing.assert_array_almost_equal( - batch["labels"].graph_task_1, batch_from_ram["labels"].graph_task_1 - ) np.testing.assert_array_almost_equal( batch["labels"].graph_task_1, batch_from_disk["labels"].graph_task_1 ) - np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_ram["labels"].x) - np.testing.assert_array_almost_equal(batch["labels"].x, batch_from_disk["labels"].x) - - np.testing.assert_array_almost_equal(batch["labels"].edge_index, batch_from_ram["labels"].edge_index) - np.testing.assert_array_almost_equal(batch["labels"].edge_index, batch_from_disk["labels"].edge_index) - - np.testing.assert_array_almost_equal(batch["labels"].batch, batch_from_ram["labels"].batch) - np.testing.assert_array_almost_equal(batch["labels"].batch, batch_from_disk["labels"].batch) - - np.testing.assert_array_almost_equal(batch["labels"].ptr, batch_from_ram["labels"].ptr) - np.testing.assert_array_almost_equal(batch["labels"].ptr, batch_from_disk["labels"].ptr) - # Delete the cache if already exist if exists(TEMP_CACHE_DATA_PATH): rm(TEMP_CACHE_DATA_PATH, recursive=True) @@ -314,10 +176,10 @@ def test_caching(self): # Reset the datamodule ds = GraphOGBDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, **dm_args) - ds.prepare_data(save_smiles_and_ids=True) + ds.prepare_data() - ds.setup(save_smiles_and_ids=True) - assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"} + ds.setup() + assert set(ds.train_ds[0].keys()) == {"features", "labels"} # test module assert ds.num_edge_feats == 5 @@ -326,9 +188,7 @@ def test_caching(self): # test batch loader batch = next(iter(ds.train_dataloader())) - assert len(batch["smiles"]) == 16 assert len(batch["labels"]["graph_task_1"]) == 16 - assert len(batch["mol_ids"]) == 16 # Delete the cache if already exist if exists(TEMP_CACHE_DATA_PATH): @@ -369,15 +229,18 @@ def test_datamodule_with_none_molecules(self): bad_smiles = (df["SMILES1"] == "XXX") & (df["SMILES2"] == "XXX") & (df["SMILES3"] == "XXX") num_bad_smiles = sum(bad_smiles) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + # Test the datamodule datamodule = MultitaskFromSmilesDataModule( task_specific_args=task_specific_args, + processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization_args=featurization_args, - featurization_n_jobs=0, - featurization_batch_size=1, ) datamodule.prepare_data() - datamodule.setup(save_smiles_and_ids=True) + datamodule.setup() # Check that the number of molecules is correct smiles = df["SMILES1"].tolist() + df["SMILES2"].tolist() + df["SMILES3"].tolist() @@ -400,33 +263,36 @@ def test_datamodule_with_none_molecules(self): df = df.set_index("idx_smiles") # Convert the smilies from the train_ds to a list, and check the content - train_smiles = [d["smiles"] for d in datamodule.train_ds] + train_smiles = [ + graphium_cpp.extract_string( + datamodule.train_ds.smiles_tensor, datamodule.train_ds.smiles_offsets_tensor, idx + ) + for idx in range(len(datamodule.train_ds)) + ] # Check that the set of smiles are the same - train_smiles_flat = list(set([item for sublist in train_smiles for item in sublist])) + train_smiles_flat = list(set(train_smiles)) train_smiles_flat.sort() index_smiles_filt = list(set([smiles for smiles in index_smiles if smiles != "XXX"])) index_smiles_filt.sort() self.assertListEqual(train_smiles_flat, index_smiles_filt) - # Check that the smiles are correct for each datapoint in the dataset + # Check that the smiles is correct for each datapoint in the dataset for smiles in train_smiles: - self.assertEqual(len(set(smiles)), 1) # Check that all smiles are the same - this_smiles = smiles[0] - true_smiles = df.loc[this_smiles][["SMILES1", "SMILES2", "SMILES3"]] - num_true_smiles = sum(true_smiles != "XXX") - self.assertEqual(len(smiles), num_true_smiles) # Check that the number of smiles is correct + assert isinstance(smiles, str) + true_smiles = df.loc[smiles][["SMILES1", "SMILES2", "SMILES3"]] self.assertEqual( - this_smiles, true_smiles[true_smiles != "XXX"].values[0] - ) # Check that the smiles are correct + smiles, true_smiles[true_smiles != "XXX"].values[0] + ) # Check that the smiles is correct # Convert the labels from the train_ds to a dataframe - train_labels = [{task: val[0] for task, val in d["labels"].items()} for d in datamodule.train_ds] + train_labels = [datamodule.train_ds[idx]["labels"] for idx in range(len(datamodule.train_ds))] + train_labels = [{k: v[0].item() for k, v in label} for label in train_labels] train_labels_df = pd.DataFrame(train_labels) train_labels_df = train_labels_df.rename( columns={"graph_task_1": "graph_SA", "graph_task_2": "graph_logp", "graph_task_3": "graph_score"} ) - train_labels_df["smiles"] = [s[0] for s in datamodule.train_ds.smiles] + train_labels_df["smiles"] = train_smiles train_labels_df = train_labels_df.set_index("smiles") train_labels_df = train_labels_df.sort_index() @@ -450,7 +316,11 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) ds.prepare_data() ds.setup() @@ -463,7 +333,11 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) ds.prepare_data() ds.setup() @@ -476,7 +350,11 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) ds.prepare_data() ds.setup() @@ -489,7 +367,11 @@ def test_datamodule_multiple_data_files(self): "task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs} } - ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) ds.prepare_data() ds.setup() @@ -526,9 +408,13 @@ def test_splits_file(self): } } - ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) - ds.prepare_data(save_smiles_and_ids=True) - ds.setup(save_smiles_and_ids=True) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) + ds.prepare_data() + ds.setup() self.assertEqual(len(ds.train_ds), len(split_train)) self.assertEqual(len(ds.val_ds), len(split_val)) @@ -555,19 +441,30 @@ def test_splits_file(self): } } - ds2 = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) - ds2.prepare_data(save_smiles_and_ids=True) - ds2.setup(save_smiles_and_ids=True) + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ds2 = MultitaskFromSmilesDataModule( + task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH + ) + ds2.prepare_data() + ds2.setup() self.assertEqual(len(ds2.train_ds), len(split_train)) self.assertEqual(len(ds2.val_ds), len(split_val)) self.assertEqual(len(ds2.test_ds), len(split_test)) # Check that the splits are the same - self.assertEqual(len(ds.train_ds.smiles), len(split_train)) - np.testing.assert_array_equal(ds.train_ds.smiles, ds2.train_ds.smiles) - np.testing.assert_array_equal(ds.val_ds.smiles, ds2.val_ds.smiles) - np.testing.assert_array_equal(ds.test_ds.smiles, ds2.test_ds.smiles) + self.assertEqual(len(ds.train_ds.smiles_offsets_tensor), len(split_train) + 1) + np.testing.assert_array_equal(ds.train_ds.smiles_tensor, ds2.train_ds.smiles_tensor) + np.testing.assert_array_equal(ds.val_ds.smiles_tensor, ds2.val_ds.smiles_tensor) + np.testing.assert_array_equal(ds.test_ds.smiles_tensor, ds2.test_ds.smiles_tensor) + np.testing.assert_array_equal( + ds.train_ds.smiles_offsets_tensor, ds2.train_ds.smiles_offsets_tensor + ) + np.testing.assert_array_equal(ds.val_ds.smiles_offsets_tensor, ds2.val_ds.smiles_offsets_tensor) + np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor) if __name__ == "__main__": diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4a7173244..b0ed43178 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -11,16 +11,101 @@ -------------------------------------------------------------------------------- """ - import unittest as ut from graphium.data import load_micro_zinc -from graphium.data.dataset import SingleTaskDataset, MultitaskDataset +from graphium.data.datamodule import MultitaskFromSmilesDataModule +from graphium.data.dataset import MultitaskDataset +from graphium.features import mol_to_pyggraph from graphium.data.smiles_transform import smiles_to_unique_mol_ids from graphium.data.utils import get_keys - -class Test_Multitask_Dataset(ut.TestCase): +import graphium_cpp + +import numpy as np +import os.path as osp + +TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000" + + +def dataframes_to_dataset(dataframes_dict, case_num): + task_names = [key for key in dataframes_dict.keys()] + + task_dataset_args = {} + task_train_indices = {} + task_val_indices = {} + task_test_indices = {} + for task in task_names: + ( + smiles, + labels, + label_offsets, + sample_idx, + extras, + ) = MultitaskFromSmilesDataModule._extract_smiles_labels( + df=dataframes_dict[task], + task_level="graph", + smiles_col="SMILES", + label_cols=task, + idx_col=None, + weights_col=None, + weights_type=None, + ) + num_molecules = len(smiles) + task_dataset_args[task] = { + "smiles": smiles, + "labels": labels, + "label_offsets": label_offsets, + "extras": extras, + } + + task_train_indices[task] = np.arange(num_molecules).tolist() + task_val_indices[task] = [] + task_test_indices[task] = [] + + fake_data_hash = "a1b2c3testdataset" + str(case_num) + + # The rest of the data preparation and caching is done in graphium_cpp.prepare_and_save_data + normalizations = {task: {} for task in task_names} # No normalization + stage_data, all_stats, label_num_cols, label_dtypes = graphium_cpp.prepare_and_save_data( + task_names, + task_dataset_args, + normalizations, + TEMP_CACHE_DATA_PATH, + fake_data_hash, + task_train_indices, + task_val_indices, + task_test_indices, + False, # add_self_loop + False, # explicit_H + 0, # preprocessing_n_jobs + True, # merge_equivalent_mols + ) + + stage_data = stage_data["train"] + + data_offsets = None + if MultitaskFromSmilesDataModule.data_offsets_tensor_index() < len(stage_data): + data_offsets = stage_data[MultitaskFromSmilesDataModule.data_offsets_tensor_index()] + + multitask_dataset = MultitaskDataset( + about="test_dataset case" + str(case_num), + data_path=osp.join(TEMP_CACHE_DATA_PATH, "train_" + fake_data_hash), + featurize_smiles=mol_to_pyggraph, + task_names=task_names, + label_num_cols=label_num_cols, + label_dtypes=label_dtypes, + mol_file_data_offsets=data_offsets, + concat_smiles_tensor=stage_data[MultitaskFromSmilesDataModule.concat_smiles_tensor_index()], + smiles_offsets_tensor=stage_data[MultitaskFromSmilesDataModule.smiles_offsets_tensor_index()], + num_nodes_tensor=stage_data[MultitaskFromSmilesDataModule.num_nodes_tensor_index()], + num_edges_tensor=stage_data[MultitaskFromSmilesDataModule.num_edges_tensor_index()], + ) + + return multitask_dataset + + +class test_Multitask_Dataset(ut.TestCase): # Then we can choose different rows and columns for the tests as we see fit. # Remember tests are supposed to be FAST, and reading from the file system multiple times slows things down. @@ -42,50 +127,44 @@ def test_multitask_dataset_case_1(self): df_micro_zinc_logp = df[["SMILES", "logp"]] df_micro_zinc_score = df[["SMILES", "score"]] - # We need to turn these dataframes into single-task datasets. + # We need to prepare the data for these dataframes. # We don't need to do featurization yet. - ds_micro_zinc_SA = SingleTaskDataset( - smiles=df_micro_zinc_SA.loc[:, "SMILES"].tolist(), labels=df_micro_zinc_SA.loc[:, "SA"].tolist() - ) - - ds_micro_zinc_logp = SingleTaskDataset( - smiles=df_micro_zinc_logp.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_logp.loc[:, "logp"].tolist(), - ) - ds_micro_zinc_score = SingleTaskDataset( - smiles=df_micro_zinc_score.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_score.loc[:, "score"].tolist(), - ) - - # Create the multitask dataset - datasets_dict = {"SA": ds_micro_zinc_SA, "logp": ds_micro_zinc_logp, "score": ds_micro_zinc_score} - multitask_microzinc = MultitaskDataset( - datasets_dict, save_smiles_and_ids=True - ) # Can optionally have features + dataframes = { + "SA": df_micro_zinc_SA, + "logp": df_micro_zinc_logp, + "score": df_micro_zinc_score, + } + multitask_dataset = dataframes_to_dataset(dataframes, 1) # Check: The number of unique molecules equals the number of datapoints in the multitask dataset. - self.assertEqual(num_unique_mols, multitask_microzinc.__len__()) + self.assertEqual(num_unique_mols, multitask_dataset.__len__()) # Check that for each task, you have the same label values as the initial DF. - for idx in range(multitask_microzinc.__len__()): + for idx in range(multitask_dataset.__len__()): smiles = df[["SMILES"]].iloc[idx].values[0] - # label = df[['SA']].iloc[idx] - label_SA = ds_micro_zinc_SA.labels[idx] - label_logp = ds_micro_zinc_logp.labels[idx] - label_score = ds_micro_zinc_score.labels[idx] - - # Search for the mol id in the multitask dataset - mol_ids = smiles_to_unique_mol_ids([smiles]) - mol_id = mol_ids[0] + + label_SA = df_micro_zinc_SA["SA"][idx] + label_logp = df_micro_zinc_logp["logp"][idx] + label_score = df_micro_zinc_score["score"][idx] + + # Search for the smiles string in the multitask dataset found_idx = -1 - for i, id in enumerate(multitask_microzinc.mol_ids): - if mol_id == id: + for i in range(multitask_dataset.__len__()): + if ( + graphium_cpp.extract_string( + multitask_dataset.smiles_tensor, multitask_dataset.smiles_offsets_tensor, i + ) + == smiles + ): found_idx = i + break + + item = multitask_dataset[found_idx]["labels"] # Compare labels - self.assertEqual(label_SA, multitask_microzinc.labels[found_idx]["SA"]) - self.assertEqual(label_logp, multitask_microzinc.labels[found_idx]["logp"]) - self.assertEqual(label_score, multitask_microzinc.labels[found_idx]["score"]) + self.assertEqual(label_SA, item["SA"]) + self.assertEqual(label_logp, item["logp"]) + self.assertEqual(label_score, item["score"]) def test_multitask_dataset_case_2(self): """Case: Different tasks, but with no intersection in the smiles (each task has a unique set of smiles) @@ -100,36 +179,18 @@ def test_multitask_dataset_case_2(self): df_rows_score = df.iloc[400:750] # 350 data points total_data_points = 750 - # Here we split the data according to the task we care about. - df_micro_zinc_SA = df_rows_SA[["SMILES", "SA"]] - df_micro_zinc_logp = df_rows_logp[["SMILES", "logp"]] - df_micro_zinc_score = df_rows_score[["SMILES", "score"]] - - # We need to turn these dataframes into single-task datasets. - # We don't need to do featurization yet. - ds_micro_zinc_SA = SingleTaskDataset( - smiles=df_micro_zinc_SA.loc[:, "SMILES"].tolist(), labels=df_micro_zinc_SA.loc[:, "SA"].tolist() - ) - ds_micro_zinc_logp = SingleTaskDataset( - smiles=df_micro_zinc_logp.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_logp.loc[:, "logp"].tolist(), - ) - ds_micro_zinc_score = SingleTaskDataset( - smiles=df_micro_zinc_score.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_score.loc[:, "score"].tolist(), - ) - - # Create the multitask dataset - datasets_dict = {"SA": ds_micro_zinc_SA, "logp": ds_micro_zinc_logp, "score": ds_micro_zinc_score} - multitask_microzinc = MultitaskDataset( - datasets_dict, save_smiles_and_ids=True - ) # Can optionally have features + dataframes = { + "SA": df_rows_SA, + "logp": df_rows_logp, + "score": df_rows_score, + } + multitask_microzinc = dataframes_to_dataset(dataframes, 2) # The total dataset has as many molecules as there are smiles in all tasks put together self.assertEqual(total_data_points, multitask_microzinc.__len__()) # For each task, only the smiles related to that task have values, and the value is what's expected from the initial DF. - for idx in range(len(ds_micro_zinc_SA)): + for idx in range(len(multitask_microzinc)): smiles = df[["SMILES"]].iloc[idx].values[0] task = "task" @@ -141,28 +202,33 @@ def test_multitask_dataset_case_2(self): task = "score" # Labels of that molecule - label_SA = df[["SA"]].iloc[idx].values[0] - label_logp = df[["logp"]].iloc[idx].values[0] - label_score = df[["score"]].iloc[idx].values[0] + label_df = df[[task]].iloc[idx].values[0] - # Search for that molecule in the multitask dataset - mol_ids = smiles_to_unique_mol_ids([smiles]) - mol_id = mol_ids[0] + # Search for the smiles string in the multitask dataset found_idx = -1 - for i, id in enumerate(multitask_microzinc.mol_ids): - if mol_id == id: + for i in range(multitask_microzinc.__len__()): + if ( + graphium_cpp.extract_string( + multitask_microzinc.smiles_tensor, multitask_microzinc.smiles_offsets_tensor, i + ) + == smiles + ): found_idx = i - multitask_microzinc_labels = get_keys(multitask_microzinc.labels[found_idx]) + break + + item = multitask_microzinc[found_idx]["labels"] + multitask_microzinc_labels = item.keys() + + assert task in multitask_microzinc_labels + self.assertEqual(label_df, item[task]) + if task == "SA": - self.assertEqual(label_SA, multitask_microzinc.labels[found_idx]["SA"]) self.assertFalse("score" in multitask_microzinc_labels) self.assertFalse("logp" in multitask_microzinc_labels) elif task == "logp": - self.assertEqual(label_logp, multitask_microzinc.labels[found_idx]["logp"]) self.assertFalse("score" in multitask_microzinc_labels) self.assertFalse("SA" in multitask_microzinc_labels) elif task == "score": - self.assertEqual(label_score, multitask_microzinc.labels[found_idx]["score"]) self.assertFalse("SA" in multitask_microzinc_labels) self.assertFalse("logp" in multitask_microzinc_labels) @@ -180,30 +246,12 @@ def test_multitask_dataset_case_3(self): df_rows_score = df.iloc[3:5] total_data_points = 5 - # Here we split the data according to the task we care about. - df_micro_zinc_SA = df_rows_SA[["SMILES", "SA"]] - df_micro_zinc_logp = df_rows_logp[["SMILES", "logp"]] - df_micro_zinc_score = df_rows_score[["SMILES", "score"]] - - # We need to turn these dataframes into single-task datasets. - # We don't need to do featurization yet. - ds_micro_zinc_SA = SingleTaskDataset( - smiles=df_micro_zinc_SA.loc[:, "SMILES"].tolist(), labels=df_micro_zinc_SA.loc[:, "SA"].tolist() - ) - ds_micro_zinc_logp = SingleTaskDataset( - smiles=df_micro_zinc_logp.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_logp.loc[:, "logp"].tolist(), - ) - ds_micro_zinc_score = SingleTaskDataset( - smiles=df_micro_zinc_score.loc[:, "SMILES"].tolist(), - labels=df_micro_zinc_score.loc[:, "score"].tolist(), - ) - - # Create the multitask dataset - datasets_dict = {"SA": ds_micro_zinc_SA, "logp": ds_micro_zinc_logp, "score": ds_micro_zinc_score} - multitask_microzinc = MultitaskDataset( - datasets_dict, save_smiles_and_ids=True - ) # Can optionally have features + dataframes = { + "SA": df_rows_SA, + "logp": df_rows_logp, + "score": df_rows_score, + } + multitask_microzinc = dataframes_to_dataset(dataframes, 3) # The multitask dataset has as many molecules as there are unique smiles across the single task datasets. self.assertEqual(total_data_points, multitask_microzinc.__len__()) diff --git a/tests/test_featurizer.py b/tests/test_featurizer.py index e8f666365..3336feae3 100644 --- a/tests/test_featurizer.py +++ b/tests/test_featurizer.py @@ -22,13 +22,9 @@ from rdkit import Chem import datamol as dm -from graphium.features.featurizer import ( - get_mol_atomic_features_onehot, - get_mol_atomic_features_float, - get_mol_edge_features, - mol_to_adj_and_features, - mol_to_pyggraph, -) +from graphium.features.featurizer import mol_to_pyggraph + +import graphium_cpp class test_featurizer(ut.TestCase): @@ -99,155 +95,120 @@ class test_featurizer(ut.TestCase): def test_get_mol_atomic_features_onehot(self): props = deepcopy(self.atomic_onehot_props) - bad_props = ["bob"] + # bad_props = ["bob"] all_smiles = self.smiles + self.smiles_noble - for s in all_smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) + for mol in all_smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {mol}" + + rdmol = dm.to_mol(mol) for ii in range(len(props)): this_props = props[:ii] err_msg2 = err_msg + f"\n\t\tprops: {this_props}" - prop_dict = get_mol_atomic_features_onehot(mol, property_list=this_props) - self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) - for key, val in prop_dict.items(): - err_msg3 = err_msg2 + f"\n\t\tkey: {key}" - self.assertEqual(val.shape[0], mol.GetNumAtoms(), msg=err_msg3) - self.assertGreater(val.shape[1], 1, msg=err_msg3) - self.assertTrue(np.all((val == 0) | (val == 1)), msg=err_msg3) + this_props_encoded = graphium_cpp.atom_onehot_feature_names_to_tensor(this_props) + features = mol_to_pyggraph(mol, atom_property_list_onehot=this_props_encoded, mask_nan=None) + val = features["feat"] + self.assertEqual(val.size(0), rdmol.GetNumAtoms(), msg=err_msg2) + self.assertGreaterEqual(val.size(1), 2 * len(this_props), msg=err_msg2) + self.assertTrue(((val == 0) | (val == 1)).numpy().all(), msg=err_msg2) - with self.assertRaises(ValueError, msg=err_msg): - get_mol_atomic_features_onehot(mol, property_list=bad_props) + # with self.assertRaises(ValueError, msg=err_msg): + # get_mol_atomic_features_onehot(mol, property_list=bad_props) def test_get_mol_atomic_features_float(self): props = deepcopy(self.atomic_float_props) - bad_props = ["bob"] + # bad_props = ["bob"] all_smiles = self.smiles + self.smiles_noble - for s in all_smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) + for mol in all_smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {mol}" + rdmol = dm.to_mol(mol) for ii in range(len(props)): this_props = props[:ii] err_msg2 = err_msg + f"\n\t\tprops: {this_props}" - prop_dict = get_mol_atomic_features_float(mol, property_list=this_props, mask_nan=None) - self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) - for key, val in prop_dict.items(): - err_msg3 = err_msg2 + f"\n\t\tkey: {key}" - self.assertListEqual(list(val.shape), [mol.GetNumAtoms()], msg=err_msg3) + this_props_encoded = graphium_cpp.atom_float_feature_names_to_tensor(this_props) + features = mol_to_pyggraph(mol, atom_property_list_float=this_props_encoded, mask_nan=None) + val = features["feat"] + self.assertEqual(val.size(0), rdmol.GetNumAtoms(), msg=err_msg2) + self.assertEqual(val.size(1), len(this_props), msg=err_msg2) - with self.assertRaises(ValueError, msg=err_msg): - get_mol_atomic_features_float(mol, property_list=bad_props) + # with self.assertRaises(ValueError, msg=err_msg): + # get_mol_atomic_features_float(mol, property_list=bad_props) def test_get_mol_atomic_features_float_nan_mask(self): - for s in self.smiles_noble: - mol = dm.to_mol(s) - + props_encoded = graphium_cpp.atom_float_feature_names_to_tensor(self.atomic_float_props) + for mol in self.smiles_noble: # Nothing happens when `mask_nan = None`, nans are still in the property array - prop_dict = get_mol_atomic_features_float( - mol, property_list=self.atomic_float_props, mask_nan=None + features = mol_to_pyggraph( + mol, atom_property_list_float=props_encoded, mask_nan=None, on_error="raise" ) - prop_array = np.concatenate(list(prop_dict.values()), axis=0) + prop_array = features["feat"] nans = np.isnan(prop_array) # Capture a raised error when `mask_nan = "raise"` with self.assertRaises(ValueError): - prop_dict = get_mol_atomic_features_float( - mol, property_list=self.atomic_float_props, mask_nan="raise" + features = mol_to_pyggraph( + mol, atom_property_list_float=props_encoded, mask_nan="raise", on_error="raise" ) + print(f"Failed to raise error for nans on {mol}") # Not sure how to Capture a logged warning when `mask_nan = "warn"` # Here, I'm testing a behaviour similar to `mask_nan = None` - prop_dict = get_mol_atomic_features_float( - mol, property_list=self.atomic_float_props, mask_nan="warn" + features = mol_to_pyggraph( + mol, atom_property_list_float=props_encoded, mask_nan="warn", on_error="raise" ) - prop_array = np.concatenate(list(prop_dict.values()), axis=0) - self.assertEqual(len(self.atomic_float_props), len(prop_dict)) - self.assertTrue(any(np.isnan(prop_array))) + prop_array = features["feat"] + self.assertEqual(len(self.atomic_float_props), prop_array.size(1)) + self.assertTrue(np.isnan(prop_array.numpy()).any()) # NaNs are replaced by `42` when `mask_nan=42` - prop_dict = get_mol_atomic_features_float(mol, property_list=self.atomic_float_props, mask_nan=42) - prop_array = np.concatenate(list(prop_dict.values()), axis=0) - self.assertEqual(len(self.atomic_float_props), len(prop_dict)) - self.assertFalse(any(np.isnan(prop_array))) - self.assertTrue(all(prop_array[nans] == 42)) + features = mol_to_pyggraph( + mol, atom_property_list_float=props_encoded, mask_nan=42, on_error="raise" + ) + prop_array = features["feat"] + self.assertEqual(len(self.atomic_float_props), prop_array.size(1)) + self.assertFalse(np.isnan(prop_array.numpy()).any()) + self.assertTrue((prop_array[nans] == 42).all()) def test_get_mol_edge_features(self): props = deepcopy(self.edge_props) - bad_props = ["bob"] + # bad_props = ["bob"] all_smiles = self.smiles + self.smiles_noble - for s in all_smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) + for mol in all_smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {mol}" + rdmol = dm.to_mol(mol) for ii in range(len(props)): this_props = props[: ii + 1] err_msg2 = err_msg + f"\n\t\tprops: {this_props}" - prop_dict = get_mol_edge_features(mol, property_list=this_props) - self.assertListEqual(list(prop_dict.keys()), this_props, msg=err_msg) - for key, val in prop_dict.items(): - err_msg3 = err_msg2 + f"\n\t\tkey: {key}" - self.assertEqual(val.shape[0], mol.GetNumBonds(), msg=err_msg3) - - if mol.GetNumBonds() > 0: - with self.assertRaises(ValueError, msg=err_msg): - get_mol_edge_features(mol, property_list=bad_props) - - def test_mol_to_adj_and_features(self): - np.random.seed(42) + this_props_encoded = graphium_cpp.bond_feature_names_to_tensor(this_props) + features = mol_to_pyggraph(mol, edge_property_list=this_props_encoded, mask_nan=None) + val = features["edge_feat"] + self.assertEqual(val.shape[0], 2 * rdmol.GetNumBonds(), msg=err_msg2) + if rdmol.GetNumBonds() > 0: + self.assertGreaterEqual(val.shape[1], len(this_props), msg=err_msg2) - for s in self.smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) - mol_Hs = Chem.AddHs(mol) # type: ignore - mol_No_Hs = Chem.RemoveHs(mol) # type: ignore - - for explicit_H in [True, False]: - this_mol = mol_Hs if explicit_H else mol_No_Hs - for ii in np.arange(0, 5, 0.2): - num_props = int(round(ii)) - err_msg2 = err_msg + f"\n\t\texplicit_H: {explicit_H}\n\t\tii: {ii}" - - adj, ndata, edata, _, _ = mol_to_adj_and_features( - mol=mol, - atom_property_list_onehot=np.random.choice( - self.atomic_onehot_props, size=num_props, replace=False - ), - atom_property_list_float=np.random.choice( - self.atomic_float_props, size=num_props, replace=False - ), - edge_property_list=np.random.choice(self.edge_props, size=num_props, replace=False), - add_self_loop=False, - explicit_H=explicit_H, - use_bonds_weights=False, - ) - - self.assertEqual(adj.shape[0], this_mol.GetNumAtoms(), msg=err_msg2) - if num_props > 0: - self.assertEqual(ndata.shape[0], this_mol.GetNumAtoms(), msg=err_msg2) - if this_mol.GetNumBonds() > 0: - self.assertEqual(edata.shape[0], this_mol.GetNumBonds(), msg=err_msg2) - self.assertGreaterEqual(edata.shape[1], num_props, msg=err_msg2) - self.assertGreaterEqual(ndata.shape[1], num_props, msg=err_msg2) + # if mol.GetNumBonds() > 0: + # with self.assertRaises(ValueError, msg=err_msg): + # get_mol_edge_features(mol, property_list=bad_props) def test_mol_to_pyggraph(self): np.random.seed(42) + single_atom_prop_encoded = graphium_cpp.atom_float_feature_names_to_tensor(["atomic-number"]) + single_bond_prop_encoded = graphium_cpp.bond_feature_names_to_tensor(["bond-type-float"]) - for s in self.smiles: - err_msg = f"\n\tError for params:\n\t\tSMILES: {s}" - mol = dm.to_mol(s) - mol_Hs = Chem.AddHs(mol) # type: ignore - mol_No_Hs = Chem.RemoveHs(mol) # type: ignore + for mol in self.smiles: + err_msg = f"\n\tError for params:\n\t\tSMILES: {mol}" + rdmol = dm.to_mol(mol) graph = mol_to_pyggraph( mol=mol, - atom_property_list_onehot=[], - atom_property_list_float=["atomic-number"], - edge_property_list=["bond-type-float"], + atom_property_list_float=single_atom_prop_encoded, + edge_property_list=single_bond_prop_encoded, add_self_loop=False, explicit_H=False, use_bonds_weights=False, @@ -255,29 +216,32 @@ def test_mol_to_pyggraph(self): ) # Check the number of nodes and edges - self.assertListEqual(list(graph["feat"].shape), [mol.GetNumAtoms(), 1], msg=err_msg) - self.assertListEqual(list(graph["edge_feat"].shape), [2 * mol.GetNumBonds(), 1], msg=err_msg) + self.assertListEqual(list(graph["feat"].shape), [rdmol.GetNumAtoms(), 1], msg=err_msg) + self.assertListEqual(list(graph["edge_feat"].shape), [2 * rdmol.GetNumBonds(), 1], msg=err_msg) # Check the node features feat = graph["feat"].to_dense().numpy() * 5 + 6 # Undo the scaling - atom_nums = np.asarray([atom.GetAtomicNum() for atom in mol.GetAtoms()]) + atom_nums = np.asarray([atom.GetAtomicNum() for atom in rdmol.GetAtoms()]) np.testing.assert_array_almost_equal(feat[:, 0], atom_nums, decimal=5, err_msg=err_msg) # Check the edge features edge_feat = graph["edge_feat"].to_dense().numpy() - bond_types = np.asarray([bond.GetBondTypeAsDouble() for bond in mol.GetBonds()]).repeat(2) + bond_types = np.asarray([bond.GetBondTypeAsDouble() for bond in rdmol.GetBonds()]).repeat(2) np.testing.assert_array_almost_equal(edge_feat[:, 0], bond_types, decimal=5, err_msg=err_msg) # Check the edge indices - if mol.GetNumBonds() > 0: + if rdmol.GetNumBonds() > 0: edge_index = graph["edge_index"].to_dense().numpy() true_edge_index = [] - for bond in mol.GetBonds(): + for bond in rdmol.GetBonds(): true_edge_index.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) true_edge_index.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) true_edge_index = np.asarray(true_edge_index).T np.testing.assert_array_equal(edge_index, true_edge_index, err_msg=err_msg) + mol_Hs = Chem.AddHs(rdmol) # type: ignore + mol_No_Hs = Chem.RemoveHs(rdmol) # type: ignore + # Loop over many possible combinations of properties for explicit_H in [True, False]: this_mol = mol_Hs if explicit_H else mol_No_Hs @@ -287,13 +251,15 @@ def test_mol_to_pyggraph(self): graph = mol_to_pyggraph( mol=mol, - atom_property_list_onehot=np.random.choice( - self.atomic_onehot_props, size=num_props, replace=False + atom_property_list_onehot=graphium_cpp.atom_onehot_feature_names_to_tensor( + np.random.choice(self.atomic_onehot_props, size=num_props, replace=False) + ), + atom_property_list_float=graphium_cpp.atom_float_feature_names_to_tensor( + np.random.choice(self.atomic_float_props, size=num_props, replace=False) ), - atom_property_list_float=np.random.choice( - self.atomic_float_props, size=num_props, replace=False + edge_property_list=graphium_cpp.bond_feature_names_to_tensor( + np.random.choice(self.edge_props, size=num_props, replace=False) ), - edge_property_list=np.random.choice(self.edge_props, size=num_props, replace=False), add_self_loop=False, explicit_H=explicit_H, use_bonds_weights=False, diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py index 52484c4c9..db71ac81b 100644 --- a/tests/test_finetuning.py +++ b/tests/test_finetuning.py @@ -16,6 +16,7 @@ import unittest as ut from copy import deepcopy from os.path import abspath, dirname +import shutil import torch from lightning.pytorch.callbacks import Callback @@ -29,7 +30,6 @@ load_metrics, load_predictor, load_trainer, - save_params_to_wandb, ) from graphium.finetuning import GraphFinetuning, modify_cfg_for_finetuning from graphium.trainer import PredictorModule @@ -40,7 +40,7 @@ os.chdir(MAIN_DIR) -class Test_Finetuning(ut.TestCase): +class test_Finetuning(ut.TestCase): def test_finetuning_from_task_head(self): # Skip test if PyTDC package not installed try: @@ -60,9 +60,14 @@ def test_finetuning_from_task_head(self): # Initialize the accelerator cfg, accelerator_type = load_accelerator(cfg) + # If the data_cache directory exists, delete it for the purpose of the test + data_cache = cfg["datamodule"]["args"]["processed_graph_data_path"] + if os.path.exists(data_cache): + shutil.rmtree(data_cache) + # Load and initialize the dataset datamodule = load_datamodule(cfg, accelerator_type) - datamodule.task_specific_args["lipophilicity_astrazeneca"].sample_size = 100 + datamodule.task_specific_args["lipophilicity_astrazeneca"].sample_size = 300 # Initialize the network model_class, model_kwargs = load_architecture( @@ -149,7 +154,7 @@ def test_finetuning_from_task_head(self): ################################################# # Define test callback that checks for correct (un)freezing - class TestCallback(Callback): + class CallbackTesting(Callback): def __init__(self, cfg): super().__init__() @@ -217,18 +222,20 @@ def on_train_epoch_start(self, trainer, pl_module): assert not False in unfrozen_parameters - trainer = load_trainer(cfg, accelerator_type) + metrics_on_progress_bar = predictor.get_metrics_on_progress_bar + trainer = load_trainer(cfg, accelerator_type, metrics_on_progress_bar=metrics_on_progress_bar) finetuning_training_kwargs = cfg["finetuning"]["training_kwargs"] trainer.callbacks.append(GraphFinetuning(**finetuning_training_kwargs)) # Add test callback to trainer - trainer.callbacks.append(TestCallback(cfg)) + trainer.callbacks.append(CallbackTesting(cfg)) predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"]) # Run the model training trainer.fit(model=predictor, datamodule=datamodule) + trainer.test(model=predictor, datamodule=datamodule) def test_finetuning_from_gnn(self): # Skip test if PyTDC package not installed @@ -249,9 +256,14 @@ def test_finetuning_from_gnn(self): # Initialize the accelerator cfg, accelerator_type = load_accelerator(cfg) + # If the data_cache directory exists, delete it for the purpose of the test + data_cache = cfg["datamodule"]["args"]["processed_graph_data_path"] + if os.path.exists(data_cache): + shutil.rmtree(data_cache) + # Load and initialize the dataset datamodule = load_datamodule(cfg, accelerator_type) - datamodule.task_specific_args["lipophilicity_astrazeneca"].sample_size = 100 + datamodule.task_specific_args["lipophilicity_astrazeneca"].sample_size = 300 # Initialize the network model_class, model_kwargs = load_architecture( @@ -335,7 +347,7 @@ def test_finetuning_from_gnn(self): ################################################# # Define test callback that checks for correct (un)freezing - class TestCallback(Callback): + class CallbackTesting(Callback): def __init__(self, cfg): super().__init__() @@ -392,18 +404,20 @@ def on_train_epoch_start(self, trainer, pl_module): assert not False in unfrozen_parameters - trainer = load_trainer(cfg, accelerator_type) + metrics_on_progress_bar = predictor.get_metrics_on_progress_bar + trainer = load_trainer(cfg, accelerator_type, metrics_on_progress_bar=metrics_on_progress_bar) finetuning_training_kwargs = cfg["finetuning"]["training_kwargs"] trainer.callbacks.append(GraphFinetuning(**finetuning_training_kwargs)) # Add test callback to trainer - trainer.callbacks.append(TestCallback(cfg)) + trainer.callbacks.append(CallbackTesting(cfg)) predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"]) # Run the model training trainer.fit(model=predictor, datamodule=datamodule) + trainer.test(model=predictor, datamodule=datamodule) if __name__ == "__main__": diff --git a/tests/test_ipu_dataloader.py b/tests/test_ipu_dataloader.py deleted file mode 100644 index 436d609d4..000000000 --- a/tests/test_ipu_dataloader.py +++ /dev/null @@ -1,255 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -# General imports -import yaml -import unittest as ut -import numpy as np -from copy import deepcopy -from warnings import warn -from unittest.mock import patch -from lightning import Trainer, LightningModule -from functools import partial -import pytest -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union - -import torch -from torch.utils.data.dataloader import default_collate -from lightning_graphcore import IPUStrategy - - -def random_packing(num_nodes, batch_size): - ipu_batch_size = int(len(num_nodes) / batch_size) - indices = np.arange(len(num_nodes)) - np.random.shuffle(indices) - indices = np.reshape(indices, (ipu_batch_size, batch_size)).tolist() - return indices - - -def global_batch_collator(batch_size, batches): - packs = [] - for pack_idx in range(0, len(batches), batch_size): - packs.append(default_collate(batches[pack_idx : pack_idx + batch_size])) - global_batch = default_collate(packs) - global_batch = (global_batch[0], tuple(global_batch[1])) - return global_batch - - -@pytest.mark.ipu -class test_DataLoading(ut.TestCase): - class TestSimpleLightning(LightningModule): - # Create a basic Ligthning for testing the batch sizes - def __init__(self, batch_size, node_feat_size, edge_feat_size, num_batch) -> None: - super().__init__() - self.batch_size = batch_size - self.node_feat_size = node_feat_size - self.edge_feat_size = edge_feat_size - self.layer = torch.nn.Linear(node_feat_size, 1) - self.loss_fn = torch.nn.L1Loss() - self.num_batch = num_batch - - def validation_step(self, batch, batch_idx): - self.assert_shapes(batch, batch_idx, "val") - loss = self.forward(batch) - return loss - - def training_step(self, batch, batch_idx): - self.assert_shapes(batch, batch_idx, "train") - loss = self.forward(batch) - return loss - - def forward(self, batch): - out = self.layer(batch[1][0]).squeeze(-1) - loss = self.loss_fn(out, batch[0]) - return loss - - def assert_shapes(self, batch, batch_idx, step): - # Test the shape of the labels - this_shape = list(batch[0].shape) - true_shape = [1, self.batch_size] - assert ( - this_shape == true_shape - ), f"Shape of the labels is `{this_shape}` but should be {true_shape}" - - # Test the shape of the first feature - this_shape = list(batch[1][0].shape) - true_shape = [1, self.batch_size, self.node_feat_size] - assert ( - this_shape == true_shape - ), f"Shape of the feature 0 is `{this_shape}` but should be {true_shape}" - - # Test the shape of the second feature - this_shape = list(batch[1][1].shape) - true_shape = [1, self.batch_size, self.edge_feat_size] - assert ( - this_shape == true_shape - ), f"Shape of the feature 0 is `{this_shape}` but should be {true_shape}" - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=1e-3) - - class TestDataset(torch.utils.data.Dataset): - # Create a simple dataset for testing the Lightning integration - def __init__(self, labels, node_features, edge_features): - self.labels = labels - self.node_features = node_features - self.edge_features = edge_features - - def __len__(self): - return len(self.labels) - - def __getitem__(self, idx): - # [label, [feat1, feat2]] - return [self.labels[idx], [self.node_features[idx], self.edge_features[idx]]] - - # @pytest.mark.skip - def test_poptorch_simple_deviceiterations_gradient_accumulation(self): - """ - Test a simple version of the device-iterations and gradient accumulation - to make sure that the dataloader and models handle them correcly. - """ - - with patch("poptorch.ipuHardwareIsAvailable", return_value=True): - with patch("lightning_graphcore.accelerator._IPU_AVAILABLE", new=True): - import poptorch - - assert poptorch.ipuHardwareIsAvailable() - from lightning_graphcore.accelerator import _IPU_AVAILABLE - - assert _IPU_AVAILABLE is True - - # Initialize constants - gradient_accumulation = 2 - device_iterations = 3 - batch_size = 5 - num_replicate = 7 - node_feat_size = 11 - edge_feat_size = 13 - - # Initialize the batch info and poptorch options - opts = poptorch.Options() - opts.useIpuModel(True) - opts.deviceIterations(device_iterations) - training_opts = deepcopy(opts) - training_opts.Training.gradientAccumulation(gradient_accumulation) - inference_opts = deepcopy(opts) - - # Initialize the dataset - num_batch = device_iterations * gradient_accumulation * num_replicate - data_size = num_batch * batch_size - dataset = self.TestDataset( - labels=np.random.rand(data_size).astype(np.float32), - node_features=[ - np.random.rand(node_feat_size).astype(np.float32) for ii in range(data_size) - ], - edge_features=[ - np.random.rand(edge_feat_size).astype(np.float32) for ii in range(data_size) - ], - ) - - # Initialize the dataloader - train_dataloader = poptorch.DataLoader( - options=training_opts, - dataset=deepcopy(dataset), - batch_size=batch_size, - collate_fn=partial(global_batch_collator, batch_size), - ) - - val_dataloader = poptorch.DataLoader( - options=inference_opts, - dataset=deepcopy(dataset), - batch_size=batch_size, - collate_fn=partial(global_batch_collator, batch_size), - ) - - # Build the model, and run it on "IPU" - model = self.TestSimpleLightning(batch_size, node_feat_size, edge_feat_size, num_batch) - - strategy = IPUStrategy( - training_opts=training_opts, inference_opts=inference_opts, autoreport=True - ) - trainer = Trainer( - logger=True, - enable_checkpointing=False, - max_epochs=2, - strategy=strategy, - num_sanity_val_steps=0, - accelerator="ipu", - devices=1, - ) - trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) - - @pytest.mark.skip - def test_poptorch_graphium_deviceiterations_gradient_accumulation_full(self): - """ - Test the device-iterations and gradient accumulation in a way - that is very similar to the Graphium code - to make sure that the dataloader and models handle them correcly. - """ - with patch("poptorch.ipuHardwareIsAvailable", return_value=True): - with patch("lightning_graphcore.accelerator._IPU_AVAILABLE", new=True): - try: - import poptorch - except Exception as e: - warn(f"Skipping this test because poptorch is not available.\n{e}") - return - - from lightning_graphcore import IPUStrategy - import lightning_graphcore - - # Current library imports - from graphium.config._loader import ( - load_datamodule, - load_metrics, - load_architecture, - load_accelerator, - load_predictor, - load_trainer, - ) - from graphium.utils.safe_run import SafeRun - - # Simplified testing config - reflecting the toymix requirements - CONFIG_FILE = "tests/config_test_ipu_dataloader_multitask.yaml" - with open(CONFIG_FILE, "r") as f: - cfg = yaml.safe_load(f) - - cfg, accelerator = load_accelerator(cfg) - - # Load the datamodule, and prepare the data - datamodule = load_datamodule(cfg, accelerator_type=accelerator) - datamodule.prepare_data() - metrics = load_metrics(cfg) - model_class, model_kwargs = load_architecture(cfg, in_dims=datamodule.in_dims) - # datamodule.setup() - predictor = load_predictor( - cfg, - model_class, - model_kwargs, - metrics, - datamodule.get_task_levels(), - accelerator, - datamodule.featurization, - datamodule.task_norms, - ) - assert poptorch.ipuHardwareIsAvailable() - trainer = load_trainer(cfg, "test", accelerator, "date_time_suffix") - # Run the model training - with SafeRun( - name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True - ): - trainer.fit(model=predictor, datamodule=datamodule) - - -if __name__ == "__main__": - ut.main() diff --git a/tests/test_ipu_losses.py b/tests/test_ipu_losses.py deleted file mode 100644 index cb18eee47..000000000 --- a/tests/test_ipu_losses.py +++ /dev/null @@ -1,172 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -import unittest as ut -import torch -from torch.nn import BCELoss, MSELoss, L1Loss, BCEWithLogitsLoss -from copy import deepcopy -import pytest - -from graphium.ipu.ipu_losses import BCELossIPU, MSELossIPU, L1LossIPU, BCEWithLogitsLossIPU, HybridCELossIPU -from graphium.trainer.losses import HybridCELoss - - -@pytest.mark.ipu -class test_Losses(ut.TestCase): - torch.manual_seed(42) - preds = torch.rand((100, 10), dtype=torch.float32) - target = torch.rand((100, 10), dtype=torch.float32) - - th = 0.7 - nan_th = 0.2 - preds_greater = preds > th - target_greater = (target > th).to(torch.float32) - target_greater_nan = deepcopy(target_greater) - is_nan = target < nan_th - target_greater_nan[target < nan_th] = torch.nan - target_nan = deepcopy(target) - target_nan[target < nan_th] = torch.nan - - def test_bce(self): - preds = deepcopy(self.preds) - target = deepcopy(self.target_greater) - target_nan = deepcopy(self.target_greater_nan) - - # Regular loss - loss_true = BCELoss()(preds, target) - loss_ipu = BCELossIPU()(preds, target) - self.assertFalse(loss_true.isnan(), "Regular BCELoss is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular BCELoss is different" - ) - - # Weighted loss - weight = torch.rand(preds.shape[1], dtype=torch.float32) - loss_true = BCELoss(weight=weight)(preds, target) - loss_ipu = BCELossIPU(weight=weight)(preds, target) - self.assertFalse(loss_true.isnan(), "Regular BCELoss is NaN") - self.assertAlmostEqual(loss_true.item(), loss_ipu.item(), msg="Weighted BCELoss is different") - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - loss_true = BCELoss()(preds[not_nan], target[not_nan]) - loss_ipu = BCELossIPU()(preds, target_nan) - self.assertFalse(loss_true.isnan(), "Regular BCELoss with target_nan is NaN") - self.assertFalse(loss_ipu.isnan(), "Regular BCELossIPU with target_nan is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular BCELoss with NaN is different" - ) - - # Weighted loss with NaNs in target - not_nan = ~target_nan.isnan() - weight = torch.rand(preds.shape, dtype=torch.float32) - loss_true = BCELoss(weight=weight[not_nan])(preds[not_nan], target_nan[not_nan]) - loss_ipu = BCELossIPU(weight=weight)(preds, target_nan) - self.assertFalse(loss_true.isnan(), "Weighted BCELoss with target_nan is NaN") - self.assertFalse(loss_ipu.isnan(), "Weighted BCELossIPU with target_nan is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Weighted BCELoss with NaN is different" - ) - - def test_mse(self): - preds = deepcopy(self.preds) - target = deepcopy(self.target) - target_nan = deepcopy(self.target_nan) - - # Regular loss - loss_true = MSELoss()(preds, target) - loss_ipu = MSELossIPU()(preds, target) - self.assertFalse(loss_true.isnan(), "Regular MSELoss is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular MSELoss is different" - ) - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - loss_true = MSELoss()(preds[not_nan], target[not_nan]) - loss_ipu = MSELossIPU()(preds, target_nan) - self.assertFalse(loss_true.isnan(), "Regular MSELoss with target_nan is NaN") - self.assertFalse(loss_ipu.isnan(), "Regular MSELossIPU with target_nan is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular MSELoss with NaN is different" - ) - - def test_l1(self): - preds = deepcopy(self.preds) - target = deepcopy(self.target) - target_nan = deepcopy(self.target_nan) - - # Regular loss - loss_true = L1Loss()(preds, target) - loss_ipu = L1LossIPU()(preds, target) - self.assertFalse(loss_true.isnan(), "Regular MAELoss is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular MAELoss is different" - ) - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - loss_true = L1Loss()(preds[not_nan], target[not_nan]) - loss_ipu = L1LossIPU()(preds, target_nan) - self.assertFalse(loss_true.isnan(), "Regular MAELoss with target_nan is NaN") - self.assertFalse(loss_ipu.isnan(), "Regular MAELossIPU with target_nan is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular MAELoss with NaN is different" - ) - - def test_bce_logits(self): - preds = deepcopy(self.preds) - target = deepcopy(self.target_greater) - target_nan = deepcopy(self.target_greater_nan) - - # Regular loss - loss_true = BCEWithLogitsLoss()(preds, target) - loss_ipu = BCEWithLogitsLossIPU()(preds, target) - self.assertFalse(loss_true.isnan(), "Regular BCEWithLogitsLoss is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular BCEWithLogitsLoss is different" - ) - - # Weighted loss - weight = torch.rand(preds.shape[1], dtype=torch.float32) - loss_true = BCEWithLogitsLoss(weight=weight)(preds, target) - loss_ipu = BCEWithLogitsLossIPU(weight=weight)(preds, target) - self.assertFalse(loss_true.isnan(), "Regular BCEWithLogitsLoss is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), msg="Weighted BCEWithLogitsLoss is different" - ) - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - loss_true = BCEWithLogitsLoss()(preds[not_nan], target[not_nan]) - loss_ipu = BCEWithLogitsLossIPU()(preds, target_nan) - self.assertFalse(loss_true.isnan(), "Regular test_bce_logits with target_nan is NaN") - self.assertFalse(loss_ipu.isnan(), "Regular test_bce_logits with target_nan is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular BCELoss with NaN is different" - ) - - # Weighted loss with NaNs in target - not_nan = ~target_nan.isnan() - weight = torch.rand(preds.shape, dtype=torch.float32) - loss_true = BCEWithLogitsLoss(weight=weight[not_nan])(preds[not_nan], target_nan[not_nan]) - loss_ipu = BCEWithLogitsLossIPU(weight=weight)(preds, target_nan) - self.assertFalse(loss_true.isnan(), "Weighted test_bce_logits with target_nan is NaN") - self.assertFalse(loss_ipu.isnan(), "Weighted test_bce_logits with target_nan is NaN") - self.assertAlmostEqual( - loss_true.item(), - loss_ipu.item(), - places=6, - msg="Weighted BCEWithLogitsLoss with NaN is different", - ) diff --git a/tests/test_ipu_metrics.py b/tests/test_ipu_metrics.py deleted file mode 100644 index ee4801e7b..000000000 --- a/tests/test_ipu_metrics.py +++ /dev/null @@ -1,774 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -import unittest as ut -import torch -from torchmetrics.functional import ( - auroc, - average_precision, - precision, - accuracy, - recall, - pearson_corrcoef, - spearman_corrcoef, - r2_score, - f1_score, - fbeta_score, - mean_squared_error, - mean_absolute_error, -) -from copy import deepcopy -import pytest - -from graphium.ipu.ipu_metrics import ( - auroc_ipu, - average_precision_ipu, - precision_ipu, - accuracy_ipu, - recall_ipu, - pearson_ipu, - spearman_ipu, - r2_score_ipu, - f1_score_ipu, - fbeta_score_ipu, - mean_squared_error_ipu, - mean_absolute_error_ipu, -) - - -@pytest.mark.ipu -class test_Metrics(ut.TestCase): - torch.manual_seed(42) - preds = torch.rand((100, 10), dtype=torch.float32) - target = torch.rand((100, 10), dtype=torch.float32) - - th = 0.7 - nan_th = 0.2 - preds_greater = preds > th - target_greater = (target > th).to(torch.float32) - target_greater_nan = deepcopy(target_greater) - is_nan = target < nan_th - target_greater_nan[target < nan_th] = torch.nan - target_nan = deepcopy(target) - target_nan[target < nan_th] = torch.nan - - def test_auroc(self): - preds = deepcopy(self.preds)[:, 0] - target = deepcopy(self.target)[:, 0] - target_nan = deepcopy(self.target_nan)[:, 0] - - target[target < 0.5] = 0 - target[target >= 0.5] = 1 - - target_nan[target_nan < 0.5] = 0 - target_nan[target_nan >= 0.5] = 1 - - # Regular loss - score_true = auroc(preds, target.to(int)) - score_ipu = auroc_ipu(preds, target) - self.assertFalse(score_true.isnan(), "Regular AUROC score is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Regular AUROC score is different" - ) - - # Weighted loss (As in BCE) - sample_weights = torch.rand(preds.shape[0], dtype=torch.float32) - score_true = auroc(preds, target.to(int), sample_weights=sample_weights) - score_ipu = auroc_ipu(preds, target, sample_weights=sample_weights) - self.assertFalse(score_true.isnan(), "Regular AUROC score is NaN") - self.assertAlmostEqual(score_true.item(), score_ipu.item(), msg="Weighted AUROC score is different") - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - score_true = auroc(preds[not_nan], target[not_nan].to(int)) - score_ipu = auroc_ipu(preds, target_nan) - self.assertFalse(score_true.isnan(), "Regular AUROC score with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Regular AUROCIPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Regular AUROC score with NaN is different" - ) - - # Weighted loss with NaNs in target (As in BCE) - not_nan = ~target_nan.isnan() - sample_weights = torch.rand(preds.shape, dtype=torch.float32) - loss_true = auroc(preds[not_nan], target_nan[not_nan].to(int), sample_weights=sample_weights[not_nan]) - loss_ipu = auroc_ipu(preds, target_nan, sample_weights=sample_weights) - self.assertFalse(loss_true.isnan(), "Weighted AUROC score with target_nan is NaN") - self.assertFalse(loss_ipu.isnan(), "Weighted AUROC IPU score with target_nan is NaN") - self.assertAlmostEqual( - # AssertionError: 0.6603766679763794 != 0.6234951615333557 within 2 places - loss_true.item(), - loss_ipu.item(), - places=6, - msg="Weighted AUROC with NaN is different", - ) - - def test_average_precision(self): # TODO: Make work with multi-class - preds = deepcopy(self.preds)[:, 0] - target = deepcopy(self.target)[:, 0] - target_nan = deepcopy(self.target_nan)[:, 0] - - target[target < 0.5] = 0 - target[target >= 0.5] = 1 - - target_nan[target_nan < 0.5] = 0 - target_nan[target_nan >= 0.5] = 1 - - # Regular loss - score_true = average_precision(preds, target.to(int), task="binary") - score_ipu = average_precision_ipu(preds, target.to(int), task="binary") - self.assertFalse(score_true.isnan(), "Regular Average Precision is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Regular Average Precision is different" - ) - - # Regular average precision with NaNs in target - not_nan = ~target_nan.isnan() - score_true = average_precision(preds[not_nan], target[not_nan].to(int), task="binary") - score_ipu = average_precision_ipu(preds, target_nan, task="binary") - self.assertFalse(score_true.isnan(), "Regular Average Precision with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Regular Average Precision IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), - score_ipu.item(), - places=6, - msg="Regular Average Precision with NaN is different", - ) - - def test_precision(self): - preds = deepcopy(self.preds)[:, :4] - target = deepcopy(self.target)[:, 0] - t = deepcopy(target) - - target[t < 0.4] = 0 - target[(t >= 0.4) & (t < 0.6)] = 1 - target[(t >= 0.6) & (t < 0.8)] = 2 - target[(t >= 0.8)] = 3 - - target_nan = deepcopy(target) - target_nan[self.is_nan[:, 0]] = float("nan") - target_nan_bin = deepcopy(target_nan) - target_nan_bin[target_nan > 0] = 1 - - # Micro precision binary - score_true = precision(preds[:, 0], target.to(int) > 0, average="micro") - score_ipu = precision_ipu(preds[:, 0], target > 0, average="micro") - self.assertFalse(score_true.isnan(), "Micro Precision binary is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Precision binary is different" - ) - - # Micro precision binary with NaNs in target - not_nan = ~target_nan.isnan() - score_true = precision(preds[:, 0][not_nan], target_nan_bin[not_nan].to(int), average="micro") - score_ipu = precision_ipu(preds[:, 0], target_nan_bin, average="micro") - self.assertFalse(score_true.isnan(), "Micro Precision binary with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Micro Precision binary IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Precision with NaN is different" - ) - - # Micro precision - score_true = precision(preds, target.to(int), average="micro") - score_ipu = precision_ipu(preds, target, average="micro") - self.assertFalse(score_true.isnan(), "Micro Precision is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Precision is different" - ) - - # Micro precision with NaNs in target - not_nan = ~target_nan.isnan() - score_true = precision(preds[not_nan], target[not_nan].to(int), average="micro") - score_ipu = precision_ipu(preds, target_nan, average="micro") - self.assertFalse(score_true.isnan(), "Micro Precision with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Micro Precision IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Precision with NaN is different" - ) - - # Macro precision - score_true = precision(preds, target.to(int), average="macro", num_classes=4) - score_ipu = precision_ipu(preds, target, average="macro", num_classes=4) - self.assertFalse(score_true.isnan(), "Macro Precision is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Macro Precision is different" - ) - - # Macro precision multiclass with NaNs in target - not_nan = ~target_nan.isnan() - score_true = precision(preds[not_nan], target[not_nan].to(int), average="macro", num_classes=4) - score_ipu = precision_ipu(preds, target_nan, average="macro", num_classes=4) - self.assertFalse(score_true.isnan(), "Macro Precision multiclass with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Macro Precision multiclass IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), - score_ipu.item(), - places=6, - msg="Macro Precision multiclass with NaN is different", - ) - - # Macro precision multiclass with NaNs in target - not_nan = ~target_nan.isnan() - score_true = precision(preds[not_nan], target[not_nan].to(int), average="macro", num_classes=4) - score_ipu = precision_ipu(preds, target_nan, average="macro", num_classes=4) - self.assertFalse(score_true.isnan(), "Macro Precision multiclass with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Macro Precision multiclass IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), - score_ipu.item(), - places=6, - msg="Macro Precision multiclass with NaN is different", - ) - - # Weighted precision multiclass - score_true = precision(preds, target.to(int), average="weighted", num_classes=4) - score_ipu = precision_ipu(preds, target, average="weighted", num_classes=4) - self.assertFalse(score_true.isnan(), "Weighted Precision multiclass is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Weighted Precision multiclass is different" - ) - - # Weighted precision multiclass with NaNs in target - not_nan = ~target_nan.isnan() - score_true = precision(preds[not_nan], target[not_nan].to(int), average="weighted", num_classes=4) - score_ipu = precision_ipu(preds, target_nan, average="weighted", num_classes=4) - self.assertFalse(score_true.isnan(), "Weighted Precision multiclass with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Weighted Precision multiclass IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), - score_ipu.item(), - places=6, - msg="Regular Average Precision multiclass with NaN is different", - ) - - def test_accuracy(self): - preds = deepcopy(self.preds)[:, :4] - target = deepcopy(self.target)[:, 0] - t = deepcopy(target) - - target[t < 0.4] = 0 - target[(t >= 0.4) & (t < 0.6)] = 1 - target[(t >= 0.6) & (t < 0.8)] = 2 - target[(t >= 0.8)] = 3 - - target_nan = deepcopy(target) - target_nan[self.is_nan[:, 0]] = float("nan") - target_nan_bin = deepcopy(target_nan) - target_nan_bin[target_nan > 0] = 1 - - # Micro accuracy binary - score_true = accuracy(preds[:, 0], target.to(int) > 0, average="micro") - score_ipu = accuracy_ipu(preds[:, 0], target > 0, average="micro") - self.assertFalse(score_true.isnan(), "Micro Accuracy binary is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Accuracy binary is different" - ) - - # Micro accuracy binary with NaNs in target - not_nan = ~target_nan.isnan() - score_true = accuracy(preds[:, 0][not_nan], target_nan_bin[not_nan].to(int), average="micro") - score_ipu = accuracy_ipu(preds[:, 0], target_nan_bin, average="micro") - self.assertFalse(score_true.isnan(), "Micro Accuracy binary with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Micro Accuracy binary IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Accuracy with NaN is different" - ) - - # Micro accuracy - score_true = accuracy(preds, target.to(int), average="micro") - score_ipu = accuracy_ipu(preds, target, average="micro") - self.assertFalse(score_true.isnan(), "Micro Accuracy is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Accuracy is different" - ) - - # Micro accuracy with NaNs in target - not_nan = ~target_nan.isnan() - score_true = accuracy(preds[not_nan], target[not_nan].to(int), average="micro") - score_ipu = accuracy_ipu(preds, target_nan, average="micro") - self.assertFalse(score_true.isnan(), "Micro Accuracy with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Micro Accuracy IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Accuracy with NaN is different" - ) - - # Macro accuracy - score_true = accuracy(preds, target.to(int), average="macro", num_classes=4) - score_ipu = accuracy_ipu(preds, target, average="macro", num_classes=4) - self.assertFalse(score_true.isnan(), "Macro Accuracy is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Macro Accuracy is different" - ) - - # Macro accuracy with NaNs in target - not_nan = ~target_nan.isnan() - score_true = accuracy(preds[not_nan], target[not_nan].to(int), average="macro", num_classes=4) - score_ipu = accuracy_ipu(preds, target_nan, average="macro", num_classes=4) - self.assertFalse(score_true.isnan(), "Macro Accuracy with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Macro Accuracy IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Macro Accuracy with NaN is different" - ) - - # Weighted accuracy - score_true = accuracy(preds, target.to(int), average="weighted", num_classes=4) - score_ipu = accuracy_ipu(preds, target, average="weighted", num_classes=4) - self.assertFalse(score_true.isnan(), "Weighted Accuracy is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Weighted Accuracy is different" - ) - - # Weighted accuracy with NaNs in target - not_nan = ~target_nan.isnan() - score_true = accuracy(preds[not_nan], target[not_nan].to(int), average="weighted", num_classes=4) - score_ipu = accuracy_ipu(preds, target_nan, average="weighted", num_classes=4) - self.assertFalse(score_true.isnan(), "Weighted Accuracy with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Weighted Accuracy IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Regular Accuracy with NaN is different" - ) - - def test_recall(self): - preds = deepcopy(self.preds)[:, :4] - target = deepcopy(self.target)[:, 0] - t = deepcopy(target) - - target[t < 0.4] = 0 - target[(t >= 0.4) & (t < 0.6)] = 1 - target[(t >= 0.6) & (t < 0.8)] = 2 - target[(t >= 0.8)] = 3 - - target_nan = deepcopy(target) - target_nan[self.is_nan[:, 0]] = float("nan") - target_nan_bin = deepcopy(target_nan) - target_nan_bin[target_nan > 0] = 1 - - # Micro recall binary - score_true = recall(preds[:, 0], target.to(int) > 0, average="micro") - score_ipu = recall_ipu(preds[:, 0], target > 0, average="micro") - self.assertFalse(score_true.isnan(), "Micro Recall binary is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Recall binary is different" - ) - - # Micro recall binary with NaNs in target - not_nan = ~target_nan.isnan() - score_true = recall(preds[:, 0][not_nan], target_nan_bin[not_nan].to(int), average="micro") - score_ipu = recall_ipu(preds[:, 0], target_nan_bin, average="micro") - self.assertFalse(score_true.isnan(), "Micro Recall binary with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Micro Recall binary IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Recall binary with NaN is different" - ) - - # Micro recall - score_true = recall(preds, target.to(int), average="micro") - score_ipu = recall_ipu(preds, target, average="micro") - self.assertFalse(score_true.isnan(), "Micro Recall is NaN") - self.assertAlmostEqual(score_true.item(), score_ipu.item(), places=6, msg="Micro Recall is different") - - # Micro recall with NaNs in target - not_nan = ~target_nan.isnan() - score_true = recall(preds[not_nan], target[not_nan].to(int), average="micro") - score_ipu = recall_ipu(preds, target_nan, average="micro") - self.assertFalse(score_true.isnan(), "Micro Recall with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Micro Recall IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro Recall with NaN is different" - ) - - # Macro recall multiclass - score_true = recall(preds, target.to(int), average="macro", num_classes=4) - score_ipu = recall_ipu(preds, target, average="macro", num_classes=4) - self.assertFalse(score_true.isnan(), "Macro Recall is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Macro Recall multiclass is different" - ) - - # Macro recall multiclass with NaNs in target - not_nan = ~target_nan.isnan() - score_true = recall(preds[not_nan], target[not_nan].to(int), average="macro", num_classes=4) - score_ipu = recall_ipu(preds, target_nan, average="macro", num_classes=4) - self.assertFalse(score_true.isnan(), "Macro Recall multiclass with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Macro Recall multiclass IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Macro Recall multiclass with NaN is different" - ) - - # Weighted recallmulticlass - score_true = recall(preds, target.to(int), average="weighted", num_classes=4) - score_ipu = recall_ipu(preds, target, average="weighted", num_classes=4) - self.assertFalse(score_true.isnan(), "Weighted Recall is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Weighted Recall is different" - ) - - # Weighted recall multiclass with NaNs in target - not_nan = ~target_nan.isnan() - score_true = recall(preds[not_nan], target[not_nan].to(int), average="weighted", num_classes=4) - score_ipu = recall_ipu(preds, target_nan, average="weighted", num_classes=4) - self.assertFalse(score_true.isnan(), "Weighted Recall multiclass with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Weighted Recall multiclass IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), - score_ipu.item(), - places=6, - msg="Regular Recall multiclass with NaN is different", - ) - - def test_pearsonr(self): - preds = deepcopy(self.preds)[:, 0] - target = deepcopy(self.target)[:, 0] + preds - target_nan = deepcopy(target) - target_nan[self.is_nan[:, 0]] = float("nan") - - # Regular loss - score_true = pearson_corrcoef(preds, target) - score_ipu = pearson_ipu(preds, target) - self.assertFalse(score_true.isnan(), "Pearson is NaN") - self.assertAlmostEqual(score_true.item(), score_ipu.item(), places=4, msg="Pearson is different") - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - score_true = pearson_corrcoef(preds[not_nan], target[not_nan]) - score_ipu = pearson_ipu(preds, target_nan) - self.assertFalse(score_true.isnan(), "Regular PearsonR with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "IPU PearsonR score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=4, msg="Pearson with NaN is different" - ) - - def test_spearmanr(self): - preds = deepcopy(self.preds)[:, 0] - target = deepcopy(self.target)[:, 0] + preds - target_nan = deepcopy(target) - target_nan[self.is_nan[:, 0]] = float("nan") - - # Regular loss - score_true = spearman_corrcoef(preds, target) - score_ipu = spearman_ipu(preds, target) - self.assertFalse(score_true.isnan(), "Spearman is NaN") - self.assertAlmostEqual(score_true.item(), score_ipu.item(), places=4, msg="Spearman is different") - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - score_true = spearman_corrcoef(preds[not_nan], target[not_nan]) - score_ipu = spearman_ipu(preds, target_nan) - self.assertFalse(score_true.isnan(), "Regular Spearman with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "IPU Spearman score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=4, msg="Spearman with NaN is different" - ) - - def test_r2_score(self): - preds = deepcopy(self.preds) - target = deepcopy(self.target) + preds - target_nan = deepcopy(target) - target_nan[self.is_nan] = float("nan") - - # Regular loss - score_true = r2_score(preds, target) - score_ipu = r2_score_ipu(preds, target) - self.assertFalse(score_true.isnan(), "r2_score is NaN") - self.assertAlmostEqual(score_true.item(), score_ipu.item(), places=4, msg="r2_score is different") - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - score_ipu = r2_score_ipu(preds, target_nan, multioutput="raw_values") - for ii in range(preds.shape[1]): - score_true = r2_score( - preds[:, ii][not_nan[:, ii]], target_nan[:, ii][not_nan[:, ii]], multioutput="raw_values" - ) - self.assertFalse(score_true.isnan().any(), f"{ii}: r2_score with target_nan is NaN") - self.assertFalse(score_ipu[ii].isnan().any(), f"{ii}: IPU r2_score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu[ii].item(), places=4, msg=f"{ii}: r2_score with NaN is different" - ) - - def test_fbeta_score(self): - preds = deepcopy(self.preds)[:, :4] - target = deepcopy(self.target)[:, 0] - t = deepcopy(target) - - target[t < 0.4] = 0 - target[(t >= 0.4) & (t < 0.6)] = 1 - target[(t >= 0.6) & (t < 0.8)] = 2 - target[(t >= 0.8)] = 3 - - target_nan = deepcopy(target) - target_nan[self.is_nan[:, 0]] = float("nan") - target_nan_bin = deepcopy(target_nan) - target_nan_bin[target_nan > 0] = 1 - - # Micro fbeta_score binary - score_true = fbeta_score(preds[:, 0], target.to(int) > 0, average="micro", beta=0.5) - score_ipu = fbeta_score_ipu(preds[:, 0], target > 0, average="micro", beta=0.5) - self.assertFalse(score_true.isnan(), "Micro FBETA_score binary is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro FBETA_score binary is different" - ) - - # Micro fbeta_score binary with NaNs in target - not_nan = ~target_nan.isnan() - score_true = fbeta_score( - preds[:, 0][not_nan], target_nan_bin[not_nan].to(int), average="micro", beta=0.5 - ) - score_ipu = fbeta_score_ipu(preds[:, 0], target_nan_bin, average="micro", beta=0.5) - self.assertFalse(score_true.isnan(), "Micro FBETA_score binary with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Micro FBETA_score binary IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), - score_ipu.item(), - places=6, - msg="Micro FBETA_score binary with NaN is different", - ) - - # Micro fbeta_score - score_true = fbeta_score(preds, target.to(int), average="micro", beta=0.5) - score_ipu = fbeta_score_ipu(preds, target, average="micro", beta=0.5) - self.assertFalse(score_true.isnan(), "Micro FBETA_score is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro FBETA_score is different" - ) - - # Micro fbeta_score with NaNs in target - not_nan = ~target_nan.isnan() - score_true = fbeta_score(preds[not_nan], target[not_nan].to(int), average="micro", beta=0.5) - score_ipu = fbeta_score_ipu(preds, target_nan, average="micro", beta=0.5) - self.assertFalse(score_true.isnan(), "Micro FBETA_score with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Micro FBETA_score IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro FBETA_score with NaN is different" - ) - - # Macro fbeta_score multiclass - score_true = fbeta_score(preds, target.to(int), average="macro", num_classes=4, beta=0.5) - score_ipu = fbeta_score_ipu(preds, target, average="macro", num_classes=4, beta=0.5) - self.assertFalse(score_true.isnan(), "Macro FBETA_score is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Macro FBETA_score multiclass is different" - ) - - # Macro fbeta_score multiclass with NaNs in target - not_nan = ~target_nan.isnan() - score_true = fbeta_score( - preds[not_nan], target[not_nan].to(int), average="macro", num_classes=4, beta=0.5 - ) - score_ipu = fbeta_score_ipu(preds, target_nan, average="macro", num_classes=4, beta=0.5) - self.assertFalse(score_true.isnan(), "Macro FBETA_score multiclass with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Macro FBETA_score multiclass IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), - score_ipu.item(), - places=6, - msg="Macro FBETA_score multiclass with NaN is different", - ) - - # Weighted fbeta_scoremulticlass - score_true = fbeta_score(preds, target.to(int), average="weighted", num_classes=4, beta=0.5) - score_ipu = fbeta_score_ipu(preds, target, average="weighted", num_classes=4, beta=0.5) - self.assertFalse(score_true.isnan(), "Weighted FBETA_score is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Weighted FBETA_score is different" - ) - - # Weighted fbeta_score multiclass with NaNs in target - not_nan = ~target_nan.isnan() - score_true = fbeta_score( - preds[not_nan], target[not_nan].to(int), average="weighted", num_classes=4, beta=0.5 - ) - score_ipu = fbeta_score_ipu(preds, target_nan, average="weighted", num_classes=4, beta=0.5) - self.assertFalse(score_true.isnan(), "Weighted FBETA_score multiclass with target_nan is NaN") - self.assertFalse( - score_ipu.isnan(), "Weighted FBETA_score multiclass IPU score with target_nan is NaN" - ) - self.assertAlmostEqual( - score_true.item(), - score_ipu.item(), - places=6, - msg="Regular FBETA_score multiclass with NaN is different", - ) - - def test_f1_score(self): - preds = deepcopy(self.preds)[:, :4] - target = deepcopy(self.target)[:, 0] - t = deepcopy(target) - - target[t < 0.4] = 0 - target[(t >= 0.4) & (t < 0.6)] = 1 - target[(t >= 0.6) & (t < 0.8)] = 2 - target[(t >= 0.8)] = 3 - - target_nan = deepcopy(target) - target_nan[self.is_nan[:, 0]] = float("nan") - target_nan_bin = deepcopy(target_nan) - target_nan_bin[target_nan > 0] = 1 - - # Micro f1_score binary - score_true = f1_score(preds[:, 0], target.to(int) > 0, average="micro") - score_ipu = f1_score_ipu(preds[:, 0], target > 0, average="micro") - self.assertFalse(score_true.isnan(), "Micro F1_score binary is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro F1_score binary is different" - ) - - # Micro f1_score binary with NaNs in target - not_nan = ~target_nan.isnan() - score_true = f1_score(preds[:, 0][not_nan], target_nan_bin[not_nan].to(int), average="micro") - score_ipu = f1_score_ipu(preds[:, 0], target_nan_bin, average="micro") - self.assertFalse(score_true.isnan(), "Micro F1_score binary with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Micro F1_score binary IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro F1_score binary with NaN is different" - ) - - # Micro f1_score - score_true = f1_score(preds, target.to(int), average="micro") - score_ipu = f1_score_ipu(preds, target, average="micro") - self.assertFalse(score_true.isnan(), "Micro F1_score is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro F1_score is different" - ) - - # Micro f1_score with NaNs in target - not_nan = ~target_nan.isnan() - score_true = f1_score(preds[not_nan], target[not_nan].to(int), average="micro") - score_ipu = f1_score_ipu(preds, target_nan, average="micro") - self.assertFalse(score_true.isnan(), "Micro F1_score with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Micro F1_score IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Micro F1_score with NaN is different" - ) - - # Macro f1_score multiclass - score_true = f1_score(preds, target.to(int), average="macro", num_classes=4) - score_ipu = f1_score_ipu(preds, target, average="macro", num_classes=4) - self.assertFalse(score_true.isnan(), "Macro F1_score is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Macro F1_score multiclass is different" - ) - - # Macro f1_score multiclass with NaNs in target - not_nan = ~target_nan.isnan() - score_true = f1_score(preds[not_nan], target[not_nan].to(int), average="macro", num_classes=4) - score_ipu = f1_score_ipu(preds, target_nan, average="macro", num_classes=4) - self.assertFalse(score_true.isnan(), "Macro F1_score multiclass with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Macro F1_score multiclass IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), - score_ipu.item(), - places=6, - msg="Macro F1_score multiclass with NaN is different", - ) - - # Weighted f1_scoremulticlass - score_true = f1_score(preds, target.to(int), average="weighted", num_classes=4) - score_ipu = f1_score_ipu(preds, target, average="weighted", num_classes=4) - self.assertFalse(score_true.isnan(), "Weighted F1_score is NaN") - self.assertAlmostEqual( - score_true.item(), score_ipu.item(), places=6, msg="Weighted F1_score is different" - ) - - # Weighted f1_score multiclass with NaNs in target - not_nan = ~target_nan.isnan() - score_true = f1_score(preds[not_nan], target[not_nan].to(int), average="weighted", num_classes=4) - score_ipu = f1_score_ipu(preds, target_nan, average="weighted", num_classes=4) - self.assertFalse(score_true.isnan(), "Weighted F1_score multiclass with target_nan is NaN") - self.assertFalse(score_ipu.isnan(), "Weighted F1_score multiclass IPU score with target_nan is NaN") - self.assertAlmostEqual( - score_true.item(), - score_ipu.item(), - places=6, - msg="Regular F1_score multiclass with NaN is different", - ) - - def test_mse(self): - preds = deepcopy(self.preds) - target = deepcopy(self.target) - target_nan = deepcopy(self.target_nan) - squared = True - - # Regular loss - loss_true = mean_squared_error(preds, target, squared) - loss_ipu = mean_squared_error_ipu(preds=preds, target=target, squared=squared) - self.assertFalse(loss_true.isnan(), "Regular Mean Squared Error is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular Mean Squared Error is different" - ) - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - loss_true = mean_squared_error(preds[not_nan], target[not_nan], squared) - loss_ipu = mean_squared_error_ipu(preds=preds, target=target_nan, squared=squared) - self.assertFalse(loss_true.isnan(), "Regular Mean Squared Error with target_nan is NaN") - self.assertFalse(loss_ipu.isnan(), "Regular Mean Squared Error IPU with target_nan is NaN") - self.assertAlmostEqual( - loss_true.item(), - loss_ipu.item(), - places=6, - msg="Regular Mean Squared Error with NaN is different", - ) - - squared = False - - # Regular loss - loss_true = mean_squared_error(preds, target, squared) - loss_ipu = mean_squared_error_ipu(preds=preds, target=target, squared=squared) - self.assertFalse(loss_true.isnan(), "Regular Mean Squared Error is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular Mean Squared Error is different" - ) - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - loss_true = mean_squared_error(preds[not_nan], target[not_nan], squared) - loss_ipu = mean_squared_error_ipu(preds=preds, target=target_nan, squared=squared) - self.assertFalse(loss_true.isnan(), "Regular Mean Squared Error with target_nan is NaN") - self.assertFalse(loss_ipu.isnan(), "Regular Mean Squared Error IPU with target_nan is NaN") - self.assertAlmostEqual( - loss_true.item(), - loss_ipu.item(), - places=6, - msg="Regular Mean Squared Error with NaN is different", - ) - - def test_mae(self): - preds = deepcopy(self.preds) - target = deepcopy(self.target) - target_nan = deepcopy(self.target_nan) - - # Regular loss - loss_true = mean_absolute_error(preds, target) - loss_ipu = mean_absolute_error_ipu(preds=preds, target=target) - self.assertFalse(loss_true.isnan(), "Regular Mean Absolute Error is NaN") - self.assertAlmostEqual( - loss_true.item(), loss_ipu.item(), places=6, msg="Regular Mean Absolute Error is different" - ) - - # Regular loss with NaNs in target - not_nan = ~target_nan.isnan() - loss_true = mean_absolute_error(preds[not_nan], target[not_nan]) - loss_ipu = mean_absolute_error_ipu(preds=preds, target=target_nan) - self.assertFalse(loss_true.isnan(), "Regular Mean Absolute Error with target_nan is NaN") - self.assertFalse(loss_ipu.isnan(), "Regular Mean Absolute Error IPU with target_nan is NaN") - self.assertAlmostEqual( - loss_true.item(), - loss_ipu.item(), - places=6, - msg="Regular Mean Absolute Error with NaN is different", - ) - - -if __name__ == "__main__": - ut.main() diff --git a/tests/test_ipu_options.py b/tests/test_ipu_options.py deleted file mode 100644 index c3cc9aa3e..000000000 --- a/tests/test_ipu_options.py +++ /dev/null @@ -1,149 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -import pytest -from graphium.config._loader import _get_ipu_opts, load_ipu_options -from graphium.ipu.ipu_utils import ipu_options_list_to_file - -import tempfile -from typing import Optional, List -import os - -CONFIG_EXTRACT = { - "trainer": {"trainer": {"accumulate_grad_batches": 10}}, - "accelerator": { - "type": "ipu", - "config_override": { - "datamodule": { - "args": { - "ipu_dataloader_training_opts": { - "mode": "async", - "max_num_nodes_per_graph": 44, - "max_num_edges_per_graph": 80, - }, - "ipu_dataloader_inference_opts": { - "mode": "async", - "max_num_nodes_per_graph": 44, - "max_num_edges_per_graph": 80, - }, - "batch_size_training": 50, - "batch_size_inference": 50, - } - }, - "predictor": {"optim_kwargs": {"loss_scaling": 1024}}, - "trainer": {"trainer": {"precision": 16, "accumulate_grad_batches": 4}}, - }, - "ipu_config": [ - "deviceIterations(5)", - "replicationFactor(16)", - "TensorLocations.numIOTiles(128)", - '_Popart.set("defaultBufferingDepth", 128)', - "Precision.enableStochasticRounding(True)", - ], - "ipu_inference_config": [ - "deviceIterations(1)", - "replicationFactor(4)", - "TensorLocations.numIOTiles(32)", - '_Popart.set("defaultBufferingDepth", 16)', - "Precision.enableStochasticRounding(True)", - ], - }, -} - - -@pytest.mark.ipu -def test_ipu_options(): - try: - import poptorch - - ipu_opts, ipu_inference_opts = _get_ipu_opts(CONFIG_EXTRACT) - - # Define the expected IPU options for comparison - expected_ipu_opts = [ - "deviceIterations(5)", - "replicationFactor(16)", - "TensorLocations.numIOTiles(128)", - '_Popart.set("defaultBufferingDepth", 128)', - "Precision.enableStochasticRounding(True)", - ] - expected_ipu_inference_opts = [ - "deviceIterations(1)", - "replicationFactor(4)", - "TensorLocations.numIOTiles(32)", - '_Popart.set("defaultBufferingDepth", 16)', - "Precision.enableStochasticRounding(True)", - ] - - # Test the _get_ipu_opts method - ipu_opts, ipu_inference_opts = _get_ipu_opts(CONFIG_EXTRACT) - assert ipu_opts == expected_ipu_opts, f"Expected {expected_ipu_opts}, but got {ipu_opts}" - assert ( - ipu_inference_opts == expected_ipu_inference_opts - ), f"Expected {expected_ipu_inference_opts}, but got {ipu_inference_opts}" - - # Test the load_ipu_options method - ipu_training_opts, ipu_inference_opts = load_ipu_options( - ipu_opts=ipu_opts, - seed=42, - model_name="test_model", - gradient_accumulation=CONFIG_EXTRACT["trainer"]["trainer"].get("accumulate_grad_batches", None), - ipu_inference_opts=ipu_inference_opts, - ) - - # Ensure that the options objects are not None - assert ipu_training_opts is not None, "Expected ipu_training_opts not to be None" - assert ipu_inference_opts is not None, "Expected ipu_inference_opts not to be None" - - # Test the properties of the options objects - assert ( - ipu_training_opts.replication_factor == 16 - ), "Expected replication_factor of ipu_training_opts to be 16" - assert ( - ipu_inference_opts.replication_factor == 4 - ), "Expected replication_factor of ipu_inference_opts to be 4" - assert ipu_training_opts._popart, "Expected _popart of ipu_training_opts to be True" - assert ipu_inference_opts._popart, "Expected _popart of ipu_inference_opts to be True" - - except ImportError: - pytest.skip("Skipping this test because poptorch is not available") - - -@pytest.mark.ipu -def test_ipu_options_list_to_file(): - # Define a list of IPU options - ipu_options = [ - "deviceIterations(5)", - "replicationFactor(16)", - "TensorLocations.numIOTiles(128)", - '_Popart.set("defaultBufferingDepth", 128)', - "Precision.enableStochasticRounding(True)", - ] - - # Call the function with the list of IPU options - tmp_file = ipu_options_list_to_file(ipu_options) - - # Check that the function returns a temporary file object - assert isinstance(tmp_file, tempfile._TemporaryFileWrapper) - - # Check that the temporary file exists - assert os.path.exists(tmp_file.name) - - # Check the contents of the temporary file - with open(tmp_file.name, "r") as f: - contents = f.read().splitlines() - assert contents == ipu_options - - # Check the behavior when the input is None - tmp_file = ipu_options_list_to_file(None) - assert tmp_file is None diff --git a/tests/test_ipu_poptorch.py b/tests/test_ipu_poptorch.py deleted file mode 100644 index 4f951d504..000000000 --- a/tests/test_ipu_poptorch.py +++ /dev/null @@ -1,29 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -import pytest - - -@pytest.mark.ipu -def test_poptorch(): - # Run this test only if poptorch is available - # Primarily to test the install and SDK is correctly activated - try: - import poptorch - - opts = poptorch.Options() - - except ImportError: - raise ImportError - assert True diff --git a/tests/test_ipu_to_dense_batch.py b/tests/test_ipu_to_dense_batch.py deleted file mode 100644 index 55c6e3372..000000000 --- a/tests/test_ipu_to_dense_batch.py +++ /dev/null @@ -1,146 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -import pytest -import torch -from torch_geometric.data import Data, Batch -from graphium.ipu.to_dense_batch import to_dense_batch -from warnings import warn - - -# General imports -import yaml -import unittest as ut -import numpy as np -from copy import deepcopy -from warnings import warn -from lightning import Trainer, LightningModule -from lightning_graphcore import IPUStrategy -from functools import partial - -import torch -from torch.utils.data.dataloader import default_collate - -# Current library imports -from graphium.config._loader import load_datamodule, load_metrics, load_architecture, load_accelerator - - -@pytest.mark.ipu -class TestIPUBatch: - @pytest.fixture(autouse=True) - def setup_class(self): - self.in_dim = 12 - self.out_dim = 12 - self.in_dim_edges = 10 - self.out_dim_edges = 10 - self.edge_idx1 = torch.stack( - [torch.tensor([0, 1, 2, 3, 2], dtype=torch.int), torch.tensor([1, 2, 3, 0, 0], dtype=torch.int)] - ) - self.edge_idx2 = torch.stack( - [torch.tensor([0, 0, 0, 1], dtype=torch.int), torch.tensor([0, 1, 2, 0], dtype=torch.int)] - ) - self.x1 = torch.randn(self.edge_idx1.max().item() + 1, self.in_dim, dtype=torch.float32) - self.e1 = torch.randn(self.edge_idx1.shape[-1], self.in_dim_edges, dtype=torch.float32) - self.x2 = torch.randn(self.edge_idx2.max().item() + 1, self.in_dim, dtype=torch.float32) - self.e2 = torch.randn(self.edge_idx2.shape[-1], self.in_dim_edges, dtype=torch.float32) - self.g1 = Data(feat=self.x1, edge_index=self.edge_idx1, edge_feat=self.e1) - self.g2 = Data(feat=self.x2, edge_index=self.edge_idx2, edge_feat=self.e2) - self.bg = Batch.from_data_list([self.g1, self.g2]) - self.attn_kwargs = {"embed_dim": self.in_dim, "num_heads": 2, "batch_first": True} - - # @pytest.mark.skip - @pytest.mark.parametrize("max_num_nodes_per_graph, batch_size", [(10, 5), (20, 10), (30, 15)]) - def test_ipu_to_dense_batch(self, max_num_nodes_per_graph, batch_size): - # Run this test only if poptorch is available - try: - import poptorch - - opts = poptorch.Options() - opts.useIpuModel(True) - - class MyModel(torch.nn.Module): - def __init__(self): - super(MyModel, self).__init__() - - def forward(self, x, batch): - return to_dense_batch( - x, - batch=batch, - batch_size=batch_size, - max_num_nodes_per_graph=max_num_nodes_per_graph, - drop_nodes_last_graph=False, - ) - - model = MyModel() - model = model.eval() - poptorch_model_inf = poptorch.inferenceModel(model, options=opts) - # for data in train_dataloader: - out, mask, idx = poptorch_model_inf(self.bg.feat, self.bg.batch) - # Check the output sizes - assert out.size() == torch.Size([batch_size, max_num_nodes_per_graph, 12]) - # Check the mask for true / false values - assert mask.size() == torch.Size([batch_size, max_num_nodes_per_graph]) - assert torch.sum(mask) == 7 - assert (mask[0][:4] == True).all() - assert (mask[0][4:] == False).all() - assert (mask[1][:3] == True).all() - assert (mask[1][3:] == False).all() - assert (mask[2:] == False).all() - - # Check the idx are all the true values in the mask - assert (mask.flatten()[idx] == True).all() - poptorch_model_inf.detachFromDevice() - except ImportError: - pytest.skip("Skipping this test because poptorch is not available") - - def test_ipu_to_dense_batch_no_batch_no_max_nodes(self): - h_dense, mask = to_dense_batch( - self.bg.feat, - batch=None, - batch_size=None, - max_num_nodes_per_graph=None, - drop_nodes_last_graph=False, - ) - # Add assertions to check the output as needed - assert torch.allclose(h_dense, self.bg.feat.unsqueeze(0), atol=1e-5), "Tensors are not equal" - assert mask.size(1) == h_dense.size(1) - assert mask.all().item(), "Not all values in the tensor are True" - - def test_ipu_to_dense_batch_no_batch(self): - max_nodes_per_graph = 10 - h_dense, mask, id = to_dense_batch( - self.bg.feat, - batch=None, - batch_size=None, - max_num_nodes_per_graph=max_nodes_per_graph, - drop_nodes_last_graph=False, - ) - assert mask.size() == (1, max_nodes_per_graph) - assert torch.sum(mask) == 7 - assert torch.equal(id, torch.arange(7)) - assert h_dense.size() == (1, max_nodes_per_graph, self.bg.feat.size(-1)) - - def test_ipu_to_dense_batch_drop_last(self): - out, mask, idx = to_dense_batch( - self.bg.feat, - batch=None, - batch_size=None, - max_num_nodes_per_graph=3, - drop_nodes_last_graph=True, - ) - # Add assertions to check the output as needed - assert mask.size(1) == out.size(1) - # Check the mask and output have been clipped - assert mask.size() == torch.Size([1, 3]) - assert mask.all().item(), "Not all values in the tensor are True" diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 22611f32f..d4bfbc993 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -17,7 +17,7 @@ import unittest as ut -class TestLoader(ut.TestCase): +class test_Loader(ut.TestCase): def test_merge_dicts(self): dict_a = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": 4} diff --git a/tests/test_losses.py b/tests/test_losses.py index b2f343bf9..82e8090ea 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -26,7 +26,8 @@ def _parse(loss_fun): eval_options = EvalOptions(loss_fun=loss_fun, metrics_on_progress_bar=None) - return eval_options.parse_loss_fun(loss_fun) + loss_name, loss_fun = eval_options.parse_loss_fun(loss_fun) + return loss_fun class test_HybridCELoss(ut.TestCase): diff --git a/tests/test_metrics.py b/tests/test_metrics.py index dc5bc01b2..ba36e8d06 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -16,6 +16,8 @@ Unit tests for the metrics and wrappers of graphium/trainer/metrics/... """ +import pytest + import torch import unittest as ut import tempfile @@ -27,7 +29,8 @@ Thresholder, ) -from torchmetrics.functional import mean_squared_error +from torchmetrics.functional import mean_squared_error, pearson_corrcoef, auroc +from torchmetrics import MeanSquaredError class test_Metrics(ut.TestCase): @@ -142,12 +145,10 @@ def test_target_nan_mask(self): def test_pickling(self): pickle_file = os.path.join(tempfile.gettempdir(), "test_metric_pickled.pkl") - metrics = ["mae", "mse", mean_squared_error] + metrics = ["mae", "mse", MeanSquaredError] target_nan_masks = [None, 2, "ignore"] multitask_handlings = [None, "flatten", "mean-per-label"] - squeeze_targets = [True, False] target_to_ints = [True, False] - other_kwargs = [{}, {"squared": False}] thresholds = [ None, {"threshold": 0.2, "operator": "greater"}, @@ -159,73 +160,136 @@ def test_pickling(self): for metric in metrics: for target_nan_mask in target_nan_masks: - for kwargs in other_kwargs: - for threshold_kwargs in thresholds: - for multitask_handling in multitask_handlings: - for squeeze_target in squeeze_targets: - for target_to_int in target_to_ints: - err_msg = f"{metric} - {target_nan_mask} - {kwargs} - {threshold_kwargs}" - - if (multitask_handling is None) and (target_nan_mask == "ignore"): - # Raise with incompatible options - with self.assertRaises(ValueError): - MetricWrapper( - metric=metric, - threshold_kwargs=threshold_kwargs, - target_nan_mask=target_nan_mask, - multitask_handling=multitask_handling, - squeeze_target=squeeze_target, - target_to_int=target_to_int, - **kwargs, - ) - - else: - metric_wrapper = MetricWrapper( - metric=metric, - threshold_kwargs=threshold_kwargs, - target_nan_mask=target_nan_mask, - multitask_handling=multitask_handling, - squeeze_target=squeeze_target, - target_to_int=target_to_int, - **kwargs, - ) - - # Check that the metric can be saved and re-loaded without error - torch.save(metric_wrapper, pickle_file) - metric_wrapper2 = torch.load(pickle_file) - self.assertTrue(metric_wrapper == metric_wrapper2, msg=err_msg) - - # Check that the metric only contains primitive types - state = metric_wrapper.__getstate__() - if state["threshold_kwargs"] is not None: - self.assertIsInstance( - state["threshold_kwargs"], dict, msg=err_msg - ) - if isinstance(metric, str): - self.assertIsInstance(state["metric"], str, msg=err_msg) - - def test_classifigression_target_squeezing(self): - preds = torch.Tensor([[0.1, 0.1, 0.3, 0.5, 0.0, 0.1, 0.0, 0.7, 0.2, 0.0]]) - target = torch.Tensor([3, 0]) - expected_scores = [0.5, 0.75] - n_brackets = 5 - metrics = ["accuracy", "averageprecision"] - other_kwargs = [ - {"task": "multiclass", "num_classes": n_brackets, "top_k": 1}, - {"task": "multiclass", "num_classes": n_brackets}, - ] + for threshold_kwargs in thresholds: + for multitask_handling in multitask_handlings: + for target_to_int in target_to_ints: + err_msg = f"{metric} - {target_nan_mask} - {threshold_kwargs}" + + if (multitask_handling is None) and (target_nan_mask == "ignore"): + # Raise with incompatible options + with self.assertRaises(ValueError): + MetricWrapper( + metric=metric, + threshold_kwargs=threshold_kwargs, + target_nan_mask=target_nan_mask, + multitask_handling=multitask_handling, + target_to_int=target_to_int, + ) + + else: + metric_wrapper = MetricWrapper( + metric=metric, + threshold_kwargs=threshold_kwargs, + target_nan_mask=target_nan_mask, + multitask_handling=multitask_handling, + target_to_int=target_to_int, + ) + + # Check that the metric can be saved and re-loaded without error + torch.save(metric_wrapper, pickle_file) + metric_wrapper2 = torch.load(pickle_file) + self.assertTrue(metric_wrapper == metric_wrapper2, msg=err_msg) + + # Check that the metric only contains primitive types + state = metric_wrapper.__getstate__() + if state["threshold_kwargs"] is not None: + self.assertIsInstance( + state["threshold_kwargs"], dict, msg=err_msg + ) + if isinstance(metric, str): + self.assertIsInstance(state["metric"], str, msg=err_msg) + + + def test_update_compute_reset(self): + pytest.skip("Will be obsolete once torchmetrics are updated") + + torch.manual_seed(42) + th = 0.7 + + # ---------- ACCURACY ---------- + metric = MetricWrapper( + metric="accuracy", threshold_kwargs={"threshold": th, "operator": "greater"}, task="binary", + ) + for batch_size in [1, 5, 25, 100]: + # Generate random predictions and targets, and compute the true accuracy + preds = torch.rand(100, dtype=torch.float32) + target = torch.rand(100, dtype=torch.float32) + preds_greater = preds > th + target_greater = target > th + true_accuracy = (preds_greater == target_greater).float().mean() + + # Test the reset, update and compute + metric.reset() + for ii in range(0, 100, batch_size): + preds_batch = preds[ii : ii + batch_size] + target_batch = target_greater[ii : ii + batch_size] + metric.update(preds_batch, target_batch) + + self.assertAlmostEqual(metric.compute(), true_accuracy, places=5, msg=f"Error for batch_size={batch_size}") + + # ---------- PEARSONR ---------- + metric = MetricWrapper( + metric="pearsonr", + ) + for batch_size in [1, 5, 25, 100]: + # Generate random predictions and targets, and compute the true pearsonr + preds = torch.rand(100, dtype=torch.float32) + target = torch.rand(100, dtype=torch.float32) + true_pearson = pearson_corrcoef(preds, target) - for metric, kwargs, expected_score in zip(metrics, other_kwargs, expected_scores): - metric_wrapper = MetricWrapper( - metric=metric, + # Test the reset, update and compute with pearsonr + metric.reset() + for ii in range(0, 100, batch_size): + preds_batch = preds[ii : ii + batch_size] + target_batch = target[ii : ii + batch_size] + metric.update(preds_batch, target_batch) + + self.assertAlmostEqual(metric.compute().numpy(), true_pearson.numpy(), places=5, msg=f"Error for batch_size={batch_size}") + + + # ---------- PEARSONR with mean-per-label ---------- + + metric = MetricWrapper( + metric="pearsonr", multitask_handling="mean-per-label", - squeeze_targets=True, + ) + for batch_size in [1, 5, 25, 100]: + # Generate random predictions and targets, and compute the true pearsonr + preds = torch.rand(100, 10, dtype=torch.float32) + target = torch.rand(100, 10, dtype=torch.float32) + true_pearson = pearson_corrcoef(preds, target).mean().numpy() + + # Test the pearson reset, update and compute with mean-per-label + metric.reset() + for ii in range(0, 100, batch_size): + preds_batch = preds[ii : ii + batch_size] + target_batch = target[ii : ii + batch_size] + metric.update(preds_batch, target_batch) + + self.assertAlmostEqual(metric.compute().numpy(), true_pearson, places=5, msg=f"Error for batch_size={batch_size}") + + # ---------- AUROC with mean-per-label ---------- + metric = MetricWrapper( + metric="auroc", target_to_int=True, - **kwargs, + multitask_handling="mean-per-label", + task="binary", ) - score = metric_wrapper(preds, target) - - assert score == expected_score + for batch_size in [1, 5, 25, 100]: + # Generate random predictions and targets, and compute the true auroc + preds = torch.rand(100, 10, dtype=torch.float32) + target = (0.5*preds + 0.5*torch.rand(100, 10, dtype=torch.float32)) > th + true_auroc = torch.stack([auroc(preds[:, ii], target[:, ii], task="binary") for ii in range(preds.shape[1])]).mean().numpy() + + # Test the auroc reset, update and compute with mean-per-label + metric.reset() + for ii in range(0, 100, batch_size): + preds_batch = preds[ii : ii + batch_size] + target_batch = target[ii : ii + batch_size] + metric.update(preds_batch, target_batch) + + self.assertAlmostEqual(metric.compute().numpy(), true_auroc, places=5, msg=f"Error for batch_size={batch_size}") + if __name__ == "__main__": diff --git a/tests/test_multitask_datamodule.py b/tests/test_multitask_datamodule.py index b8d2119e1..664334561 100644 --- a/tests/test_multitask_datamodule.py +++ b/tests/test_multitask_datamodule.py @@ -22,8 +22,10 @@ import numpy as np import graphium +TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000" -class Test_Multitask_DataModule(ut.TestCase): + +class test_Multitask_DataModule(ut.TestCase): def setUp(self): # Create a temporary directory self.tmp_test_dir = tempfile.mkdtemp() @@ -109,12 +111,9 @@ def test_multitask_fromsmiles_dm( # Task-independent arguments dm_args["featurization"] = featurization_args - dm_args["featurization_n_jobs"] = 16 - dm_args["featurization_progress"] = True - dm_args["featurization_backend"] = "loky" dm_args["num_workers"] = 0 dm_args["pin_memory"] = True - dm_args["processed_graph_data_path"] = None + dm_args["processed_graph_data_path"] = TEMP_CACHE_DATA_PATH dm_args["batch_size_training"] = 16 dm_args["batch_size_inference"] = 16 @@ -175,6 +174,8 @@ def test_multitask_fromsmiles_from_config(self): dm_args["task_specific_args"]["logp"]["df_path"] = None dm_args["task_specific_args"]["score"]["df_path"] = None + dm_args["processed_graph_data_path"] = TEMP_CACHE_DATA_PATH + dm = graphium.data.MultitaskFromSmilesDataModule(**dm_args) # assert dm.num_node_feats == 50 @@ -205,6 +206,7 @@ def test_multitask_fromsmiles_from_config_csv(self): config = graphium.load_config(name="zinc_default_multitask_pyg") dm_args = OmegaConf.to_container(config.datamodule.args, resolve=True) + dm_args["processed_graph_data_path"] = TEMP_CACHE_DATA_PATH dm = graphium.data.MultitaskFromSmilesDataModule(**dm_args) dm.prepare_data() @@ -232,6 +234,7 @@ def test_multitask_fromsmiles_from_config_parquet(self): config = graphium.load_config(name="fake_multilevel_multitask_pyg") dm_args = OmegaConf.to_container(config.datamodule.args, resolve=True) + dm_args["processed_graph_data_path"] = TEMP_CACHE_DATA_PATH dm = graphium.data.MultitaskFromSmilesDataModule(**dm_args) dm.prepare_data() @@ -260,6 +263,7 @@ def test_multitask_with_missing_fromsmiles_from_config_parquet(self): config = graphium.load_config(name="fake_and_missing_multilevel_multitask_pyg") dm_args = OmegaConf.to_container(config.datamodule.args, resolve=True) + dm_args["processed_graph_data_path"] = TEMP_CACHE_DATA_PATH dm = graphium.data.MultitaskFromSmilesDataModule(**dm_args) dm.prepare_data() @@ -288,23 +292,25 @@ def test_extract_graph_level_singletask(self): df = pd.read_parquet(f"tests/converted_fake_multilevel_data.parquet") num_graphs = len(df) label_cols = ["graph_label"] - output = graphium.data.datamodule.extract_labels(df, "graph", label_cols) + output, output_offsets = graphium.data.datamodule.extract_labels(df, "graph", label_cols) assert isinstance(output, np.ndarray) assert len(output.shape) == 2 assert output.shape[0] == num_graphs assert output.shape[1] == 1 + assert output_offsets is None def test_extract_graph_level_multitask(self): df = pd.read_parquet(f"tests/converted_fake_multilevel_data.parquet") num_graphs = len(df) label_cols = ["graph_label", "graph_label"] - output = graphium.data.datamodule.extract_labels(df, "graph", label_cols) + output, output_offsets = graphium.data.datamodule.extract_labels(df, "graph", label_cols) assert isinstance(output, np.ndarray) assert len(output.shape) == 2 assert output.shape[0] == num_graphs assert output.shape[1] == len(label_cols) + assert output_offsets is None def test_extract_graph_level_multitask_missing_cols(self): df = pd.read_parquet(f"tests/converted_fake_multilevel_data.parquet") @@ -316,7 +322,7 @@ def test_extract_graph_level_multitask_missing_cols(self): for missing_col in label_cols[:replace]: df[missing_col].iloc[drop_index] = None - output = graphium.data.datamodule.extract_labels(df, "graph", label_cols) + output, output_offsets = graphium.data.datamodule.extract_labels(df, "graph", label_cols) assert isinstance(output, np.ndarray) assert len(output.shape) == 2 @@ -325,17 +331,24 @@ def test_extract_graph_level_multitask_missing_cols(self): def test_non_graph_level_extract_labels(self): df = pd.read_parquet(f"tests/converted_fake_multilevel_data.parquet") + num_graphs = len(df) for level in ["node", "edge", "nodepair"]: label_cols = [f"{level}_label_{suffix}" for suffix in ["list", "np"]] - output = graphium.data.datamodule.extract_labels(df, level, label_cols) + output, output_offsets = graphium.data.datamodule.extract_labels(df, level, label_cols) - assert isinstance(output, list) - assert len(output[0].shape) == 2 - assert output[0].shape[1] == len(label_cols) + assert isinstance(output, np.ndarray) + assert len(output.shape) == 2 + assert output.shape[1] == len(label_cols) + assert output_offsets is not None + assert isinstance(output_offsets, np.ndarray) + assert len(output_offsets.shape) == 1 + assert output_offsets.shape[0] == (num_graphs + 1) + assert output.shape[0] == output_offsets[-1] def test_non_graph_level_extract_labels_missing_cols(self): df = pd.read_parquet(f"tests/converted_fake_multilevel_data.parquet") + num_graphs = len(df) for level in ["node", "edge", "nodepair"]: label_cols = [f"{level}_label_{suffix}" for suffix in ["list", "np"]] @@ -344,16 +357,28 @@ def test_non_graph_level_extract_labels_missing_cols(self): for missing_col in label_cols[:replace]: df.loc[drop_index, missing_col] = None - output = graphium.data.datamodule.extract_labels(df, level, label_cols) + output, output_offsets = graphium.data.datamodule.extract_labels(df, level, label_cols) + + assert isinstance(output, np.ndarray) + assert len(output.shape) == 2 + assert output.shape[1] == len(label_cols) + assert output_offsets is not None + assert isinstance(output_offsets, np.ndarray) + assert len(output_offsets.shape) == 1 + assert output_offsets.shape[0] == (num_graphs + 1) + assert output.shape[0] == output_offsets[-1] for idx in drop_index: - assert len(output[idx].shape) == 2 - assert output[idx].shape[1] == len(label_cols) + begin_idx = output_offsets[idx] + end_idx = output_offsets[idx + 1] + values = output[begin_idx:end_idx] + assert len(values.shape) == 2 + assert values.shape[1] == len(label_cols) - # Check that number of labels is adjusted correctly - if replace == 1: - non_missing_col = label_cols[1] - assert output[idx].shape[0] == len(df[non_missing_col][idx]) + # All removed entries must be nan + assert np.all(np.isnan(values[:, :replace])) + # All kept entries should be non-nan in this case + assert not np.any(np.isnan(values[:, replace:])) def test_tdc_admet_benchmark_data_module(self): """ @@ -369,7 +394,7 @@ def test_tdc_admet_benchmark_data_module(self): raise # Make sure we can initialize the module and run the main endpoints - data_module = graphium.data.ADMETBenchmarkDataModule() + data_module = graphium.data.TDCBenchmarkDataModule() data_module.prepare_data() data_module.setup() diff --git a/tests/test_mup.py b/tests/test_mup.py index b60e0ccf3..f5fd48987 100644 --- a/tests/test_mup.py +++ b/tests/test_mup.py @@ -151,7 +151,7 @@ def test_feedforwardgraph_mup(self): def test_fullgraphmultitasknetwork(self): # Load the configuration file for the model - CONFIG_FILE = "tests/config_test_ipu_dataloader.yaml" + CONFIG_FILE = "tests/config_test_dataloader.yaml" with open(CONFIG_FILE, "r") as f: cfg = yaml.safe_load(f) diff --git a/tests/test_node_label_order.py b/tests/test_node_label_order.py new file mode 100644 index 000000000..4ef099332 --- /dev/null +++ b/tests/test_node_label_order.py @@ -0,0 +1,305 @@ +""" +-------------------------------------------------------------------------------- +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. + +Use of this software is subject to the terms and conditions outlined in the LICENSE file. +Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without +warranties of any kind. + +Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Refer to the LICENSE file for the full terms and conditions. +-------------------------------------------------------------------------------- +""" + + +import unittest as ut + +from graphium.utils.fs import rm, exists +from graphium.data import MultitaskFromSmilesDataModule + +import torch +import pandas as pd +import numpy as np + +from torch_geometric.utils import unbatch + +TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000" + + +class Test_NodeLabelOrdering(ut.TestCase): + def test_node_label_ordering(self): + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ################################################################################################################### + ### Test I: Test if atom labels are ordered correctly for a single dataset that contains only a single molecule ### + ################################################################################################################### + + # Import node labels from parquet file + df = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]"], + "node_labels": [[0., 0., 2.]], + } + ) + + task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + atom_types = batch["labels"].node_task.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) + + np.testing.assert_array_equal(atom_types, atom_types_from_features) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ################################################################################### + ### Test II: Two ordered SMILES representing the same molecule in same dataset ### + ################################################################################### + + # Create input data + df = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]", "[O:0][C:1][C:2]"], + "node_labels": [[0., 0., 2.], [2., 0., 0.]], + } + ) + + task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + atom_types = batch["labels"].node_task.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) + + np.testing.assert_array_equal(atom_types_from_features, atom_types) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ############################################################################################# + ### Test III: Merging two node-level tasks each with different ordering of ordered SMILES ### + ############################################################################################# + + # Create input data + df1 = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]"], + "node_labels": [[0., 0., 2.]], + } + ) + + df2 = pd.DataFrame( + { + "ordered_smiles": ["[O:0][C:1][C:2]"], + "node_labels": [[2., 0., 0.]], + } + ) + + task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0} + task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task1": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs}, + "task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + unbatched_node_labels1 = unbatch(batch["labels"].node_task1, batch["labels"].batch) + unbatched_node_labels2 = unbatch(batch["labels"].node_task2, batch["labels"].batch) + unbatched_node_features = unbatch(batch["features"].feat, batch["features"].batch) + + atom_types1 = unbatched_node_labels1[0].squeeze() + atom_types2 = unbatched_node_labels2[0].squeeze() + atom_types_from_features = unbatched_node_features[0].argmax(1) + + np.testing.assert_array_equal(atom_types_from_features, atom_types1) + np.testing.assert_array_equal(atom_types_from_features, atom_types2) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ############################################################################### + ### Test IV: Merging node-level task on graph-level task with no node order ### + ### NOTE: Works as rdkit does not merge ordered_smiles vs. unordered smiles ### + ############################################################################### + + # Create input data + df1 = pd.DataFrame( + { + "ordered_smiles": ["CCO"], + "graph_labels": [1.], + } + ) + + df2 = pd.DataFrame( + { + "ordered_smiles": ["[O:0][C:1][C:2]"], + "node_labels": [[2., 0., 0.]], + } + ) + + task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0} + task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task1": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs}, + "task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + atom_types = batch["labels"].node_task2.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) + + # Ignore NaNs + nan_indices = atom_types.isnan() + atom_types_from_features[nan_indices] = 333 + atom_types[nan_indices] = 333 + + np.testing.assert_array_equal(atom_types, atom_types_from_features) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ##################################################################################### + ### Test V: Merging node-level task on graph-level task with different node order ### + ##################################################################################### + + # Create input data + df1 = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]"], + "graph_labels": [1.], + } + ) + + df2 = pd.DataFrame( + { + "ordered_smiles": ["[O:0][C:1][C:2]"], + "node_labels": [[2., 0., 0.]], + } + ) + + task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0} + task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task1": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs}, + "task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + atom_types = batch["labels"].node_task2.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) + + np.testing.assert_array_equal(atom_types, atom_types_from_features) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ############################ + ### Test VI: ... ### + ### TODO: To be finished ### + ############################ + + # Create input data + df = pd.DataFrame( + { + "smiles": ["CCO", "OCC", "COC", "[C:0][C:1][O:2]", "[O:0][C:1][C:2]", "[C:0][O:1][C:2]"], + "graph_labels": [0., 0., 1., 0., 0., 1.], + } + ) + + task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "smiles", "seed": 42, **task_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + +if __name__ == "__main__": + ut.main() + + # Delete the cache + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) diff --git a/tests/test_packing.py b/tests/test_packing.py deleted file mode 100644 index 3b378214b..000000000 --- a/tests/test_packing.py +++ /dev/null @@ -1,234 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -# General imports -import unittest as ut -import numpy as np - -import torch -from torch_geometric.data import Data, Batch - -# Current library imports -from graphium.utils.packing import ( - smart_packing, - get_pack_sizes, - fast_packing, - hybrid_packing, - node_to_pack_indices_mask, -) - - -def random_packing(num_nodes, batch_size): - ipu_batch_size = int(len(num_nodes) / batch_size) - indices = np.arange(len(num_nodes)) - np.random.shuffle(indices) - indices = np.reshape(indices, (ipu_batch_size, batch_size)).tolist() - return indices - - -class test_Packing(ut.TestCase): - def test_smart_packing(self): - np.random.seed(42) - - batch_sizes = [2, 4, 8, 16, 32, 64] - ipu_batch_sizes = [2, 3, 4, 8, 16, 32, 64] - - for batch_size in batch_sizes: - for ipu_batch_size in ipu_batch_sizes: - err_msg = f"bz={batch_size}, ipu_bz={ipu_batch_size}" - - # Generate random batch size - global_batch = batch_size * ipu_batch_size - num_nodes = np.abs(np.random.gamma(2, 20, size=global_batch)).astype(int) - - # Use the smart packing - packed_indices = smart_packing(num_nodes=num_nodes, batch_size=batch_size) - pack_num_nodes = get_pack_sizes(packed_indices, num_nodes) - - # Use the random packing - rand_packed_indices = random_packing(num_nodes=num_nodes, batch_size=batch_size) - rand_pack_num_nodes = get_pack_sizes(rand_packed_indices, num_nodes) - - # Assert that the smart packing is better than the random packing - self.assertLessEqual(max(pack_num_nodes), max(rand_pack_num_nodes), msg=err_msg) - self.assertGreaterEqual(min(pack_num_nodes), min(rand_pack_num_nodes), msg=err_msg) - - # Assert that the total number of atoms is right - self.assertEqual(sum(pack_num_nodes), sum(num_nodes), msg=err_msg) - self.assertEqual(sum(rand_pack_num_nodes), sum(num_nodes), msg=err_msg) - - # Assert that all index are there - self.assertListEqual( - np.sort(np.asarray(packed_indices).flatten()).tolist(), np.arange(len(num_nodes)).tolist() - ) - self.assertListEqual( - np.sort(np.asarray(rand_packed_indices).flatten()).tolist(), - np.arange(len(num_nodes)).tolist(), - ) - - def test_fast_packing(self): - np.random.seed(42) - - # Start at 4 for fast_packing for better statistical significance - batch_sizes = [4, 8, 16, 32, 64] - ipu_batch_sizes = [4, 8, 16, 32, 64] - - for batch_size in batch_sizes: - for ipu_batch_size in ipu_batch_sizes: - err_msg = f"bz={batch_size}, ipu_bz={ipu_batch_size}" - - # Generate random batch size - global_batch = batch_size * ipu_batch_size - num_nodes = np.abs(np.random.gamma(2, 20, size=global_batch)).astype(int) - - # Use the smart packing - packed_indices = fast_packing(num_nodes=num_nodes, batch_size=batch_size) - pack_num_nodes = get_pack_sizes(packed_indices, num_nodes) - - # Use the random packing - rand_packed_indices = random_packing(num_nodes=num_nodes, batch_size=batch_size) - rand_pack_num_nodes = get_pack_sizes(rand_packed_indices, num_nodes) - - # Assert that the smart packing is better than the random packing - self.assertLessEqual(max(pack_num_nodes), max(rand_pack_num_nodes), msg=err_msg) - self.assertGreaterEqual(min(pack_num_nodes), min(rand_pack_num_nodes), msg=err_msg) - - # Assert that the total number of atoms is right - self.assertEqual(sum(pack_num_nodes), sum(num_nodes), msg=err_msg) - self.assertEqual(sum(rand_pack_num_nodes), sum(num_nodes), msg=err_msg) - - # Assert that all index are there - self.assertListEqual( - np.sort(np.asarray(packed_indices).flatten()).tolist(), np.arange(len(num_nodes)).tolist() - ) - self.assertListEqual( - np.sort(np.asarray(rand_packed_indices).flatten()).tolist(), - np.arange(len(num_nodes)).tolist(), - ) - - def test_hybrid_packing(self): - np.random.seed(42) - - batch_sizes = [2, 4, 8, 16, 32, 64] - ipu_batch_sizes = [2, 3, 4, 8, 16, 32, 64] - - for batch_size in batch_sizes: - for ipu_batch_size in ipu_batch_sizes: - err_msg = f"bz={batch_size}, ipu_bz={ipu_batch_size}" - - # Generate random batch size - global_batch = batch_size * ipu_batch_size - num_nodes = np.abs(np.random.gamma(2, 20, size=global_batch)).astype(int) - - # Use the smart packing - packed_indices = hybrid_packing(num_nodes=num_nodes, batch_size=batch_size) - pack_num_nodes = get_pack_sizes(packed_indices, num_nodes) - - # Use the random packing - rand_packed_indices = random_packing(num_nodes=num_nodes, batch_size=batch_size) - rand_pack_num_nodes = get_pack_sizes(rand_packed_indices, num_nodes) - - # Assert that the smart packing is better than the random packing - self.assertLessEqual(max(pack_num_nodes), max(rand_pack_num_nodes), msg=err_msg) - self.assertGreaterEqual(min(pack_num_nodes), min(rand_pack_num_nodes), msg=err_msg) - - # Assert that the total number of atoms is right - self.assertEqual(sum(pack_num_nodes), sum(num_nodes), msg=err_msg) - self.assertEqual(sum(rand_pack_num_nodes), sum(num_nodes), msg=err_msg) - - # Assert that all index are there - self.assertListEqual( - np.sort(np.asarray(packed_indices).flatten()).tolist(), np.arange(len(num_nodes)).tolist() - ) - self.assertListEqual( - np.sort(np.asarray(rand_packed_indices).flatten()).tolist(), - np.arange(len(num_nodes)).tolist(), - ) - - def test_node_to_pack_indices_mask(self): - # Create a dummy batch - in_dim = 7 - in_dim_edges = 11 - max_num_nodes_per_graph = 20 - batch_size_per_pack = 5 - - torch.manual_seed(42) - - # Create a dummy batch of graphs - batch, all_num_nodes = [], [] - for ii in range(100): - num_nodes = torch.randint(1, max_num_nodes_per_graph, (1,)).item() - all_num_nodes.append(num_nodes) - num_edges = abs(round(2.2 * num_nodes) + torch.randint(-2, 2, (1,)).item()) + 1 - x = torch.randn(num_nodes, in_dim, dtype=torch.float32) - edge_idx = torch.randint(0, num_nodes, (2, num_edges)) - e = torch.randn(edge_idx.shape[-1], in_dim_edges, dtype=torch.float32) - g = Data(h=x, edge_index=edge_idx, edge_attr=e) - batch.append(g) - batch = Batch.from_data_list(batch) - - # Get the packing - packed_graph_idx = fast_packing(all_num_nodes, batch_size_per_pack) - pack_sizes = get_pack_sizes(packed_graph_idx, all_num_nodes) - max_pack_size = max(pack_sizes) - num_packs = len(pack_sizes) - - # Get the node to pack indices and the mask - pack_from_node_idx, pack_attn_mask = node_to_pack_indices_mask(packed_graph_idx, all_num_nodes) - - # Assert that the nodes to pack indices are correct - h = torch.arange(batch.num_nodes, dtype=torch.float32) - packed_shape = [num_packs, max_pack_size] - h_packed = torch.zeros(packed_shape) - h_packed[pack_from_node_idx[:, 0], pack_from_node_idx[:, 1]] = h - h_packed_unique = torch.sort(torch.unique(h_packed))[0] - np.testing.assert_array_equal(h_packed_unique, torch.arange(batch.num_nodes)) - self.assertEqual(h_packed.sum(), h.sum()) - - # Test again with additional h dimension - h = batch.h - packed_shape = [num_packs, max_pack_size] + list(h.shape[1:]) - h_packed = torch.zeros(packed_shape) - h_packed[pack_from_node_idx[:, 0], pack_from_node_idx[:, 1]] = h - h_packed_unique = torch.sort(torch.unique(h_packed))[0] - h_packed_unique = h_packed_unique[h_packed_unique != 0] - np.testing.assert_array_almost_equal(h_packed_unique, torch.unique(h)) - self.assertAlmostEqual(h_packed.sum().item(), h.sum().item(), places=3) - - # Assert that the mask is correct by counting the number of False values (the sum of squared number of nodes per pack) - num_false = (~pack_attn_mask).sum([1, 2]) - num_expected = torch.as_tensor( - [sum([all_num_nodes[graph_idx] ** 2 for graph_idx in pack]) for pack in packed_graph_idx] - ) - np.testing.assert_array_equal(num_false, num_expected) - - # Assert that the mask is correct by counting the number of elements in each row and column - num_expected = [] - for pack in packed_graph_idx: - pack_num_expected = [] - for graph_idx in pack: - num_nodes = all_num_nodes[graph_idx] - for ii in range(num_nodes): - pack_num_expected.append(num_nodes) - pack_num_expected.extend([0] * (max_pack_size - len(pack_num_expected))) - num_expected.append(pack_num_expected) - num_expected = torch.as_tensor(num_expected) - num_false_row = (~pack_attn_mask).sum([2]) - num_false_col = (~pack_attn_mask).sum([1]) - np.testing.assert_array_equal(num_false_row, num_expected) - np.testing.assert_array_equal(num_false_col, num_expected) - - -if __name__ == "__main__": - ut.main() diff --git a/tests/test_pe_nodepair.py b/tests/test_pe_nodepair.py index f90ce728b..b849b28b3 100644 --- a/tests/test_pe_nodepair.py +++ b/tests/test_pe_nodepair.py @@ -1,88 +1,113 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ """ -Unit tests for the positional encodings in graphium/features/* +Unit tests for the positional encodings in graphium/graphium_cpp/*.cpp """ import numpy as np -import networkx as nx +import torch import unittest as ut -from graphium.features.electrostatic import compute_electrostatic_interactions -from graphium.features.commute import compute_commute_distances -from graphium.features.graphormer import compute_graphormer_distances +import graphium +import graphium_cpp class test_positional_encodings(ut.TestCase): # Test graphs - adj_dict = {} + smiles_dict = {} + shape_dict = {} max_dict = {} # 6-ring - adj = np.asarray( - [ - [0, 1, 0, 0, 0, 1], - [1, 0, 1, 0, 0, 0], - [0, 1, 0, 1, 0, 0], - [0, 0, 1, 0, 1, 0], - [0, 0, 0, 1, 0, 1], - [1, 0, 0, 0, 1, 0], - ] - ) - adj_dict["6-ring"] = adj + smiles = "C1CCCCC1" + smiles_dict["6-ring"] = smiles + shape_dict["6-ring"] = [6, 6] max_dict["6-ring"] = 3 # 5-path - G = nx.path_graph(5) - adj = nx.to_numpy_array(G) - adj_dict["5-path"] = adj + smiles = "CCCCC" + smiles_dict["5-path"] = smiles + shape_dict["5-path"] = [5, 5] max_dict["5-path"] = 4 # 4-clique - adj = 1 - np.eye(4) - adj_dict["4-clique"] = adj + smiles = "C12C3C1C23" + smiles_dict["4-clique"] = smiles + shape_dict["4-clique"] = [4, 4] max_dict["4-clique"] = 1 # 4-barbell - H = nx.barbell_graph(4, 0) - adj = nx.to_numpy_array(H) - adj_dict["4-barbell"] = adj + smiles = "C12C3C1C23C12C3C1C23" + smiles_dict["4-barbell"] = smiles + shape_dict["4-barbell"] = [8, 8] max_dict["4-barbell"] = 3 + features = { + "electrostatic": {"pos_level": "nodepair", "pos_type": "electrostatic", "normalization": "none"}, + "graphormer": {"pos_level": "nodepair", "pos_type": "graphormer", "normalization": "none"}, + "commute": {"pos_level": "nodepair", "pos_type": "commute", "normalization": "none"}, + } + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor(features) + + def get_tensors(self, smiles): + tensors, _, _ = graphium_cpp.featurize_smiles( + smiles, + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_onehot + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_float + False, # has_conformer + torch.tensor(data=[], dtype=torch.int64), # edge_property_list + self.pos_encoding_tensor, + True, # duplicate_edges + False, # add_self_loop + False, # explicit_H=False + False, # use_bonds_weights + True, # offset_carbon + 7, # torch float64 + 0, # mask_nan_style_int + 0, # mask_nan_value + ) + return tensors + def test_dimensions(self): - for _, adj in self.adj_dict.items(): - pe, _, _ = compute_electrostatic_interactions(adj, cache={}) - self.assertEqual(pe.shape, adj.shape) + for key, smiles in self.smiles_dict.items(): + tensors = self.get_tensors(smiles) + + pe = tensors[4] # electrostatic + self.assertEqual(list(pe.shape), self.shape_dict[key]) - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) - self.assertEqual(pe.shape, adj.shape) + pe = tensors[5] # graphormer + self.assertEqual(list(pe.shape), self.shape_dict[key]) - pe, _, _ = compute_commute_distances(adj, adj.shape[0], cache={}) - self.assertEqual(pe.shape, adj.shape) + pe = tensors[6] # commute + self.assertEqual(list(pe.shape), self.shape_dict[key]) def test_symmetry(self): - for _, adj in self.adj_dict.items(): - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) + for _, smiles in self.smiles_dict.items(): + tensors = self.get_tensors(smiles) + + pe = tensors[5] # graphormer np.testing.assert_array_almost_equal(pe, pe.T) - pe, _, _ = compute_commute_distances(adj, adj.shape[0], cache={}) + pe = tensors[6] # commute np.testing.assert_array_almost_equal(pe, pe.T) def test_max_dist(self): - for key, adj in self.adj_dict.items(): - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) + for key, smiles in self.smiles_dict.items(): + tensors = self.get_tensors(smiles) + + pe = tensors[5] # graphormer np.testing.assert_array_almost_equal(pe.max(), self.max_dict[key]) diff --git a/tests/test_pe_rw.py b/tests/test_pe_rw.py index 938df28da..aebd6a577 100644 --- a/tests/test_pe_rw.py +++ b/tests/test_pe_rw.py @@ -1,53 +1,86 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ """ -Unit tests for the positional encodings in graphium/features/* +Unit tests for the positional encodings in graphium/features/random_walk.cpp """ import numpy as np -import networkx as nx +import torch import unittest as ut -from graphium.features.rw import compute_rwse +import graphium +import graphium_cpp class test_pe_spectral(ut.TestCase): - def test_caching_and_outputs(self): + def test_outputs(self): # 4-barbell - G = nx.barbell_graph(4, 0) - adj = nx.to_numpy_array(G) - num_nodes = adj.shape[0] - cache = {} + smiles = "C12C3C1C23C12C3C1C23" + num_nodes = 8 ksteps1 = [4, 6] ksteps2 = [2] ksteps3 = [6, 7] - pe1, _, cache = compute_rwse( - adj.astype(np.float32), ksteps1, num_nodes, cache, pos_type="rw_transition_probs" + # The feature names only depend on pos_type and pos_level, so the two + # rw_return_probs features can't have the same pos_level. + features = { + "rw_transition_probs": { + "pos_level": "nodepair", + "pos_type": "rw_transition_probs", + "normalization": "none", + "ksteps": ksteps1, + }, + "rw_return_probs_0": { + "pos_level": "node", + "pos_type": "rw_return_probs", + "normalization": "none", + "ksteps": ksteps2, + }, + "rw_return_probs_1": { + "pos_level": "nodepair", + "pos_type": "rw_return_probs", + "normalization": "none", + "ksteps": ksteps3, + }, + } + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor( + features ) - pe2, _, cache = compute_rwse( - adj.astype(np.float32), ksteps2, num_nodes, cache, pos_type="rw_return_probs" + tensors, _, _ = graphium_cpp.featurize_smiles( + smiles, + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_onehot + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_float + False, # has_conformer + torch.tensor(data=[], dtype=torch.int64), # edge_property_list + pos_encoding_tensor, + True, # duplicate_edges + False, # add_self_loop + False, # explicit_H=False + False, # use_bonds_weights + True, # offset_carbon + 7, # torch float64 + 0, # mask_nan_style_int + 0, # mask_nan_value ) - pe3, _, cache = compute_rwse( - adj.astype(np.float32), ksteps3, num_nodes, cache, pos_type="rw_return_probs" - ) + pe1 = tensors[4] + pe2 = tensors[5] + pe3 = tensors[6] - self.assertTrue(all([k in cache["ksteps"] for k in ksteps1 + ksteps2 + ksteps3])) self.assertTrue(pe1.shape, np.zeros((num_nodes, num_nodes, len(ksteps1)))) self.assertTrue(pe2.shape, np.zeros((num_nodes, len(ksteps2)))) self.assertTrue(pe3.shape, np.zeros((num_nodes, len(ksteps3)))) diff --git a/tests/test_pe_spectral.py b/tests/test_pe_spectral.py index 400eb9630..5c66e6f8b 100644 --- a/tests/test_pe_spectral.py +++ b/tests/test_pe_spectral.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -17,40 +17,75 @@ """ import numpy as np -import networkx as nx +import torch import unittest as ut -from graphium.features.spectral import compute_laplacian_pe +import graphium +import graphium_cpp + + +def get_pe_tensors(smiles, pos_encoding_tensor): + tensors, _, _ = graphium_cpp.featurize_smiles( + smiles, + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_onehot + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_float + False, # has_conformer + torch.tensor(data=[], dtype=torch.int64), # edge_property_list + pos_encoding_tensor, + True, # duplicate_edges + False, # add_self_loop + False, # explicit_H=False + False, # use_bonds_weights + True, # offset_carbon + 7, # torch float64 + 0, # mask_nan_style_int + 0, # mask_nan_value + ) + return tensors class test_pe_spectral(ut.TestCase): - # 2 disconnected 3 cliques - adj1 = np.zeros((6, 6)) - adj_3clq = 1 - np.eye(3) - adj1[:3, :3] = adj_3clq - adj1[3:, 3:] = adj_3clq + def test_for_connected_vs_disconnected_graph(self): + # 2 disconnected 3 cliques + smiles1 = "C1CC1.C1CC1" - # 3-clique - adj2 = 1 - np.eye(6) + # 6-clique (have to use S instead of C, because RDKit doesn't accept a carbon having 6 explicit bonds) + smiles2 = "S1234S567S189S251S368S4791" - def test_for_connected_vs_disconnected_graph(self): + num_atoms = 6 num_pos = 3 - # test if pe works identically on connected vs disconnected graphs - eigvals_pe1, _, _, cache = compute_laplacian_pe(self.adj1, num_pos, cache={}) - eigvals_pe1 = np.real(eigvals_pe1).astype(np.float32) - _, eigvecs_pe1, _, _ = compute_laplacian_pe(self.adj1, num_pos, cache=cache) + features = { + "laplacian_eigval": { + "pos_level": "node", + "pos_type": "laplacian_eigval", + "normalization": "none", + "num_pos": num_pos, + "disconnected_comp": True, + }, + "laplacian_eigvec": { + "pos_level": "node", + "pos_type": "laplacian_eigvec", + "normalization": "none", + "num_pos": num_pos, + "disconnected_comp": True, + }, + } + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor( + features + ) - # We expect to cache 4 objects in when running the functon for the first time - self.assertEqual(len(cache.keys()), 4) - - eigvals_pe2, _, _, _ = compute_laplacian_pe(self.adj2, num_pos, cache={}) - eigvals_pe2 = np.real(eigvals_pe2).astype(np.float32) - _, eigvecs_pe2, _, _ = compute_laplacian_pe(self.adj2, num_pos, cache={}) + # test if pe works identically on connected vs disconnected graphs + tensors1 = get_pe_tensors(smiles1, pos_encoding_tensor) + eigvals_pe1 = tensors1[4] + eigvecs_pe1 = tensors1[5] + tensors2 = get_pe_tensors(smiles2, pos_encoding_tensor) + eigvals_pe2 = tensors2[4] + eigvecs_pe2 = tensors2[5] np.testing.assert_array_almost_equal(2 * eigvals_pe1, eigvals_pe2) - self.assertListEqual(list(eigvals_pe2.shape), [self.adj2.shape[0], num_pos]) - self.assertListEqual(list(eigvecs_pe2.shape), [self.adj2.shape[0], num_pos]) + self.assertListEqual(list(eigvals_pe2.shape), [num_atoms, num_pos]) + self.assertListEqual(list(eigvecs_pe2.shape), [num_atoms, num_pos]) if __name__ == "__main__": diff --git a/tests/test_pos_transfer_funcs.py b/tests/test_pos_transfer_funcs.py index 5062cbe46..188c6b0e3 100644 --- a/tests/test_pos_transfer_funcs.py +++ b/tests/test_pos_transfer_funcs.py @@ -1,51 +1,166 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ - """ Unit tests for the positional encodings in graphium/features/* """ import numpy as np -import networkx as nx +import torch import unittest as ut +import math + +import graphium +import graphium_cpp + -from graphium.features.spectral import compute_laplacian_pe -from graphium.features.transfer_pos_level import ( - node_to_edge, - node_to_nodepair, - edge_to_nodepair, - nodepair_to_node, - nodepair_to_edge, - graph_to_node, -) +def get_tensors(smiles, pos_encoding_tensor): + tensors, _, _ = graphium_cpp.featurize_smiles( + smiles, + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_onehot + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_float + False, # has_conformer + torch.tensor(data=[], dtype=torch.int64), # edge_property_list + pos_encoding_tensor, + True, # duplicate_edges + False, # add_self_loop + False, # explicit_H=False + False, # use_bonds_weights + True, # offset_carbon + 7, # torch float64 + 0, # mask_nan_style_int + 0, # mask_nan_value + ) + return tensors class test_pos_transfer_funcs(ut.TestCase): - # 4-barbell - G = nx.barbell_graph(4, 0) - adj = nx.to_numpy_array(G) - num_nodes, num_feat = 8, 5 - node_pe = np.random.rand(num_nodes, num_feat) - - def test_different_pathways_from_node_to_edge(self): - edge_pe1, _ = node_to_edge(self.node_pe, self.adj, {}) - nodepair_pe1 = node_to_nodepair(self.node_pe, self.num_nodes) - edge_pe2, _ = nodepair_to_edge(nodepair_pe1, self.adj, {}) - nodepair_pe2, _ = edge_to_nodepair(edge_pe1, self.adj, self.num_nodes, {}) - edge_pe3, _ = nodepair_to_edge(nodepair_pe2, self.adj, {}) - np.testing.assert_array_almost_equal(edge_pe1, edge_pe2) - np.testing.assert_array_almost_equal(edge_pe1, edge_pe3) + + def test_different_transfers(self): + smiles = "CCCC" + + ksteps = [2, 4] + features = { + "a": { + "pos_level": "node", + "pos_type": "rw_return_probs", + "normalization": "none", + "ksteps": ksteps, + }, + "b": { + "pos_level": "edge", + "pos_type": "rw_return_probs", + "normalization": "none", + "ksteps": ksteps, + }, + "c": { + "pos_level": "nodepair", + "pos_type": "rw_return_probs", + "normalization": "none", + "ksteps": ksteps, + }, + "e": {"pos_level": "node", "pos_type": "graphormer", "normalization": "none"}, + "f": {"pos_level": "edge", "pos_type": "graphormer", "normalization": "none"}, + "d": {"pos_level": "nodepair", "pos_type": "graphormer", "normalization": "none"}, + } + + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor( + features + ) + + tensors = get_tensors(smiles, pos_encoding_tensor) + node_probs = tensors[4] + edge_probs = tensors[5] + nodepair_probs = tensors[6] + node_dists = tensors[7] + edge_dists = tensors[8] + nodepair_dists = tensors[9] + + print(f"node_probs =\n{node_probs}\n") + print(f"edge_probs =\n{edge_probs}\n") + print(f"nodepair_probs =\n{nodepair_probs}\n") + print(f"node_dists =\n{node_dists}\n") + print(f"edge_dists =\n{edge_dists}\n") + print(f"nodepair_dists =\n{nodepair_dists}\n") + + expected_node_probs = [ + [0.5, 0.375], + [0.75, 0.6875], + [0.75, 0.6875], + [0.5, 0.375], + ] + # sum for each node value and absolute difference for each node value, for each half-edge + expected_edge_probs = [ + [1.25, 1.0625, 0.25, 0.3125], + [1.25, 1.0625, 0.25, 0.3125], + [1.5, 1.375, 0.0, 0.0], + [1.5, 1.375, 0.0, 0.0], + [1.25, 1.0625, 0.25, 0.3125], + [1.25, 1.0625, 0.25, 0.3125], + ] + # sum for each node value and absolute difference for each node value, for each node pair + expected_nodepair_probs = [ + [ + [1.0000, 0.7500, 0.0000, 0.0000], + [1.2500, 1.0625, 0.2500, 0.3125], + [1.2500, 1.0625, 0.2500, 0.3125], + [1.0000, 0.7500, 0.0000, 0.0000], + ], + [ + [1.2500, 1.0625, 0.2500, 0.3125], + [1.5000, 1.3750, 0.0000, 0.0000], + [1.5000, 1.3750, 0.0000, 0.0000], + [1.2500, 1.0625, 0.2500, 0.3125], + ], + [ + [1.2500, 1.0625, 0.2500, 0.3125], + [1.5000, 1.3750, 0.0000, 0.0000], + [1.5000, 1.3750, 0.0000, 0.0000], + [1.2500, 1.0625, 0.2500, 0.3125], + ], + [ + [1.0000, 0.7500, 0.0000, 0.0000], + [1.2500, 1.0625, 0.2500, 0.3125], + [1.2500, 1.0625, 0.2500, 0.3125], + [1.0000, 0.7500, 0.0000, 0.0000], + ], + ] + self.assertEqual(node_probs.tolist(), expected_node_probs) + self.assertEqual(edge_probs.tolist(), expected_edge_probs) + self.assertEqual(nodepair_probs.tolist(), expected_nodepair_probs) + + expected_nodepair_dists = [ + [0.0, 1.0, 2.0, 3.0], + [1.0, 0.0, 1.0, 2.0], + [2.0, 1.0, 0.0, 1.0], + [3.0, 2.0, 1.0, 0.0], + ] + # Select half-edge node pairs + expected_edge_dists = [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0]] + # Minimum of column, minimum of row, mean of column, mean of row, + # stdev of column, stdev of row, for each node + # stdev here uses n for normalization instead of n-1 + stdev_a = math.sqrt((1.5 * 1.5 + 0.5 * 0.5 + 0.5 * 0.5 + 1.5 * 1.5) / 4) + stdev_b = math.sqrt((1.0 * 1.0 + 1.0 * 1.0) / 4) + expected_node_dists = [ + [0.0, 0.0, 1.5, 1.5, stdev_a, stdev_a], + [0.0, 0.0, 1.0, 1.0, stdev_b, stdev_b], + [0.0, 0.0, 1.0, 1.0, stdev_b, stdev_b], + [0.0, 0.0, 1.5, 1.5, stdev_a, stdev_a], + ] + np.testing.assert_array_almost_equal(node_dists.tolist(), expected_node_dists) + self.assertEqual(edge_dists.tolist(), expected_edge_dists) + self.assertEqual(nodepair_dists.tolist(), expected_nodepair_dists) if __name__ == "__main__": diff --git a/tests/test_positional_encoders.py b/tests/test_positional_encoders.py index 166929ba2..66148487f 100644 --- a/tests/test_positional_encoders.py +++ b/tests/test_positional_encoders.py @@ -1,12 +1,12 @@ """ -------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind. -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Valence Labs, Recursion Pharmaceuticals, Graphcore Limited, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions. -------------------------------------------------------------------------------- """ @@ -18,19 +18,40 @@ import numpy as np import unittest as ut -from copy import deepcopy from rdkit import Chem import datamol as dm import torch -from scipy.sparse import coo_matrix +from torch_geometric.data import Data + +import graphium +import graphium_cpp -from graphium.features.featurizer import GraphDict -from graphium.features.positional_encoding import graph_positional_encoder from graphium.nn.encoders import laplace_pos_encoder, mlp_encoder, signnet_pos_encoder + # TODO: Test the MLP_encoder and signnet_pos_encoder +def get_pe_tensors(smiles, pos_encoding_tensor): + tensors, _, _ = graphium_cpp.featurize_smiles( + smiles, + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_onehot + torch.tensor(data=[], dtype=torch.int64), # atom_property_list_float + False, # has_conformer + torch.tensor(data=[], dtype=torch.int64), # edge_property_list + pos_encoding_tensor, + True, # duplicate_edges + False, # add_self_loop + False, # explicit_H=False + False, # use_bonds_weights + True, # offset_carbon + 7, # torch float64 + 0, # mask_nan_style_int + 0, # mask_nan_value + ) + return tensors + + class test_positional_encoder(ut.TestCase): smiles = [ "C", @@ -44,22 +65,34 @@ class test_positional_encoder(ut.TestCase): adjs = [Chem.rdmolops.GetAdjacencyMatrix(mol) for mol in mols] def test_laplacian_eigvec_eigval(self): - for ii, adj in enumerate(deepcopy(self.adjs)): + for ii, mol in enumerate(self.smiles): + adj = self.adjs[ii] for num_pos in [1, 2, 4]: # Can't test too much eigs because of multiplicities for disconnected_comp in [True, False]: err_msg = f"adj_id={ii}, num_pos={num_pos}, disconnected_comp={disconnected_comp}" - # returns a dictionary of computed pe - pos_kwargs = { - "pos_type": "laplacian_eigvec", - "num_pos": num_pos, - "disconnected_comp": disconnected_comp, - "pos_level": "node", + features = { + "laplacian_eigval": { + "pos_type": "laplacian_eigval", + "num_pos": num_pos, + "disconnected_comp": disconnected_comp, + "pos_level": "node", + }, + "laplacian_eigvec": { + "pos_type": "laplacian_eigvec", + "num_pos": num_pos, + "disconnected_comp": disconnected_comp, + "pos_level": "node", + }, } - num_nodes = adj.shape[0] - eigvecs, cache = graph_positional_encoder(adj, num_nodes, pos_kwargs=pos_kwargs) - pos_kwargs["pos_type"] = "laplacian_eigval" - eigvals, cache = graph_positional_encoder(adj, num_nodes, pos_kwargs=pos_kwargs) + ( + pos_encoding_names, + pos_encoding_tensor, + ) = graphium_cpp.positional_feature_options_to_tensor(features) + + tensors = get_pe_tensors(mol, pos_encoding_tensor) + eigvals = tensors[4] + eigvecs = tensors[5] self.assertEqual(list(eigvecs.shape), [adj.shape[0], num_pos], msg=err_msg) self.assertEqual(list(eigvals.shape), [adj.shape[0], num_pos], msg=err_msg) @@ -74,7 +107,10 @@ def test_laplacian_eigvec_eigval(self): true_num_pos = min(num_pos, len(true_eigvals)) true_eigvals, true_eigvecs = true_eigvals[:true_num_pos], true_eigvecs[:, :true_num_pos] - if not ("." in self.smiles[ii]): + if not ("." in mol): + print( + f"About to test eigvecs for smiles {mol}, num_pos {num_pos}, disconnected_comp {disconnected_comp}" + ) np.testing.assert_array_almost_equal( np.abs(true_eigvecs), np.abs(eigvecs[:, :true_num_pos]), @@ -88,13 +124,22 @@ def test_laplacian_eigvec_eigval(self): # didn't actually check the exact computation result because the code was adapted def test_rwse(self): - for ii, adj in enumerate(deepcopy(self.adjs)): + for ii, mol in enumerate(self.smiles): + adj = self.adjs[ii] for ksteps in [1, 2, 4]: err_msg = f"adj_id={ii}, ksteps={ksteps}" num_nodes = adj.shape[0] pos_kwargs = {"pos_type": "rw_return_probs", "ksteps": ksteps, "pos_level": "node"} - rwse_embed, cache = graph_positional_encoder(adj, num_nodes, pos_kwargs=pos_kwargs) + features = { + "rw_return_probs": pos_kwargs, + } + (pos_encoding_names, pos_encoding_tensor) = graphium_cpp.positional_feature_options_to_tensor( + features + ) + tensors = get_pe_tensors(mol, pos_encoding_tensor) + rwse_embed = tensors[4] + self.assertEqual(list(rwse_embed.shape), [num_nodes, ksteps], msg=err_msg) # TODO: work in progress @@ -105,23 +150,32 @@ def test_rwse(self): """ def test_laplacian_eigvec_with_encoder(self): - for ii, adj in enumerate(deepcopy(self.adjs)): + for ii, mol in enumerate(self.smiles): for num_pos in [2, 4, 8]: # Can't test too much eigs because of multiplicities for disconnected_comp in [True, False]: for model_type in ["Transformer", "DeepSet", "MLP"]: err_msg = f"adj_id={ii}, num_pos={num_pos}, disconnected_comp={disconnected_comp}" - # returns a dictionary of computed pe - pos_kwargs = { - "pos_type": "laplacian_eigvec", - "num_pos": num_pos, - "disconnected_comp": disconnected_comp, - "pos_level": "node", + features = { + "laplacian_eigval": { + "pos_type": "laplacian_eigval", + "num_pos": num_pos, + "disconnected_comp": disconnected_comp, + "pos_level": "node", + }, + "laplacian_eigvec": { + "pos_type": "laplacian_eigvec", + "num_pos": num_pos, + "disconnected_comp": disconnected_comp, + "pos_level": "node", + }, } - num_nodes = adj.shape[0] - eigvecs, cache = graph_positional_encoder(adj, num_nodes, pos_kwargs=pos_kwargs) - pos_kwargs["pos_type"] = "laplacian_eigval" - eigvals, cache = graph_positional_encoder(adj, num_nodes, pos_kwargs=pos_kwargs) + ( + pos_encoding_names, + pos_encoding_tensor, + ) = graphium_cpp.positional_feature_options_to_tensor(features) + + tensors = get_pe_tensors(mol, pos_encoding_tensor) input_keys = ["laplacian_eigvec", "laplacian_eigval"] in_dim = num_pos @@ -129,16 +183,17 @@ def test_laplacian_eigvec_with_encoder(self): out_dim = 64 num_layers = 1 - eigvecs = torch.from_numpy(eigvecs) - eigvals = torch.from_numpy(eigvals) - - g = GraphDict( - { - "adj": coo_matrix(adj), - "data": {"laplacian_eigval": eigvals, "laplacian_eigvec": eigvecs}, - } + num_nodes = tensors[2].size(0) + data_dict = { + # "feat": tensors[2], + # "edge_feat": tensors[3], + "laplacian_eigval": tensors[4].float(), + "laplacian_eigvec": tensors[5].float(), + } + # Create the PyG graph object `Data` + data = Data( + edge_index=tensors[0], edge_weight=tensors[1], num_nodes=num_nodes, **data_dict ) - batch = g.make_pyg_graph() encoder = laplace_pos_encoder.LapPENodeEncoder( input_keys=input_keys, @@ -153,7 +208,7 @@ def test_laplacian_eigvec_with_encoder(self): first_normalization=None, ) - hidden_embed = encoder(batch, key_prefix=None) + hidden_embed = encoder(data, key_prefix=None) assert "node" in hidden_embed.keys() self.assertEqual(list(hidden_embed["node"].shape), [num_nodes, out_dim], msg=err_msg) diff --git a/tests/test_positional_encodings.py b/tests/test_positional_encodings.py deleted file mode 100644 index 89bf355a4..000000000 --- a/tests/test_positional_encodings.py +++ /dev/null @@ -1,92 +0,0 @@ -""" --------------------------------------------------------------------------------- -Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. - -Use of this software is subject to the terms and conditions outlined in the LICENSE file. -Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without -warranties of any kind. - -Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. -Refer to the LICENSE file for the full terms and conditions. --------------------------------------------------------------------------------- -""" - - -""" -Unit tests for the positional encodings in graphium/features/* -""" - -import numpy as np -import networkx as nx -import unittest as ut - -# from graphium.features.spectral import compute_laplacian_positional_eigvecs # TODO: add tests -# from graphium.features.rw import compute_rwse # TODO: add tests -from graphium.features.electrostatic import compute_electrostatic_interactions -from graphium.features.commute import compute_commute_distances -from graphium.features.graphormer import compute_graphormer_distances - - -class test_positional_encodings(ut.TestCase): - # Test graphs - adj_dict = {} - max_dict = {} - - # 6-ring - adj = np.asarray( - [ - [0, 1, 0, 0, 0, 1], - [1, 0, 1, 0, 0, 0], - [0, 1, 0, 1, 0, 0], - [0, 0, 1, 0, 1, 0], - [0, 0, 0, 1, 0, 1], - [1, 0, 0, 0, 1, 0], - ] - ) - adj_dict["6-ring"] = adj - max_dict["6-ring"] = 3 - - # 5-path - G = nx.path_graph(5) - adj = nx.to_numpy_array(G) - adj_dict["5-path"] = adj - max_dict["5-path"] = 4 - - # 4-clique - adj = 1 - np.eye(4) - adj_dict["4-clique"] = adj - max_dict["4-clique"] = 1 - - # 4-barbell - H = nx.barbell_graph(4, 0) - adj = nx.to_numpy_array(H) - adj_dict["4-barbell"] = adj - max_dict["4-barbell"] = 3 - - def test_dimensions(self): - for _, adj in self.adj_dict.items(): - pe, _, _ = compute_electrostatic_interactions(adj, cache={}) - self.assertEqual(pe.shape, adj.shape) - - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) - self.assertEqual(pe.shape, adj.shape) - - pe, _, _ = compute_commute_distances(adj, adj.shape[0], cache={}) - self.assertEqual(pe.shape, adj.shape) - - def test_symmetry(self): - for _, adj in self.adj_dict.items(): - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) - np.testing.assert_array_almost_equal(pe, pe.T) - - pe, _, _ = compute_commute_distances(adj, adj.shape[0], cache={}) - np.testing.assert_array_almost_equal(pe, pe.T) - - def test_max_dist(self): - for key, adj in self.adj_dict.items(): - pe, _, _ = compute_graphormer_distances(adj, adj.shape[0], cache={}) - np.testing.assert_array_almost_equal(pe.max(), self.max_dict[key]) - - -if __name__ == "__main__": - ut.main() diff --git a/tests/test_predictor.py b/tests/test_predictor.py index 1ef69775f..3d3b8648c 100644 --- a/tests/test_predictor.py +++ b/tests/test_predictor.py @@ -29,7 +29,7 @@ def test_parse_loss_fun(self): preds = torch.rand(10, 5) target = (torch.rand(10, 5) > 0.5).to(preds.dtype) for this_loss in losses: - loss_fun = EvalOptions.parse_loss_fun(this_loss) + loss_name, loss_fun = EvalOptions.parse_loss_fun(this_loss) loss = loss_fun(preds, target) diff --git a/tests/test_predictor_summaries.py b/tests/test_predictor_summaries.py new file mode 100644 index 000000000..4e946d68b --- /dev/null +++ b/tests/test_predictor_summaries.py @@ -0,0 +1,276 @@ +""" +-------------------------------------------------------------------------------- +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. + +Use of this software is subject to the terms and conditions outlined in the LICENSE file. +Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without +warranties of any kind. + +Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Refer to the LICENSE file for the full terms and conditions. +-------------------------------------------------------------------------------- +""" + + +""" +Unit tests for the file graphium/trainer/predictor.py +""" + +import torch +from torch import nn +from torchmetrics import MeanAbsoluteError, PearsonCorrCoef +from copy import deepcopy +import unittest as ut + +from graphium.trainer.predictor_summaries import MultiTaskSummary, STDMetric, GradientNormMetric + +class SimpleNN(nn.Module): +# Define a simple neural network with 2 layers + def __init__(self, in_dim=10, out_dim=1): + super(SimpleNN, self).__init__() + torch.random.manual_seed(42) + # Define the first layer with 10 input features and 5 output features + self.layer1 = nn.Linear(in_dim, 5) + # Define the second layer with 5 input features and 1 output feature + self.layer2 = nn.Linear(5, out_dim) + + def forward(self, x): + # Pass the input through the first layer + if x.ndim == 1: + x = x.unsqueeze(-1) + x = torch.relu(self.layer1(x)) + # Pass the output of the first layer through the second layer + x = self.layer2(x) + return x + + +class SimpleDictNN(nn.Module): + def __init__(self, task_list, in_dim=10, out_dim=1): + super(SimpleDictNN, self).__init__() + torch.random.manual_seed(42) + self.dict_nn = nn.ModuleDict({task: SimpleNN(in_dim, out_dim) for task in task_list}) + + def forward(self, x): + return {task: self.dict_nn[task](x[task]) for task in self.dict_nn.keys()} + + +def simple_nn_grad_step(model, inputs, targets): + # Initialize the optimizer and loss function + loss_fn = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + # Perform a gradient step + optimizer.zero_grad() + outputs = model(inputs) + if isinstance(outputs, dict): + loss = sum([loss_fn(outputs[task], targets[task]) for task in outputs.keys()]) + else: + loss = loss_fn(outputs, targets) + loss.backward() + optimizer.step() + return model + +class test_TaskSummary(ut.TestCase): + + def test_std_metric(self): + + # Generate random data + torch.random.manual_seed(42) + rand = torch.rand(100, 1) + + # Compute expected values for STD + expected_std = torch.std(rand, correction=0) + + # Compute std metric + std_metric = STDMetric() + std_metric.update(rand) + std_metric_val = std_metric.compute() + std_metric.reset() + + self.assertAlmostEqual(std_metric_val.item(), expected_std.item(), places=5) + + # Check multiple updates + std_metric.update(rand[:10]) + std_metric.update(rand[10:25]) + std_metric.update(rand[25:]) + std_metric_val = std_metric.compute() + std_metric.reset() + + self.assertAlmostEqual(std_metric_val.item(), expected_std.item(), places=5) + + # Add some correction + expected_std = torch.std(rand, correction=1) + std_metric = STDMetric(correction=1) + std_metric.update(rand) + std_metric_val = std_metric.compute() + std_metric.reset() + + self.assertAlmostEqual(std_metric_val.item(), expected_std.item(), places=5) + + # Add some nans + rand[[3, 5, 11, 23, 42, 56, 78, 99]] = float('nan') + expected_std = torch.std(rand[~rand.isnan()], correction=0) + + std_metric = STDMetric(nan_strategy='ignore', correction=0) + std_metric.update(rand) + std_metric_val = std_metric.compute() + std_metric.reset() + + self.assertAlmostEqual(std_metric_val.item(), expected_std.item(), places=5) + + def test_gradient_norm_metric(self): + + # Generate random data + torch.random.manual_seed(42) + LEN = 10000 + inputs = torch.rand(LEN, 10) + targets = torch.rand(LEN, 1) + + # Compute expected values for gradient norm + model = SimpleNN() + model = simple_nn_grad_step(model, inputs, targets) + expected_grad_norm = torch.norm(torch.stack([torch.norm(param.grad) for param in model.parameters()])) + + # Compute gradient norm metric + model = SimpleNN() + model = simple_nn_grad_step(model, inputs, targets) + grad_norm_metric = GradientNormMetric() + grad_norm_metric.update(model) + grad_norm_metric_val = grad_norm_metric.compute() + grad_norm_metric.reset() + + self.assertAlmostEqual(grad_norm_metric_val.item(), expected_grad_norm.item(), places=5) + + # Compute gradient norm metric with many update steps + grad_norm_metric = GradientNormMetric() + model = SimpleNN() + model = simple_nn_grad_step(model, inputs[:50], targets[:50]) + grad_norm_metric.update(model) + model = SimpleNN() + model = simple_nn_grad_step(model, inputs[50:400], targets[50:400]) + grad_norm_metric.update(model) + model = SimpleNN() + model = simple_nn_grad_step(model, inputs[400:], targets[400:]) + grad_norm_metric.update(model) + + grad_norm_metric_val = grad_norm_metric.compute() + grad_norm_metric.reset() + + self.assertAlmostEqual(grad_norm_metric_val.item(), expected_grad_norm.item(), places=1) + + def assertDictTensorAlmostEqual(self, dict1, dict2, places=7): + dict1 = deepcopy(dict1) + dict1 = {key: dict1[key] for key in sorted(dict1.keys())} + dict2 = deepcopy(dict2) + dict2 = {key: dict2[key] for key in sorted(dict2.keys())} + for key in dict1.keys(): + dict1[key] = round(dict1[key].item(), places) + for key in dict2.keys(): + dict2[key] = round(dict2[key].item(), places) + self.assertDictEqual(dict1, dict2) + + + def test_multi_task_summary(self): + + # Generate random data + torch.random.manual_seed(42) + targets = torch.rand(100, 3) + preds = torch.rand(100, 3) + 0.4 * targets + targets = {f"task{i+1}": targets[:, i] for i in range(targets.shape[1])} + preds = {f"task{i+1}": preds[:, i] for i in range(preds.shape[1])} + + task_metrics = { + "task1": {'mae': MeanAbsoluteError(), 'pearson': PearsonCorrCoef()}, + "task2": {'pearson': PearsonCorrCoef()}, + "task3": {'mae': MeanAbsoluteError()} + } + + expected_dict = {} + for task, metrics in task_metrics.items(): + for metric_name, metric in metrics.items(): + metric.update(preds[task], targets[task]) + expected_val = metric.compute() + metric.reset() + expected_dict[f"{task}/{metric_name}/val"] = expected_val + + + # Test the metrics on validation step + summary_val = MultiTaskSummary(task_metrics, step_name="val", compute_mean=False, compute_std=False) + summary_val.update(preds, targets) + summary_dict = summary_val.compute() + self.assertDictTensorAlmostEqual(summary_dict, expected_dict, places=5) + + # Test the metric reset + summary_val.reset() + summary_val.update(preds, targets) + summary_dict = summary_val.compute() + self.assertDictTensorAlmostEqual(summary_dict, expected_dict, places=5) + + # Test multiple batches + summary_val.reset() + preds1 = {key: preds[key][:10] for key in preds.keys()} + targets1 = {key: targets[key][:10] for key in targets.keys()} + preds2 = {key: preds[key][10:25] for key in preds.keys()} + targets2 = {key: targets[key][10:25] for key in targets.keys()} + preds3 = {key: preds[key][25:] for key in preds.keys()} + targets3 = {key: targets[key][25:] for key in targets.keys()} + + summary_val.update(preds1, targets1) + summary_val.update(preds2, targets2) + summary_val.update(preds3, targets3) + summary_dict = summary_val.compute() + self.assertDictTensorAlmostEqual(summary_dict, expected_dict, places=5) + + # Test the mean and std computation + summary_val = MultiTaskSummary(task_metrics, step_name="val", compute_mean=True, compute_std=True) + summary_val.update(preds, targets) + summary_dict = summary_val.compute() + expected_dict_mean_std = {} + for task in task_metrics.keys(): + expected_dict_mean_std[f"{task}/mean_preds/val"] = preds[task].mean() + expected_dict_mean_std[f"{task}/std_preds/val"] = preds[task].std(correction=0) + expected_dict_mean_std[f"{task}/mean_target/val"] = targets[task].mean() + expected_dict_mean_std[f"{task}/std_target/val"] = targets[task].std(correction=0) + expected_dict_mean_std.update(expected_dict) + self.assertDictTensorAlmostEqual(summary_dict, expected_dict_mean_std, places=5) + + # Test the mean and std computation with multiple batches + summary_val.reset() + summary_val.update(preds1, targets1) + summary_val.update(preds2, targets2) + summary_val.update(preds3, targets3) + summary_dict = summary_val.compute() + self.assertDictTensorAlmostEqual(summary_dict, expected_dict_mean_std, places=5) + + # Test the training step doesn't return anything when no metrics on training set are selected + summary_train = MultiTaskSummary(task_metrics, step_name="train", task_metrics_on_training_set=None, compute_mean=False, compute_std=False) + summary_train.update(preds, targets) + summary_train = summary_train.compute() + self.assertDictEqual(summary_train, {}) + + # Test the training step returns only the mae + task_metrics_on_training_set = {"task1": ["mae"], "task2": None, "task3": "mae"} + summary_train = MultiTaskSummary(task_metrics, step_name="train", task_metrics_on_training_set=task_metrics_on_training_set, compute_mean=False, compute_std=False) + summary_train.update(preds, targets) + summary_dict = summary_train.compute() + expected_dict_mae = {key: value for key, value in expected_dict.items() if "mae" in key} + expected_dict_mae = {key.replace("/val", "/train"): value for key, value in expected_dict_mae.items()} + self.assertDictTensorAlmostEqual(summary_dict, expected_dict_mae, places=5) + + # Test the training step returns only the mae with multiple steps + summary_train = MultiTaskSummary(task_metrics, step_name="train", task_metrics_on_training_set=task_metrics_on_training_set, compute_mean=False, compute_std=False) + summary_train.update(preds1, targets1) + summary_train.update(preds2, targets2) + summary_train.update(preds3, targets3) + summary_dict = summary_train.compute() + self.assertDictTensorAlmostEqual(summary_dict, expected_dict_mae, places=5) + + # Test grad_norm not available in "val" step + summary_val = MultiTaskSummary(task_metrics, step_name="val", compute_mean=False, compute_std=False) + summary_val.update(preds, targets) + summary_dict = summary_val.compute() + self.assertNotIn("grad_norm", summary_dict.keys()) + + +if __name__ == "__main__": + ut.main() diff --git a/tests/test_pyg_layers.py b/tests/test_pyg_layers.py index 03498eb35..52caec668 100644 --- a/tests/test_pyg_layers.py +++ b/tests/test_pyg_layers.py @@ -231,7 +231,6 @@ def test_pnamessagepassinglayer(self): self.assertEqual(bg2.feat.shape[1], self.out_dim * layer.out_dim_factor) self.assertTrue((bg2.edge_feat == self.bg.edge_feat).all) - @pytest.mark.skip_ipu def test_dimenetlayer(self): from graphium.nn.encoders.bessel_pos_encoder import BesselSphericalPosEncoder @@ -311,7 +310,7 @@ def test_preprocess3Dfeaturelayer(self): # bias: [batch, num_heads, nodes, nodes] # node_feature: [total_nodes, embed_dim] bias, node_feature = layer.forward( - bg, max_num_nodes_per_graph=4, on_ipu=False, positions_3d_key="positions_3d" + bg, max_num_nodes_per_graph=4, positions_3d_key="positions_3d" ) self.assertEqual(bias.size(), torch.Size([2, num_heads, 4, 4])) self.assertFalse(np.isnan(bias.detach().numpy()).any()) diff --git a/tests/test_training.py b/tests/test_training.py index 3ac31fc35..d9696f9a1 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -17,10 +17,25 @@ import sys import subprocess import os -from unittest.mock import patch +import shutil +import unittest as ut -class TestCLITraining: +import hydra +from hydra.core.global_hydra import GlobalHydra +from omegaconf import OmegaConf +import os + +from graphium.config._loader import ( + load_accelerator, + load_architecture, + load_datamodule, + load_metrics, + load_predictor, + load_trainer, +) + +class test_CLITraining(): @classmethod def setup_class(cls): print("Setting up the test class...") @@ -49,7 +64,7 @@ def setup_class(cls): print("Data has been successfully downloaded.") - def call_cli_with_overrides(self, acc_type: str, acc_prec: str, load_type: str) -> None: + def call_cli_with_overrides(self, acc_type: str, acc_prec: str) -> None: overrides = [ f"accelerator={acc_type}", "tasks=toymix", @@ -76,11 +91,8 @@ def call_cli_with_overrides(self, acc_type: str, acc_prec: str, load_type: str) "+datamodule.args.task_specific_args.zinc.sample_size=1000", "trainer.trainer.check_val_every_n_epoch=1", f"trainer.trainer.precision={acc_prec}", - f"datamodule.args.dataloading_from={load_type}", ] - if acc_type == "ipu": - overrides.append("accelerator.ipu_config=['useIpuModel(True)']") - overrides.append("accelerator.ipu_inference_config=['useIpuModel(True)']") + # Backup the original sys.argv original_argv = sys.argv.copy() @@ -93,20 +105,93 @@ def call_cli_with_overrides(self, acc_type: str, acc_prec: str, load_type: str) # Restore the original sys.argv sys.argv = original_argv - @pytest.mark.parametrize("load_type", ["RAM", "disk"]) - def test_cpu_cli_training(self, load_type): - self.call_cli_with_overrides("cpu", "32", load_type) - - @pytest.mark.ipu - @pytest.mark.skip - @pytest.mark.parametrize("load_type", ["RAM", "disk"]) - def test_ipu_cli_training(self, load_type): - with patch("poptorch.ipuHardwareIsAvailable", return_value=True): - with patch("lightning_graphcore.accelerator._IPU_AVAILABLE", new=True): - import poptorch - - assert poptorch.ipuHardwareIsAvailable() - from lightning_graphcore.accelerator import _IPU_AVAILABLE - - assert _IPU_AVAILABLE is True - self.call_cli_with_overrides("ipu", "16-true", load_type) + def test_cpu_cli_training(self): + self.call_cli_with_overrides("cpu", "32") + + +def initialize_hydra(config_path, job_name="app"): + if GlobalHydra.instance().is_initialized(): + GlobalHydra.instance().clear() + hydra.initialize(config_path=config_path, job_name=job_name) + +def compose_main_config(config_dir): + initialize_hydra(config_dir) + # Compose the main configuration + main_config = hydra.compose(config_name="main") + return main_config + +def compose_task_config(config_dir, task_name): + task_config_dir = os.path.join(config_dir, "tasks") + initialize_hydra(task_config_dir, job_name="compose_task") + # Compose the specific task configuration + task_config = hydra.compose(config_name=task_name) + return task_config + +class test_TrainToymix(ut.TestCase): + def test_train_toymix(self): + pytest.skip("Skipping for now because of necessity of download") + + # Load the main configuration for toymix + CONFIG_DIR = "../expts/hydra-configs/" + cfg = compose_main_config(CONFIG_DIR) + cfg = OmegaConf.to_container(cfg, resolve=True) + cfg.pop("tasks") + + # Adapt the configuration to reduce the time it takes to run the test, less samples, less epochs + cfg["constants"]["max_epochs"] = 4 + cfg["trainer"]["trainer"]["check_val_every_n_epoch"] = 1 + cfg["trainer"]["trainer"]["max_epochs"] = 4 + + cfg["datamodule"]["args"]["batch_size_training"] = 20 + cfg["datamodule"]["args"]["batch_size_inference"] = 20 + cfg["datamodule"]["args"]["task_specific_args"]["zinc"]["sample_size"] = 300 + cfg["datamodule"]["args"]["task_specific_args"]["qm9"]["sample_size"] = 300 + cfg["datamodule"]["args"]["task_specific_args"]["tox21"]["sample_size"] = 300 + + + # Initialize the accelerator + cfg, accelerator_type = load_accelerator(cfg) + + # If the data_cache directory exists, delete it for the purpose of the test + data_cache = cfg["datamodule"]["args"]["processed_graph_data_path"] + if os.path.exists(data_cache): + shutil.rmtree(data_cache) + + # Load and initialize the dataset + datamodule = load_datamodule(cfg, accelerator_type) + + # Initialize the network + model_class, model_kwargs = load_architecture( + cfg, + in_dims=datamodule.in_dims, + ) + + datamodule.prepare_data() + + metrics = load_metrics(cfg) + + predictor = load_predictor( + cfg, + model_class, + model_kwargs, + metrics, + datamodule.get_task_levels(), + accelerator_type, + datamodule.featurization, + datamodule.task_norms, + ) + + metrics_on_progress_bar = predictor.get_metrics_on_progress_bar + trainer = load_trainer(cfg, accelerator_type, metrics_on_progress_bar=metrics_on_progress_bar) + + predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"]) + + # Run the model training + trainer.fit(model=predictor, datamodule=datamodule) + trainer.test(model=predictor, datamodule=datamodule) + +if __name__ == "__main__": + config_dir = "../expts/hydra-configs/" # Path to your config directory + test_CLITraining.setup_class() + + ut.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index b6a7b171c..537f35775 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -22,7 +22,6 @@ import unittest as ut import gzip -from graphium.utils.read_file import file_opener from graphium.utils.tensor import ( nan_mad, nan_mean, @@ -150,32 +149,6 @@ def test_nan_mad(self): np.testing.assert_almost_equal(torch_mad.numpy(), numpy_mad, decimal=4, err_msg=err_msg) -def test_file_opener(tmp_path): - # Create a temporary file - txt_file = tmp_path / "test.txt" - txt_file.write_text("Hello, World!") - - # Test opening file in read mode - with file_opener(txt_file, "r") as f: - assert f.read() == "Hello, World!" - - # Test opening file in write mode - with file_opener(txt_file, "w") as f: - f.write("New text") - - with file_opener(txt_file, "r") as f: - assert f.read() == "New text" - - # Create a temporary gzip file - gzip_file = tmp_path / "test.txt.gz" - with gzip.open(gzip_file, "wt") as f: - f.write("Hello, Gzip!") - - # Test opening gzip file in read mode - with file_opener(gzip_file, "r") as f: - assert f.read() == "Hello, Gzip!" - - class test_SafeRun(ut.TestCase): def test_safe_run(self): # Error is caught @@ -205,7 +178,7 @@ def test_safe_run(self): print("This is not an error") -class TestTensorFp16ToFp32(ut.TestCase): +class test_TensorFp16ToFp32(ut.TestCase): def test_tensor_fp16_to_fp32(self): # Create a tensor tensor = torch.randn(10, 10).half()