Skip to content

Add DAG-GFlowNet (Bayesian Structure learning, Deleu et al., 2022) #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 48 commits into
base: master
Choose a base branch
from

Conversation

hyeok9855
Copy link
Collaborator

@hyeok9855 hyeok9855 commented Mar 26, 2025

PR Summary

This PR includes the BayesianStructure env to reproduce DAG-GFlowNet.

You can find a script for training with an MLP in tutorials/examples/train_bayesian_structure.py.

Issues

  • Training is very slow.

TODO in some following PRs

@hyeok9855 hyeok9855 marked this pull request as draft March 26, 2025 14:30
@hyeok9855
Copy link
Collaborator Author

hyeok9855 commented Mar 28, 2025

The below error is resolved in #299


@josephdviviano
An error occurs when using the replay buffer:

python tutorials/examples/train_bayesian_structure.py

  3%|███████▊                                                                                                                                                                                                                                                           | 30/1000 [05:01<2:42:32, 10.05s/it]
Traceback (most recent call last):
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/tutorials/examples/train_bayesian_structure.py", line 519, in <module>
    main(args)
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/tutorials/examples/train_bayesian_structure.py", line 397, in main
    loss = gflownet.loss(env, training_samples, recalculate_all_logprobs=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/gflownet/trajectory_balance.py", line 70, in loss
    _, _, scores = self.get_trajectories_scores(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/gflownet/base.py", line 204, in get_trajectories_scores
    log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(
                                               ^^^^^^^^^^^^^^^^^^^^^
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/gflownet/base.py", line 185, in get_pfs_and_pbs
    return get_trajectory_pfs_and_pbs(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/utils/prob_calculations.py", line 45, in get_trajectory_pfs_and_pbs
    log_pf_trajectories = get_trajectory_pfs(
                          ^^^^^^^^^^^^^^^^^^^
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/utils/prob_calculations.py", line 72, in get_trajectory_pfs
    raise AssertionError("Something wrong happening with log_pf evaluations")

The assertion is in get_trajectory_pfs of src/gfn/utils/prob_calculations.py; here.

You can reproduce the error by running the following on this branch:

python tutorials/examples/train_bayesian_structure.py

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some intermediate comments.

return self._log_prior


class FairPrior(BasePrior):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get docstrings explaining the math behind all priors please?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The priors are taken directly from the original repo without modifications. (And there's no docstring there, either.) My understanding of this task is just as bad as yours.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should ask the author for assistance with this?

return order


def sample_from_linear_gaussian(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this used for? the return datastruct is strange and I would expect the topological sort to be potentially slow?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this used for?

This function generates the data from the true DAG, and the data is required for the reward function (scorer).

the return datastruct is strange

Could you elaborate more?

I would expect the topological sort to be potentially slow?

This code is almost identical to the one in the original repo. And this is called only once at the initialization of the scorer object, so I think it will be okay.

return ld


class BaseScore(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this base class necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if we want to support other scores later (e.g., BDe score). Note that this structure (BaseScore - BGeScore) was also brought over from the original repo.

self._log_prior = all_parents * math.log(p) + (
self.num_variables - all_parents - 1
) * math.log1p(-p)
return self._log_prior
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I'm not sure either. This is taken from the original repo without modification.

@saleml saleml mentioned this pull request Apr 15, 2025
@hyeok9855 hyeok9855 requested a review from Copilot April 17, 2025 09:44
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for DAG-GFlowNet to reproduce Bayesian Structure learning as described in Deleu et al. (2022) by introducing a new BayesianStructure environment along with several helper modules for scoring, sampling, priors, and graph generation. Additional updates improve device propagation, platform‐specific multiprocessing handling, and overall consistency of the codebase.

  • Updated probability distribution computation in modules to include parameter assertions and refined exploration handling.
  • Added new helper modules under gfn/gym/helpers/bayesian_structure for scoring (BGeScore), sampling, priors, graph generation, and data factories.
  • Refactored various environment, actions, and container functions to consistently propagate the device and enhance multiprocessing compatibility.

Reviewed Changes

Copilot reviewed 34 out of 34 changed files in this pull request and generated no comments.

Show a summary per file
File Description
src/gfn/modules.py Refined the probability distribution function with parameter assertions and logic.
src/gfn/gym/hypergrid.py Updated the multiprocessing start method to check the platform and force device usage.
src/gfn/gym/helpers/bayesian_structure/*.py Introduced new scorer, sampling, priors, graph generation, evaluation, and factories.
src/gfn/gym/graph_building.py Improved node/edge creation with explicit device handling and refactored action logic.
src/gfn/(containers env
pyproject.toml Added pgmpy dependency with a version constraint.
Comments suppressed due to low confidence (2)

src/gfn/gym/helpers/bayesian_structure/scores.py:189

  • [nitpick] Consider renaming 'tmp_var' to a more descriptive name (e.g., 'adjusted_sample_size') to improve readability.
tmp_var = self.num_samples + self.alpha_w - self.num_nodes + num_parents

src/gfn/gym/graph_building.py:348

  • [nitpick] Although noted by the TODO comment, replacing the hard-coded upper limit '10' with a configurable parameter would enhance flexibility and clarity of the code.
n_nodes = np.random.randint(1, 10)  # TODO: make the max n_nodes a parameter

Copy link
Collaborator

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check the helpers yet

Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very strong PR. Very well written, and follows the spirit of the library to perfection! I agree that we need to investigate why it is slow.
Before merging, could you please modify README.md?
In the section "Other environments available in the package include:", it would be nice to talk about this environment (and other graph environments while you're at it :)), and describe it as coming from the Deleu et al. paper. In fact this is our least toy example (so kudos for that Sanghyeok and Abhijith!), and I strongly believe it should be highlighted/advertised more. You should evne provide some example of commands (both in the top of the file and in the README) that could be run.

What do you think?

@@ -49,6 +49,7 @@ tox = { version = "*", optional = true }

# scripts dependencies.
matplotlib = { version = "*", optional = true }
pgmpy = { version = "<1.0.0", optional = true }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason we want this to be <1.0.0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some differences between versions 0.x and 1.0.0:

  1. Difference in attribute names (see the error below)
  2. LinearGaussianCPD in version 0.x accepts "variance" as input, but in 1.0.0, it accepts standard deviation.

Since the original DAG-GFN repo uses version 0.x, I believe it's reasonable to follow their settings. Additionally, pgmpy was upgraded to 1.0.0 recently (Apr 1st), and I feel reluctant to use it.

Traceback (most recent call last):
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/tutorials/examples/train_bayesian_structure.py", line 391, in <module>
    main(args)
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/tutorials/examples/train_bayesian_structure.py", line 188, in main
    scorer, _, gt_graph = get_scorer(
                          ^^^^^^^^^^^
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/gym/helpers/bayesian_structure/factories.py", line 75, in get_scorer
    graph, data, score = get_data(
                         ^^^^^^^^^
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/gym/helpers/bayesian_structure/factories.py", line 48, in get_data
    data = sample_from_linear_gaussian(graph, num_samples=num_samples, rng=rng)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/gym/helpers/bayesian_structure/sampling.py", line 46, in sample_from_linear_gaussian
    cpd.mean[0], cpd.variance, size=(num_samples,)
    ^^^^^^^^
AttributeError: 'LinearGaussianCPD' object has no attribute 'mean'

@hyeok9855 hyeok9855 mentioned this pull request May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants