Skip to content

Commit 644ce53

Browse files
committed
fix non-persistent params
1 parent 3591fe9 commit 644ce53

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

timm/models/_builder.py

+12
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,19 @@ def build_model_with_cfg(
429429
else:
430430
model = model_cls(cfg=model_cfg, **kwargs)
431431
if pretrained:
432+
# .to_empty() will also move cpu params/buffers to uninitialized storage.
433+
# this is problematic for non-persistent buffers, since they don't get loaded
434+
# from pretrained weights later (not part of state_dict). hence, we have
435+
# to save them before calling .to_empty() and fill them back after.
436+
buffers = {k: v for k, v in model.named_buffers() if not v.is_meta}
432437
model.to_empty(device="cpu")
438+
for k, v in model.named_buffers():
439+
if k in buffers:
440+
v.data = buffers[k]
441+
442+
# alternative, rely on internal method ._apply()
443+
# model._apply(lambda t: torch.empty_like(t, device="cpu") if t.is_meta else t)
444+
433445
model.pretrained_cfg = pretrained_cfg
434446
model.default_cfg = model.pretrained_cfg # alias for backwards compat
435447

0 commit comments

Comments
 (0)