|
17 | 17 | import torch.nn.functional as F
|
18 | 18 |
|
19 | 19 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
20 |
| -from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn |
| 20 | +from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn, nchw_to |
21 | 21 | from ._builder import build_model_with_cfg
|
22 | 22 | from ._registry import generate_default_cfgs, register_model
|
23 | 23 |
|
@@ -447,17 +447,30 @@ def __init__(
|
447 | 447 | self.norm = norm_layer(dims[-1])
|
448 | 448 | self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
449 | 449 |
|
450 |
| - def forward(self, x: torch.Tensor) -> torch.Tensor: |
451 | 450 |
|
| 451 | + def _forward_features(self, x: torch.Tensor) -> torch.Tensor: |
| 452 | + # nn.Sequential forward can't accept tuple intermediates |
| 453 | + # TODO grad checkpointing |
452 | 454 | for stage in self.stages:
|
453 | 455 | x = stage(x)
|
454 |
| - |
455 | 456 |
|
| 457 | + return x |
| 458 | + |
| 459 | + def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
| 460 | + x = self._forward_features(x) |
| 461 | + |
| 462 | + return x[0] if self.use_cls_token else x |
| 463 | + |
| 464 | + def forward_head(self, x: torch.Tensor) -> torch.Tensor: |
456 | 465 | if self.use_cls_token:
|
457 | 466 | return self.head(self.norm(x[1].flatten(1)))
|
458 | 467 | else:
|
459 | 468 | return self.head(self.norm(x.mean(dim=(2,3))))
|
460 | 469 |
|
| 470 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 471 | + x = self._forward_features(x) |
| 472 | + x = self.forward_head(x) |
| 473 | + return x |
461 | 474 |
|
462 | 475 |
|
463 | 476 | def checkpoint_filter_fn(state_dict, model):
|
|
0 commit comments