@@ -400,6 +400,7 @@ def __init__(
400
400
self .feature_info = []
401
401
402
402
self .use_cls_token = use_cls_token
403
+ self .global_pool = 'token' if use_cls_token else 'avg'
403
404
404
405
dpr = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
405
406
@@ -448,6 +449,21 @@ def __init__(
448
449
self .head = nn .Linear (dims [- 1 ], num_classes ) if num_classes > 0 else nn .Identity ()
449
450
450
451
452
+
453
+ @torch .jit .ignore
454
+ def get_classifier (self ) -> nn .Module :
455
+ return self .head
456
+
457
+ def reset_classifier (self , num_classes : int , global_pool = None ) -> None :
458
+ self .num_classes = num_classes
459
+ if global_pool is not None :
460
+ assert global_pool in ('' , 'avg' , 'token' )
461
+ if global_pool == 'token' and not self .use_cls_token :
462
+ assert False , 'Model not configured to use class token'
463
+ self .global_pool = global_pool
464
+ self .head = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
465
+
466
+
451
467
def _forward_features (self , x : torch .Tensor ) -> torch .Tensor :
452
468
# nn.Sequential forward can't accept tuple intermediates
453
469
# TODO grad checkpointing
@@ -457,12 +473,13 @@ def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
457
473
return x
458
474
459
475
def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
476
+ # get feature map, not always used
460
477
x = self ._forward_features (x )
461
478
462
479
return x [0 ] if self .use_cls_token else x
463
480
464
481
def forward_head (self , x : torch .Tensor ) -> torch .Tensor :
465
- if self .use_cls_token :
482
+ if self .global_pool == 'token' :
466
483
return self .head (self .norm (x [1 ].flatten (1 )))
467
484
else :
468
485
return self .head (self .norm (x .mean (dim = (2 ,3 ))))
0 commit comments