8
8
from .utils import batched_eye , batched_eye_like
9
9
10
10
11
- class DeepImplicitAttention (_DEQModule ):
11
+ class DEQMeanFieldAttention (_DEQModule ):
12
12
"""Deep implicit attention.
13
13
14
14
Attention as a fixed-point mean-field response of an Ising-like vector
@@ -36,6 +36,10 @@ class DeepImplicitAttention(_DEQModule):
36
36
norm of tensor |weight| ~ O(1).
37
37
weight_training (bool):
38
38
Allow coupling weights to be trained. (default: `True`).
39
+ weight_sym_internal (bool):
40
+ Symmetrize internal indices of weight tensor. (default: `False`).
41
+ weight_sym_sites (bool):
42
+ Symmetrize site indices of weight tensor. (default: `False`).
39
43
lin_response (bool):
40
44
Toggle linear response correction to mean-field (default: `True`).
41
45
"""
@@ -46,6 +50,8 @@ def __init__(
46
50
dim ,
47
51
weight_init_std = None ,
48
52
weight_training = True ,
53
+ weight_sym_internal = False ,
54
+ weight_sym_sites = False ,
49
55
lin_response = True ,
50
56
):
51
57
super ().__init__ ()
@@ -60,6 +66,9 @@ def __init__(
60
66
),
61
67
training = weight_training ,
62
68
)
69
+ self .weight_sym_internal = weight_sym_internal
70
+ self .weight_sym_sites = weight_sym_sites
71
+
63
72
if lin_response :
64
73
self .correction = FeedForward (dim ) # no dropout
65
74
self .lin_response = lin_response
@@ -73,28 +82,28 @@ def _init_weight(self, num_spins, dim, init_std, training):
73
82
else :
74
83
self .register_buffer ('_weight' , weight )
75
84
76
- def weight (self , symmetrize_internal = True , symmetrize_sites = True ):
77
- """
78
- Return symmetrized and traceless weight tensor.
79
-
80
- Note:
81
- This implementation is very inefficient since it stores N^2*d^2
82
- parameters but only needs N*(N-1)*d*(d+1)/4. Also look into new
83
- torch parametrization functionality:
84
- https://pytorch.org/tutorials/intermediate/parametrizations.html
85
- """
85
+ def weight (self ):
86
+ """Return symmetrized and traceless weight tensor."""
86
87
num_spins , dim = self ._weight .size (0 ), self ._weight .size (2 )
87
88
weight = self ._weight
88
- if symmetrize_internal : # local dofs at every site
89
+ if self . weight_sym_internal :
89
90
weight = 0.5 * (weight + weight .permute ([0 , 1 , 3 , 2 ]))
90
- if symmetrize_sites : # between sites
91
+ if self . weight_sym_sites :
91
92
weight = 0.5 * (weight + weight .permute ([1 , 0 , 2 , 3 ]))
92
93
mask = batched_eye (dim ** 2 , num_spins ,
93
94
device = weight .device , dtype = weight .dtype )
94
95
mask = rearrange (mask , '(a b) i j -> i j a b' , a = dim , b = dim )
95
- weight = (1.0 - mask ) * weight # zeros on sites' block-diagonal
96
+ weight = (1.0 - mask ) * weight
96
97
return weight
97
98
99
+ def count_params (self ):
100
+ num_spins , dim = self ._weight .size (0 ), self ._weight .size (2 )
101
+ site_factor = 0.5 * num_spins * \
102
+ (num_spins - 1 ) if self .weight_sym_sites else num_spins * (num_spins - 1 )
103
+ internal_factor = 0.5 * dim * \
104
+ (dim + 1 ) if self .weight_sym_internal else dim ** 2
105
+ return site_factor * internal_factor
106
+
98
107
def _initial_guess (self , x ):
99
108
"""Return initial guess tensors."""
100
109
bsz , N , d = x .shape
@@ -119,14 +128,13 @@ def forward(self, z, x, *args):
119
128
120
129
spin_mean = torch .einsum (
121
130
'i j c d, b j d -> b i c' , self .weight (), spin_mean ) + x
122
-
123
131
if self .lin_response :
124
132
spin_mean = spin_mean - self .correction (spin_mean )
125
133
126
134
return self .pack_state ([spin_mean ])
127
135
128
136
129
- class ExplicitDeepImplicitAttention (_DEQModule ):
137
+ class DEQAdaTAPMeanFieldAttention (_DEQModule ):
130
138
"""Ising-like vector model with multivariate Gaussian prior over spins.
131
139
132
140
Generalization of the application of the adaptive TAP mean-field approach
@@ -162,6 +170,10 @@ class ExplicitDeepImplicitAttention(_DEQModule):
162
170
norm of tensor |weight| ~ O(1).
163
171
weight_training (bool):
164
172
Allow coupling weights to be trained. (default: `True`).
173
+ weight_sym_internal (bool):
174
+ Symmetrize internal indices of weight tensor. (default: `True`).
175
+ weight_sym_sites (bool):
176
+ Symmetrize site indices of weight tensor. (default: `True`).
165
177
lin_response (bool):
166
178
Toggle linear response correction to mean-field (default: `True`).
167
179
"""
@@ -172,6 +184,8 @@ def __init__(
172
184
dim ,
173
185
weight_init_std = None ,
174
186
weight_training = True ,
187
+ weight_sym_internal = True ,
188
+ weight_sym_sites = True ,
175
189
lin_response = True ,
176
190
):
177
191
super ().__init__ ()
@@ -186,11 +200,15 @@ def __init__(
186
200
),
187
201
training = weight_training ,
188
202
)
203
+ self .weight_sym_internal = weight_sym_internal
204
+ self .weight_sym_sites = weight_sym_sites
205
+
189
206
self .register_buffer (
190
207
'spin_prior_inv_var' ,
191
208
batched_eye_like (
192
209
torch .zeros (num_spins , dim , dim ))
193
210
)
211
+
194
212
self .lin_response = lin_response
195
213
196
214
def _init_weight (self , num_spins , dim , init_std , training ):
@@ -202,7 +220,7 @@ def _init_weight(self, num_spins, dim, init_std, training):
202
220
else :
203
221
self .register_buffer ('_weight' , weight )
204
222
205
- def weight (self , symmetrize_internal = True , symmetrize_sites = True ):
223
+ def weight (self ):
206
224
"""
207
225
Return symmetrized and traceless weight tensor.
208
226
@@ -214,16 +232,24 @@ def weight(self, symmetrize_internal=True, symmetrize_sites=True):
214
232
"""
215
233
num_spins , dim = self ._weight .size (0 ), self ._weight .size (2 )
216
234
weight = self ._weight
217
- if symmetrize_internal : # local dofs at every site
235
+ if self . weight_sym_internal :
218
236
weight = 0.5 * (weight + weight .permute ([0 , 1 , 3 , 2 ]))
219
- if symmetrize_sites : # between sites
237
+ if self . weight_sym_sites :
220
238
weight = 0.5 * (weight + weight .permute ([1 , 0 , 2 , 3 ]))
221
239
mask = batched_eye (dim ** 2 , num_spins ,
222
240
device = weight .device , dtype = weight .dtype )
223
241
mask = rearrange (mask , '(a b) i j -> i j a b' , a = dim , b = dim )
224
- weight = (1.0 - mask ) * weight # zeros on sites' block-diagonal
242
+ weight = (1.0 - mask ) * weight
225
243
return weight
226
244
245
+ def count_params (self ):
246
+ num_spins , dim = self ._weight .size (0 ), self ._weight .size (2 )
247
+ site_factor = 0.5 * num_spins * \
248
+ (num_spins - 1 ) if self .weight_sym_sites else num_spins * (num_spins - 1 )
249
+ internal_factor = 0.5 * dim * \
250
+ (dim + 1 ) if self .weight_sym_internal else dim ** 2
251
+ return site_factor * internal_factor
252
+
227
253
def _initial_guess (self , x ):
228
254
"""Return initial guess tensors."""
229
255
bsz , N , d = x .shape
@@ -243,7 +269,7 @@ def _spin_mean_var(self, x, cav_mean, cav_var):
243
269
inv_var = self .spin_prior_inv_var - cav_var
244
270
prefactor = torch .solve (batched_eye_like (inv_var ), inv_var ).solution
245
271
spin_mean = torch .einsum (
246
- 'n d e, b n d -> b n e' , prefactor , (cav_mean + x )
272
+ 'i d e, b i d -> b i e' , prefactor , (cav_mean + x )
247
273
)
248
274
spin_var = prefactor
249
275
return spin_mean , spin_var
@@ -275,8 +301,8 @@ def forward(self, z, x, *args):
275
301
weight = self .weight ()
276
302
277
303
cav_mean = torch .einsum (
278
- 'n m d e, b m e -> b n d' , weight , spin_mean
279
- ) - torch .einsum ('b n d e, b n d -> b n e' , cav_var , spin_mean )
304
+ 'i j d e, b j e -> b i d' , weight , spin_mean
305
+ ) - torch .einsum ('b i d e, b i d -> b i e' , cav_var , spin_mean )
280
306
281
307
spin_mean , spin_var = self ._spin_mean_var (x , cav_mean , cav_var [0 ])
282
308
0 commit comments