File tree 1 file changed +12
-0
lines changed
1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -429,7 +429,19 @@ def build_model_with_cfg(
429
429
else :
430
430
model = model_cls (cfg = model_cfg , ** kwargs )
431
431
if pretrained :
432
+ # .to_empty() will also move cpu params/buffers to uninitialized storage.
433
+ # this is problematic for non-persistent buffers, since they don't get loaded
434
+ # from pretrained weights later (not part of state_dict). hence, we have
435
+ # to save them before calling .to_empty() and fill them back after.
436
+ buffers = {k : v for k , v in model .named_buffers () if not v .is_meta }
432
437
model .to_empty (device = "cpu" )
438
+ for k , v in model .named_buffers ():
439
+ if k in buffers :
440
+ v .data = buffers [k ]
441
+
442
+ # alternative, rely on internal method ._apply()
443
+ # model._apply(lambda t: torch.empty_like(t, device="cpu") if t.is_meta else t)
444
+
433
445
model .pretrained_cfg = pretrained_cfg
434
446
model .default_cfg = model .pretrained_cfg # alias for backwards compat
435
447
You can’t perform that action at this time.
0 commit comments