Skip to content

Commit 4d1b21a

Browse files
committed
Update cvt.py
1 parent 2975318 commit 4d1b21a

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

timm/models/cvt.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def __init__(
400400
self.feature_info = []
401401

402402
self.use_cls_token = use_cls_token
403+
self.global_pool = 'token' if use_cls_token else 'avg'
403404

404405
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
405406

@@ -448,6 +449,21 @@ def __init__(
448449
self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
449450

450451

452+
453+
@torch.jit.ignore
454+
def get_classifier(self) -> nn.Module:
455+
return self.head
456+
457+
def reset_classifier(self, num_classes: int, global_pool = None) -> None:
458+
self.num_classes = num_classes
459+
if global_pool is not None:
460+
assert global_pool in ('', 'avg', 'token')
461+
if global_pool == 'token' and not self.use_cls_token:
462+
assert False, 'Model not configured to use class token'
463+
self.global_pool = global_pool
464+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
465+
466+
451467
def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
452468
# nn.Sequential forward can't accept tuple intermediates
453469
# TODO grad checkpointing
@@ -457,12 +473,13 @@ def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
457473
return x
458474

459475
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
476+
# get feature map, not always used
460477
x = self._forward_features(x)
461478

462479
return x[0] if self.use_cls_token else x
463480

464481
def forward_head(self, x: torch.Tensor) -> torch.Tensor:
465-
if self.use_cls_token:
482+
if self.global_pool == 'token':
466483
return self.head(self.norm(x[1].flatten(1)))
467484
else:
468485
return self.head(self.norm(x.mean(dim=(2,3))))

0 commit comments

Comments
 (0)