Skip to content

Commit f18b679

Browse files
committed
Fix kwargs parsing
1 parent 3f0049c commit f18b679

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

deep_implicit_attention/deq.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _fixed_point(self, z0, x, *args, **kwargs):
6565
out = self.solver(
6666
lambda z: self.fun(z, x, *args),
6767
z0,
68-
**filter_kwargs(kwargs, "solver_fwd"),
68+
**filter_kwargs(kwargs, "solver_fwd_"),
6969
)
7070
z, _ = out["result"], out["rel_trace"]
7171

@@ -81,7 +81,7 @@ def backward_hook(grad):
8181
lambda y: autograd.grad(fun_bwd, z_bwd, y, retain_graph=True)[0]
8282
+ grad,
8383
torch.zeros_like(grad),
84-
**filter_kwargs(kwargs, "solver_bwd"),
84+
**filter_kwargs(kwargs, "solver_bwd_"),
8585
)
8686
g, _ = out["result"], out["rel_trace"]
8787
return g
@@ -94,7 +94,7 @@ def forward(self, x, *args, **kwargs):
9494
# Get list of initial guess tensors and reshape into a batch of vectors
9595
z0 = self.fun.pack_state(kwargs.get("z0", self.fun.get_initial_guess(x)))
9696
# Find equilibrium vectors
97-
self.kwargs.update(**kwargs)
98-
z_star = self._fixed_point(z0, x, *args, **kwargs)
97+
z_star = self._fixed_point(z0, x, *args, **self.kwargs)
9998
# Return (subset of) list of tensors of original input shapes
100-
return [self.fun.unpack_state(z_star)[i] for i in self.output_elements]
99+
out = [self.fun.unpack_state(z_star)[i] for i in self.output_elements]
100+
return out[0] if len(out) == 1 else out

deep_implicit_attention/modules.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def __init__(
7171
self.prior_init_std = prior_init_std
7272

7373
self.solver = solver
74-
self.solver_tol = 1e-4
75-
self.solver_max_iter = 30
74+
self.solver_tol = solver_tol
75+
self.solver_max_iter = solver_max_iter
7676

7777
self.lin_response_correction = lin_response_correction
7878

@@ -144,4 +144,3 @@ def forward(self, z, x, *args):
144144
next_spin_mean, next_spin_var = pf * (cav_mean + x), pf
145145

146146
return self.pack_state([next_spin_mean, next_spin_var])
147-

deep_implicit_attention/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def set_all_seeds(seed):
1212
torch.cuda.manual_seed_all(seed)
1313

1414

15-
def filter_kwargs(kwargs, str):
16-
return {k.replace(str, ""): v for k, v in kwargs if k.startswith(str)}
15+
def filter_kwargs(kwargs, prefix):
16+
return {k.replace(prefix, ""): v for k, v in kwargs.items() if k.startswith(prefix)}
1717

1818

1919
def make_traceless(X: torch.Tensor):

0 commit comments

Comments
 (0)