File tree 1 file changed +14
-3
lines changed
1 file changed +14
-3
lines changed Original file line number Diff line number Diff line change @@ -421,14 +421,25 @@ def build_model_with_cfg(
421
421
if 'feature_cls' in kwargs :
422
422
feature_cfg ['feature_cls' ] = kwargs .pop ('feature_cls' )
423
423
424
+ # use meta-device init to speed up loading pretrained weights.
425
+ # device context manager is only available for PyTorch>=2.0
426
+ # when num_classes is changed, we rely on __init__() logic to initialize head weights.
427
+ # thus, we can't use meta-device init in that case.
428
+ num_classes = 0 if features else kwargs .get ("num_classes" , pretrained_cfg ["num_classes" ])
429
+ use_meta_init = (
430
+ pretrained
431
+ and hasattr (torch .device ("meta" ), "__enter__" )
432
+ and (num_classes == 0 or num_classes == pretrained_cfg ["num_classes" ])
433
+ )
434
+
424
435
# Instantiate the model
425
- meta_device = torch .device ("meta" )
426
- with meta_device if hasattr (meta_device , "__enter__" ) and pretrained else nullcontext ():
436
+ with torch .device ("meta" ) if use_meta_init else nullcontext ():
427
437
if model_cfg is None :
428
438
model = model_cls (** kwargs )
429
439
else :
430
440
model = model_cls (cfg = model_cfg , ** kwargs )
431
- if pretrained :
441
+
442
+ if use_meta_init :
432
443
# .to_empty() will also move cpu params/buffers to uninitialized storage.
433
444
# this is problematic for non-persistent buffers, since they don't get loaded
434
445
# from pretrained weights later (not part of state_dict). hence, we have
You can’t perform that action at this time.
0 commit comments