Skip to content

Commit 362f13e

Browse files
authored
Generalized Ising implementation (#1)
* Dump before implementation of proper cavity var optimization * Dirty but more or less works * Update README * Update README #2 * Clean up tests folder * Saturday dump * Remove dirty example * Clean up solvers * Clean solvers and utils * Update setup.py * Update README and tests * Clean up deq * Clean modules * Update README * Simplifiy single layer example and remove plot from debug in deq
1 parent f18b679 commit 362f13e

File tree

9 files changed

+394
-219
lines changed

9 files changed

+394
-219
lines changed

README.md

+21-7
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,27 @@
22

33
## Deep Implicit Attention
44

5-
_The return of the Boltzmann machine_
5+
Experimental implementation of deep implicit attention in PyTorch.
66

7-
---
7+
**Summary:** Using deep equilibrium networks to implicitly solve a set of self-consistent mean-field equations of a random Ising model implements attention as a collective response 🤗 and provides insight into the transformer architecture, connecting it to mean-field theory, message-passing algorithms, and Boltzmann machines.
88

9-
Experimental implementation of deep implicit attention in PyTorch.
9+
**Blog post (in preparation): _Deep Implicit Attention: A Mean-Field Theory Perspective on Attention Mechanisms_**
10+
11+
## To-do
1012

11-
**Key idea:** Use deep equilibrium networks to implicitly solve a set of self-consistent mean-field equations of a random Ising model: attention as a collective response 🤗.
13+
### Modules
14+
- [x] Add a `GeneralizedIsingGaussianAdaTAP` module implementing the adaptive TAP mean-field equations for an Ising-like vector model with standard multivariate Gaussian priors over spins
15+
- [ ] Figure out the analytical Gibbs free energy for `GeneralizedIsingGaussianAdaTAP` and implement it to be able to use it as a stand-alone loss function
16+
- [ ] Look into making the parameters of the multivariate Gaussian priors in `GeneralizedIsingGaussianAdaTAP` trainable
17+
- [ ] Add a `VanillaSoftmaxAttention` module which reproduces vanilla softmax attention, i.e. implementing coupling weights between spins which depend solely on linear transformations of the external sources (queries/keys) and replacing the self-correction term with a parametrized position-wise feed-forward network
1218

13-
**Blog post (in preparation):** <a href="https://mcbal.github.io/">Deep Implicit Attention: A Mean-Field Theory Perspective on Attention Mechanisms</a>
19+
### Models
20+
- [ ] Add a `DeepImplicitAttentionTransformer` model
21+
- [ ] Add a `DeepImplicitAttentionViT` model
22+
23+
### Miscellaneous
24+
- [ ] Add additional fixed-point / root solvers (e.g. Broyden)
25+
- [ ] Add examples (MNIST, sequence tasks, ...)
1426

1527
## Setup
1628

@@ -30,7 +42,9 @@ $ python -m unittest
3042

3143
See `tests` for now until `examples` folder is populated.
3244

33-
## Selection of references
45+
## References
46+
47+
### Selection of literature
3448
On variational inference, iterative approximation algorithms, expectation propagation, mean-field methods and belief propagation:
3549
- [Expectation Propagation](https://arxiv.org/abs/1409.6179) (2014) by Jack Raymond, Andre Manoel, Manfred Opper
3650

@@ -48,7 +62,7 @@ On deep equilibrium networks:
4862
- [Chapter 4: Deep Equilibrium Models](https://implicit-layers-tutorial.org/deep_equilibrium_models/) of the [Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond](http://implicit-layers-tutorial.org/), created by Zico Kolter, David Duvenaud, and Matt Johnson
4963

5064

51-
## Code inspiration
65+
### Code inspiration
5266

5367
- http://implicit-layers-tutorial.org/
5468
- https://github.com/locuslab/deq

deep_implicit_attention/deq.py

+38-24
Original file line numberDiff line numberDiff line change
@@ -11,44 +11,48 @@
1111
class _DEQModule(nn.Module, metaclass=ABCMeta):
1212
def __init__(self):
1313
super().__init__()
14-
self.state_shape = None
14+
self.shapes = None
1515

1616
def pack_state(self, z_list):
17-
"""Transform list of batched tensors into batch of vectors."""
18-
self.state_shape = [t.shape[1:] for t in z_list]
17+
"""
18+
Transform list of batched tensors into batch of vectors.
19+
"""
20+
self.shapes = [t.shape[1:] for t in z_list]
1921
bsz = z_list[0].shape[0]
20-
z = torch.cat([elem.reshape(bsz, -1) for elem in z_list], dim=1)
22+
z = torch.cat([t.reshape(bsz, -1) for t in z_list], dim=1)
2123
return z
2224

2325
def unpack_state(self, z):
24-
"""Transform batch of vectors into list of batched tensors according to `state_shape`."""
25-
assert self.state_shape is not None
26+
"""
27+
Transform batch of vectors into list of batched tensors.
28+
"""
29+
assert self.shapes is not None
2630
bsz, z_list = z.shape[0], []
27-
start_idx, end_idx = 0, reduce(lambda x, y: x * y, self.state_shape[0])
28-
for i in range(len(self.state_shape)):
29-
z_list.append(z[:, start_idx:end_idx].view(bsz, *self.state_shape[i]))
30-
if i < len(self.state_shape) - 1:
31+
start_idx, end_idx = 0, reduce(lambda x, y: x * y, self.shapes[0])
32+
for i in range(len(self.shapes)):
33+
z_list.append(z[:, start_idx:end_idx].view(bsz, *self.shapes[i]))
34+
if i < len(self.shapes) - 1:
3135
start_idx = end_idx
32-
end_idx += reduce(lambda x, y: x * y, self.state_shape[i + 1])
36+
end_idx += reduce(lambda x, y: x * y, self.shapes[i + 1])
3337
return z_list
3438

3539
@abstractmethod
36-
def get_initial_guess(self, x):
40+
def _initial_guess(self, x):
3741
"""Return an initial guess for the fixed-point state based on shape of `x`."""
3842
pass
3943

4044
@abstractmethod
4145
def forward(self, z, x, *args):
42-
"""Implement (z_{n}, x) -> z_{n+1}."""
46+
"""Implement f(z_{n}, x) -> z_{n+1}."""
4347
pass
4448

4549

4650
class DEQFixedPoint(nn.Module):
4751
_default_kwargs = {
48-
"solver_fwd_max_iter": 30,
49-
"solver_fwd_tol": 1e-4,
50-
"solver_bwd_max_iter": 30,
51-
"solver_bwd_tol": 1e-4,
52+
'solver_fwd_max_iter': 30,
53+
'solver_fwd_tol': 1e-4,
54+
'solver_bwd_max_iter': 30,
55+
'solver_bwd_tol': 1e-4,
5256
}
5357

5458
def __init__(self, fun, solver, output_elements=[0], **kwargs):
@@ -60,14 +64,20 @@ def __init__(self, fun, solver, output_elements=[0], **kwargs):
6064
self.kwargs.update(**kwargs)
6165

6266
def _fixed_point(self, z0, x, *args, **kwargs):
67+
"""Find fixed-point of `fun` given `z0` and `x`."""
68+
6369
# Compute forward pass: find equilibrium state
6470
with torch.no_grad():
6571
out = self.solver(
6672
lambda z: self.fun(z, x, *args),
6773
z0,
68-
**filter_kwargs(kwargs, "solver_fwd_"),
74+
**filter_kwargs(kwargs, 'solver_fwd_'),
6975
)
70-
z, _ = out["result"], out["rel_trace"]
76+
z = out['result']
77+
if kwargs.get('debug', False):
78+
print(f"{out['rel_trace'][0]} -> {out['rel_trace'][-1]}")
79+
# from .utils import log_plot
80+
# log_plot(out['rel_trace'])
7181

7282
if self.training:
7383
# Re-engage autograd tape at equilibrium state
@@ -78,23 +88,27 @@ def _fixed_point(self, z0, x, *args, **kwargs):
7888

7989
def backward_hook(grad):
8090
out = self.solver(
81-
lambda y: autograd.grad(fun_bwd, z_bwd, y, retain_graph=True)[0]
91+
lambda y: autograd.grad(
92+
fun_bwd, z_bwd, y, retain_graph=True)[0]
8293
+ grad,
8394
torch.zeros_like(grad),
84-
**filter_kwargs(kwargs, "solver_bwd_"),
95+
**filter_kwargs(kwargs, 'solver_bwd_'),
8596
)
86-
g, _ = out["result"], out["rel_trace"]
97+
g = out['result']
98+
# [DEBUG] insert statements here for backward pass inspection
8799
return g
88100

89101
z.register_hook(backward_hook)
90102

91103
return z
92104

93105
def forward(self, x, *args, **kwargs):
106+
# Merge default kwargs with incoming runtime kwargs.
107+
kwargs = {**self.kwargs, **kwargs}
94108
# Get list of initial guess tensors and reshape into a batch of vectors
95-
z0 = self.fun.pack_state(kwargs.get("z0", self.fun.get_initial_guess(x)))
109+
z0 = self.fun.pack_state(self.fun._initial_guess(x))
96110
# Find equilibrium vectors
97-
z_star = self._fixed_point(z0, x, *args, **self.kwargs)
111+
z_star = self._fixed_point(z0, x, *args, **kwargs)
98112
# Return (subset of) list of tensors of original input shapes
99113
out = [self.fun.unpack_state(z_star)[i] for i in self.output_elements]
100114
return out[0] if len(out) == 1 else out

deep_implicit_attention/models.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
"""TODO: BoltzmannAgent, BoltzmannTransformer"""
1+
"""
2+
TODO:
3+
- Add a `DeepImplicitAttentionTransformer` model
4+
- Add a `DeepImplicitAttentionViT` model
5+
"""

0 commit comments

Comments
 (0)