Skip to content

Commit d0a8ceb

Browse files
committed
add a separate use_meta_init flag. handle the case when num_classes is changed
1 parent 644ce53 commit d0a8ceb

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

timm/models/_builder.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -421,14 +421,25 @@ def build_model_with_cfg(
421421
if 'feature_cls' in kwargs:
422422
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
423423

424+
# use meta-device init to speed up loading pretrained weights.
425+
# device context manager is only available for PyTorch>=2.0
426+
# when num_classes is changed, we rely on __init__() logic to initialize head weights.
427+
# thus, we can't use meta-device init in that case.
428+
num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"])
429+
use_meta_init = (
430+
pretrained
431+
and hasattr(torch.device("meta"), "__enter__")
432+
and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"])
433+
)
434+
424435
# Instantiate the model
425-
meta_device = torch.device("meta")
426-
with meta_device if hasattr(meta_device, "__enter__") and pretrained else nullcontext():
436+
with torch.device("meta") if use_meta_init else nullcontext():
427437
if model_cfg is None:
428438
model = model_cls(**kwargs)
429439
else:
430440
model = model_cls(cfg=model_cfg, **kwargs)
431-
if pretrained:
441+
442+
if use_meta_init:
432443
# .to_empty() will also move cpu params/buffers to uninitialized storage.
433444
# this is problematic for non-persistent buffers, since they don't get loaded
434445
# from pretrained weights later (not part of state_dict). hence, we have

0 commit comments

Comments
 (0)