@@ -65,7 +65,7 @@ def _fixed_point(self, z0, x, *args, **kwargs):
65
65
out = self .solver (
66
66
lambda z : self .fun (z , x , * args ),
67
67
z0 ,
68
- ** filter_kwargs (kwargs , "solver_fwd " ),
68
+ ** filter_kwargs (kwargs , "solver_fwd_ " ),
69
69
)
70
70
z , _ = out ["result" ], out ["rel_trace" ]
71
71
@@ -81,7 +81,7 @@ def backward_hook(grad):
81
81
lambda y : autograd .grad (fun_bwd , z_bwd , y , retain_graph = True )[0 ]
82
82
+ grad ,
83
83
torch .zeros_like (grad ),
84
- ** filter_kwargs (kwargs , "solver_bwd " ),
84
+ ** filter_kwargs (kwargs , "solver_bwd_ " ),
85
85
)
86
86
g , _ = out ["result" ], out ["rel_trace" ]
87
87
return g
@@ -94,7 +94,7 @@ def forward(self, x, *args, **kwargs):
94
94
# Get list of initial guess tensors and reshape into a batch of vectors
95
95
z0 = self .fun .pack_state (kwargs .get ("z0" , self .fun .get_initial_guess (x )))
96
96
# Find equilibrium vectors
97
- self .kwargs .update (** kwargs )
98
- z_star = self ._fixed_point (z0 , x , * args , ** kwargs )
97
+ z_star = self ._fixed_point (z0 , x , * args , ** self .kwargs )
99
98
# Return (subset of) list of tensors of original input shapes
100
- return [self .fun .unpack_state (z_star )[i ] for i in self .output_elements ]
99
+ out = [self .fun .unpack_state (z_star )[i ] for i in self .output_elements ]
100
+ return out [0 ] if len (out ) == 1 else out
0 commit comments