Skip to content

Commit 490d222

Browse files
committed
Fix issue taking device from V before V exists
1 parent 875c19d commit 490d222

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

timm/optim/kron.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,16 @@ def step(self, closure=None):
306306
exprA, exprGs, _ = exprs
307307
Q = state["Q"]
308308
if self.deterministic:
309-
torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31))
309+
torch_rng = torch.Generator(device=debiased_momentum.device)
310+
torch_rng.manual_seed(self.rng.randint(0, 2 ** 31))
310311
else:
311312
torch_rng = None
312-
V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device=debiased_momentum.device)
313+
V = torch.randn(
314+
debiased_momentum.shape,
315+
generator=torch_rng,
316+
dtype=precond_dtype,
317+
device=debiased_momentum.device,
318+
)
313319
G = debiased_momentum if momentum_into_precond_update else grad
314320

315321
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)

0 commit comments

Comments
 (0)