-
Notifications
You must be signed in to change notification settings - Fork 44
Add DAG-GFlowNet (Bayesian Structure learning, Deleu et al., 2022) #296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…to follow torchgfn conventions
The below error is resolved in #299 @josephdviviano
The assertion is in You can reproduce the error by running the following on this branch: python tutorials/examples/train_bayesian_structure.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some intermediate comments.
return self._log_prior | ||
|
||
|
||
class FairPrior(BasePrior): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we get docstrings explaining the math behind all priors please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The priors are taken directly from the original repo without modifications. (And there's no docstring there, either.) My understanding of this task is just as bad as yours.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we should ask the author for assistance with this?
return order | ||
|
||
|
||
def sample_from_linear_gaussian( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this used for? the return datastruct is strange and I would expect the topological sort to be potentially slow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this used for?
This function generates the data from the true DAG, and the data is required for the reward function (scorer).
the return datastruct is strange
Could you elaborate more?
I would expect the topological sort to be potentially slow?
This code is almost identical to the one in the original repo. And this is called only once at the initialization of the scorer object, so I think it will be okay.
return ld | ||
|
||
|
||
class BaseScore(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this base class necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, if we want to support other scores later (e.g., BDe score). Note that this structure (BaseScore - BGeScore) was also brought over from the original repo.
self._log_prior = all_parents * math.log(p) + ( | ||
self.num_variables - all_parents - 1 | ||
) * math.log1p(-p) | ||
return self._log_prior |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the second definition from here, right?
https://en.wikipedia.org/wiki/Erd%C5%91s%E2%80%93R%C3%A9nyi_model#Definition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I'm not sure either. This is taken from the original repo without modification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for DAG-GFlowNet to reproduce Bayesian Structure learning as described in Deleu et al. (2022) by introducing a new BayesianStructure environment along with several helper modules for scoring, sampling, priors, and graph generation. Additional updates improve device propagation, platform‐specific multiprocessing handling, and overall consistency of the codebase.
- Updated probability distribution computation in modules to include parameter assertions and refined exploration handling.
- Added new helper modules under gfn/gym/helpers/bayesian_structure for scoring (BGeScore), sampling, priors, graph generation, and data factories.
- Refactored various environment, actions, and container functions to consistently propagate the device and enhance multiprocessing compatibility.
Reviewed Changes
Copilot reviewed 34 out of 34 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
src/gfn/modules.py | Refined the probability distribution function with parameter assertions and logic. |
src/gfn/gym/hypergrid.py | Updated the multiprocessing start method to check the platform and force device usage. |
src/gfn/gym/helpers/bayesian_structure/*.py | Introduced new scorer, sampling, priors, graph generation, evaluation, and factories. |
src/gfn/gym/graph_building.py | Improved node/edge creation with explicit device handling and refactored action logic. |
src/gfn/(containers | env |
pyproject.toml | Added pgmpy dependency with a version constraint. |
Comments suppressed due to low confidence (2)
src/gfn/gym/helpers/bayesian_structure/scores.py:189
- [nitpick] Consider renaming 'tmp_var' to a more descriptive name (e.g., 'adjusted_sample_size') to improve readability.
tmp_var = self.num_samples + self.alpha_w - self.num_nodes + num_parents
src/gfn/gym/graph_building.py:348
- [nitpick] Although noted by the TODO comment, replacing the hard-coded upper limit '10' with a configurable parameter would enhance flexibility and clarity of the code.
n_nodes = np.random.randint(1, 10) # TODO: make the max n_nodes a parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't check the helpers yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very strong PR. Very well written, and follows the spirit of the library to perfection! I agree that we need to investigate why it is slow.
Before merging, could you please modify README.md?
In the section "Other environments available in the package include:", it would be nice to talk about this environment (and other graph environments while you're at it :)), and describe it as coming from the Deleu et al. paper. In fact this is our least toy example (so kudos for that Sanghyeok and Abhijith!), and I strongly believe it should be highlighted/advertised more. You should evne provide some example of commands (both in the top of the file and in the README) that could be run.
What do you think?
@@ -49,6 +49,7 @@ tox = { version = "*", optional = true } | |||
|
|||
# scripts dependencies. | |||
matplotlib = { version = "*", optional = true } | |||
pgmpy = { version = "<1.0.0", optional = true } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason we want this to be <1.0.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some differences between versions 0.x and 1.0.0:
- Difference in attribute names (see the error below)
- LinearGaussianCPD in version 0.x accepts "variance" as input, but in 1.0.0, it accepts standard deviation.
Since the original DAG-GFN repo uses version 0.x, I believe it's reasonable to follow their settings. Additionally, pgmpy was upgraded to 1.0.0 recently (Apr 1st), and I feel reluctant to use it.
Traceback (most recent call last):
File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/tutorials/examples/train_bayesian_structure.py", line 391, in <module>
main(args)
File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/tutorials/examples/train_bayesian_structure.py", line 188, in main
scorer, _, gt_graph = get_scorer(
^^^^^^^^^^^
File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/gym/helpers/bayesian_structure/factories.py", line 75, in get_scorer
graph, data, score = get_data(
^^^^^^^^^
File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/gym/helpers/bayesian_structure/factories.py", line 48, in get_data
data = sample_from_linear_gaussian(graph, num_samples=num_samples, rng=rng)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sanghyeok/GFN/torchgfn-dag/torchgfn/src/gfn/gym/helpers/bayesian_structure/sampling.py", line 46, in sample_from_linear_gaussian
cpd.mean[0], cpd.variance, size=(num_samples,)
^^^^^^^^
AttributeError: 'LinearGaussianCPD' object has no attribute 'mean'
PR Summary
This PR includes the
BayesianStructure
env to reproduce DAG-GFlowNet.You can find a script for training with an MLP in
tutorials/examples/train_bayesian_structure.py
.Issues
TODO in some following PRs