Skip to content

Commit bff9807

Browse files
committed
Cleanup and fix tests and examples
1 parent 599d2f3 commit bff9807

File tree

3 files changed

+57
-31
lines changed

3 files changed

+57
-31
lines changed

deep_implicit_attention/attention.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .utils import batched_eye, batched_eye_like
99

1010

11-
class DeepImplicitAttention(_DEQModule):
11+
class DEQMeanFieldAttention(_DEQModule):
1212
"""Deep implicit attention.
1313
1414
Attention as a fixed-point mean-field response of an Ising-like vector
@@ -36,6 +36,10 @@ class DeepImplicitAttention(_DEQModule):
3636
norm of tensor |weight| ~ O(1).
3737
weight_training (bool):
3838
Allow coupling weights to be trained. (default: `True`).
39+
weight_sym_internal (bool):
40+
Symmetrize internal indices of weight tensor. (default: `False`).
41+
weight_sym_sites (bool):
42+
Symmetrize site indices of weight tensor. (default: `False`).
3943
lin_response (bool):
4044
Toggle linear response correction to mean-field (default: `True`).
4145
"""
@@ -46,6 +50,8 @@ def __init__(
4650
dim,
4751
weight_init_std=None,
4852
weight_training=True,
53+
weight_sym_internal=False,
54+
weight_sym_sites=False,
4955
lin_response=True,
5056
):
5157
super().__init__()
@@ -60,6 +66,9 @@ def __init__(
6066
),
6167
training=weight_training,
6268
)
69+
self.weight_sym_internal = weight_sym_internal
70+
self.weight_sym_sites = weight_sym_sites
71+
6372
if lin_response:
6473
self.correction = FeedForward(dim) # no dropout
6574
self.lin_response = lin_response
@@ -73,28 +82,28 @@ def _init_weight(self, num_spins, dim, init_std, training):
7382
else:
7483
self.register_buffer('_weight', weight)
7584

76-
def weight(self, symmetrize_internal=True, symmetrize_sites=True):
77-
"""
78-
Return symmetrized and traceless weight tensor.
79-
80-
Note:
81-
This implementation is very inefficient since it stores N^2*d^2
82-
parameters but only needs N*(N-1)*d*(d+1)/4. Also look into new
83-
torch parametrization functionality:
84-
https://pytorch.org/tutorials/intermediate/parametrizations.html
85-
"""
85+
def weight(self):
86+
"""Return symmetrized and traceless weight tensor."""
8687
num_spins, dim = self._weight.size(0), self._weight.size(2)
8788
weight = self._weight
88-
if symmetrize_internal: # local dofs at every site
89+
if self.weight_sym_internal:
8990
weight = 0.5 * (weight + weight.permute([0, 1, 3, 2]))
90-
if symmetrize_sites: # between sites
91+
if self.weight_sym_sites:
9192
weight = 0.5 * (weight + weight.permute([1, 0, 2, 3]))
9293
mask = batched_eye(dim ** 2, num_spins,
9394
device=weight.device, dtype=weight.dtype)
9495
mask = rearrange(mask, '(a b) i j -> i j a b', a=dim, b=dim)
95-
weight = (1.0 - mask) * weight # zeros on sites' block-diagonal
96+
weight = (1.0 - mask) * weight
9697
return weight
9798

99+
def count_params(self):
100+
num_spins, dim = self._weight.size(0), self._weight.size(2)
101+
site_factor = 0.5*num_spins * \
102+
(num_spins-1) if self.weight_sym_sites else num_spins*(num_spins-1)
103+
internal_factor = 0.5*dim * \
104+
(dim+1) if self.weight_sym_internal else dim**2
105+
return site_factor*internal_factor
106+
98107
def _initial_guess(self, x):
99108
"""Return initial guess tensors."""
100109
bsz, N, d = x.shape
@@ -119,14 +128,13 @@ def forward(self, z, x, *args):
119128

120129
spin_mean = torch.einsum(
121130
'i j c d, b j d -> b i c', self.weight(), spin_mean) + x
122-
123131
if self.lin_response:
124132
spin_mean = spin_mean - self.correction(spin_mean)
125133

126134
return self.pack_state([spin_mean])
127135

128136

129-
class ExplicitDeepImplicitAttention(_DEQModule):
137+
class DEQAdaTAPMeanFieldAttention(_DEQModule):
130138
"""Ising-like vector model with multivariate Gaussian prior over spins.
131139
132140
Generalization of the application of the adaptive TAP mean-field approach
@@ -162,6 +170,10 @@ class ExplicitDeepImplicitAttention(_DEQModule):
162170
norm of tensor |weight| ~ O(1).
163171
weight_training (bool):
164172
Allow coupling weights to be trained. (default: `True`).
173+
weight_sym_internal (bool):
174+
Symmetrize internal indices of weight tensor. (default: `True`).
175+
weight_sym_sites (bool):
176+
Symmetrize site indices of weight tensor. (default: `True`).
165177
lin_response (bool):
166178
Toggle linear response correction to mean-field (default: `True`).
167179
"""
@@ -172,6 +184,8 @@ def __init__(
172184
dim,
173185
weight_init_std=None,
174186
weight_training=True,
187+
weight_sym_internal=True,
188+
weight_sym_sites=True,
175189
lin_response=True,
176190
):
177191
super().__init__()
@@ -186,11 +200,15 @@ def __init__(
186200
),
187201
training=weight_training,
188202
)
203+
self.weight_sym_internal = weight_sym_internal
204+
self.weight_sym_sites = weight_sym_sites
205+
189206
self.register_buffer(
190207
'spin_prior_inv_var',
191208
batched_eye_like(
192209
torch.zeros(num_spins, dim, dim))
193210
)
211+
194212
self.lin_response = lin_response
195213

196214
def _init_weight(self, num_spins, dim, init_std, training):
@@ -202,7 +220,7 @@ def _init_weight(self, num_spins, dim, init_std, training):
202220
else:
203221
self.register_buffer('_weight', weight)
204222

205-
def weight(self, symmetrize_internal=True, symmetrize_sites=True):
223+
def weight(self):
206224
"""
207225
Return symmetrized and traceless weight tensor.
208226
@@ -214,16 +232,24 @@ def weight(self, symmetrize_internal=True, symmetrize_sites=True):
214232
"""
215233
num_spins, dim = self._weight.size(0), self._weight.size(2)
216234
weight = self._weight
217-
if symmetrize_internal: # local dofs at every site
235+
if self.weight_sym_internal:
218236
weight = 0.5 * (weight + weight.permute([0, 1, 3, 2]))
219-
if symmetrize_sites: # between sites
237+
if self.weight_sym_sites:
220238
weight = 0.5 * (weight + weight.permute([1, 0, 2, 3]))
221239
mask = batched_eye(dim ** 2, num_spins,
222240
device=weight.device, dtype=weight.dtype)
223241
mask = rearrange(mask, '(a b) i j -> i j a b', a=dim, b=dim)
224-
weight = (1.0 - mask) * weight # zeros on sites' block-diagonal
242+
weight = (1.0 - mask) * weight
225243
return weight
226244

245+
def count_params(self):
246+
num_spins, dim = self._weight.size(0), self._weight.size(2)
247+
site_factor = 0.5*num_spins * \
248+
(num_spins-1) if self.weight_sym_sites else num_spins*(num_spins-1)
249+
internal_factor = 0.5*dim * \
250+
(dim+1) if self.weight_sym_internal else dim**2
251+
return site_factor*internal_factor
252+
227253
def _initial_guess(self, x):
228254
"""Return initial guess tensors."""
229255
bsz, N, d = x.shape
@@ -243,7 +269,7 @@ def _spin_mean_var(self, x, cav_mean, cav_var):
243269
inv_var = self.spin_prior_inv_var - cav_var
244270
prefactor = torch.solve(batched_eye_like(inv_var), inv_var).solution
245271
spin_mean = torch.einsum(
246-
'n d e, b n d -> b n e', prefactor, (cav_mean + x)
272+
'i d e, b i d -> b i e', prefactor, (cav_mean + x)
247273
)
248274
spin_var = prefactor
249275
return spin_mean, spin_var
@@ -275,8 +301,8 @@ def forward(self, z, x, *args):
275301
weight = self.weight()
276302

277303
cav_mean = torch.einsum(
278-
'n m d e, b m e -> b n d', weight, spin_mean
279-
) - torch.einsum('b n d e, b n d -> b n e', cav_var, spin_mean)
304+
'i j d e, b j e -> b i d', weight, spin_mean
305+
) - torch.einsum('b i d e, b i d -> b i e', cav_var, spin_mean)
280306

281307
spin_mean, spin_var = self._spin_mean_var(x, cav_mean, cav_var[0])
282308

examples/single_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import torch
33

4-
from deep_implicit_attention.attention import DeepImplicitAttention
4+
from deep_implicit_attention.attention import DEQMeanFieldAttention
55
from deep_implicit_attention.deq import DEQFixedPoint
66
from deep_implicit_attention.solvers import anderson
77

@@ -10,7 +10,7 @@
1010

1111
# Initialize fixed-point wrapper around model system.
1212
deq_attn = DEQFixedPoint(
13-
DeepImplicitAttention(
13+
DEQMeanFieldAttention(
1414
num_spins=num_spins,
1515
dim=dim,
1616
weight_init_std=1.0 / np.sqrt(num_spins * dim**2),

tests/test_gradients.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,23 @@
55
from torch.autograd import gradcheck
66

77
from deep_implicit_attention.attention import (
8-
DeepImplicitAttention,
9-
ExplicitDeepImplicitAttention,
8+
DEQMeanFieldAttention,
9+
DEQAdaTAPMeanFieldAttention,
1010
)
1111
from deep_implicit_attention.deq import DEQFixedPoint
1212
from deep_implicit_attention.solvers import anderson
1313

1414

1515
class TestGradients(unittest.TestCase):
16-
def test_explicit_deep_implicit_attention(self):
16+
def test_adatap_mean_field_attention(self):
1717
"""Run a small network with double precision."""
1818

1919
num_spins, dim = 11, 3
2020

2121
for lin_response in [False, True]:
2222
with self.subTest():
2323
deq_attn = DEQFixedPoint(
24-
ExplicitDeepImplicitAttention(
24+
DEQAdaTAPMeanFieldAttention(
2525
num_spins=num_spins,
2626
dim=dim,
2727
lin_response=lin_response,
@@ -41,15 +41,15 @@ def test_explicit_deep_implicit_attention(self):
4141
)
4242
)
4343

44-
def test_deep_implicit_attention(self):
44+
def test_mean_field_attention(self):
4545
"""Run a small network with double precision."""
4646

4747
num_spins, dim = 11, 3
4848

4949
for lin_response in [False, True]:
5050
with self.subTest():
5151
deq_attn = DEQFixedPoint(
52-
DeepImplicitAttention(
52+
DEQMeanFieldAttention(
5353
num_spins=num_spins,
5454
dim=dim,
5555
lin_response=lin_response,

0 commit comments

Comments
 (0)