Skip to content

Commit 6c31ec1

Browse files
committed
Update cvt.py
1 parent 7d68ef7 commit 6c31ec1

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

timm/models/cvt.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn.functional as F
1818

1919
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
20-
from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn
20+
from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn, nchw_to
2121
from ._builder import build_model_with_cfg
2222
from ._registry import generate_default_cfgs, register_model
2323

@@ -447,17 +447,30 @@ def __init__(
447447
self.norm = norm_layer(dims[-1])
448448
self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
449449

450-
def forward(self, x: torch.Tensor) -> torch.Tensor:
451450

451+
def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
452+
# nn.Sequential forward can't accept tuple intermediates
453+
# TODO grad checkpointing
452454
for stage in self.stages:
453455
x = stage(x)
454-
455456

457+
return x
458+
459+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
460+
x = self._forward_features(x)
461+
462+
return x[0] if self.use_cls_token else x
463+
464+
def forward_head(self, x: torch.Tensor) -> torch.Tensor:
456465
if self.use_cls_token:
457466
return self.head(self.norm(x[1].flatten(1)))
458467
else:
459468
return self.head(self.norm(x.mean(dim=(2,3))))
460469

470+
def forward(self, x: torch.Tensor) -> torch.Tensor:
471+
x = self._forward_features(x)
472+
x = self.forward_head(x)
473+
return x
461474

462475

463476
def checkpoint_filter_fn(state_dict, model):

0 commit comments

Comments
 (0)