Skip to content

Commit b049a5c

Browse files
committed
Merge remote-tracking branch 'origin/master' into norm_norm_norm
2 parents 61d3493 + 7cdd164 commit b049a5c

10 files changed

+1130
-37
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
2323

2424
## What's New
2525

26+
### Feb 2, 2022
27+
* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055)
28+
* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so.
29+
* The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware `pip install git+https://github.com/rwightman/pytorch-image-models` installs!
30+
* `0.5.x` releases and a `0.5.x` branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable.
31+
2632
### Jan 14, 2022
2733
* Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
2834
* Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features
@@ -410,6 +416,8 @@ Model validation results can be found in the [documentation](https://rwightman.g
410416

411417
My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics.
412418

419+
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
420+
413421
[timmdocs](https://fastai.github.io/timmdocs/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
414422

415423
[paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`.

docs/feature_extraction.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ torch.Size([2, 1512, 7, 7])
145145

146146
### Select specific feature levels or limit the stride
147147

148-
There are to additional creation arguments impacting the output features.
148+
There are two additional creation arguments impacting the output features.
149149

150150
* `out_indices` selects which indices to output
151151
* `output_stride` limits the feature output stride of the network (also works in classification mode BTW)

tests/test_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
EXCLUDE_FILTERS = [
3535
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
3636
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
37-
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*']
38-
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*']
37+
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
38+
'swin*giant*']
39+
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*']
3940
else:
4041
EXCLUDE_FILTERS = []
4142
NON_STD_EXCLUDE_FILTERS = ['vit_gi*']

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .senet import *
4242
from .sknet import *
4343
from .swin_transformer import *
44+
from .swin_transformer_v2_cr import *
4445
from .tnt import *
4546
from .tresnet import *
4647
from .twins import *

timm/models/layers/ml_decoder.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch import nn
5+
from torch import nn, Tensor
6+
from torch.nn.modules.transformer import _get_activation_fn
7+
8+
9+
def add_ml_decoder_head(model):
10+
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
11+
model.global_pool = nn.Identity()
12+
del model.fc
13+
num_classes = model.num_classes
14+
num_features = model.num_features
15+
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
16+
elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
17+
model.global_pool = nn.Identity()
18+
del model.classifier
19+
num_classes = model.num_classes
20+
num_features = model.num_features
21+
model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
22+
elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
23+
del model.head
24+
num_classes = model.num_classes
25+
num_features = model.num_features
26+
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
27+
else:
28+
print("Model code-writing is not aligned currently with ml-decoder")
29+
exit(-1)
30+
if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
31+
model.drop_rate = 0
32+
return model
33+
34+
35+
class TransformerDecoderLayerOptimal(nn.Module):
36+
def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu",
37+
layer_norm_eps=1e-5) -> None:
38+
super(TransformerDecoderLayerOptimal, self).__init__()
39+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
40+
self.dropout = nn.Dropout(dropout)
41+
self.dropout1 = nn.Dropout(dropout)
42+
self.dropout2 = nn.Dropout(dropout)
43+
self.dropout3 = nn.Dropout(dropout)
44+
45+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
46+
47+
# Implementation of Feedforward model
48+
self.linear1 = nn.Linear(d_model, dim_feedforward)
49+
self.linear2 = nn.Linear(dim_feedforward, d_model)
50+
51+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
52+
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
53+
54+
self.activation = _get_activation_fn(activation)
55+
56+
def __setstate__(self, state):
57+
if 'activation' not in state:
58+
state['activation'] = torch.nn.functional.relu
59+
super(TransformerDecoderLayerOptimal, self).__setstate__(state)
60+
61+
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
62+
memory_mask: Optional[Tensor] = None,
63+
tgt_key_padding_mask: Optional[Tensor] = None,
64+
memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
65+
tgt = tgt + self.dropout1(tgt)
66+
tgt = self.norm1(tgt)
67+
tgt2 = self.multihead_attn(tgt, memory, memory)[0]
68+
tgt = tgt + self.dropout2(tgt2)
69+
tgt = self.norm2(tgt)
70+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
71+
tgt = tgt + self.dropout3(tgt2)
72+
tgt = self.norm3(tgt)
73+
return tgt
74+
75+
76+
# @torch.jit.script
77+
# class ExtrapClasses(object):
78+
# def __init__(self, num_queries: int, group_size: int):
79+
# self.num_queries = num_queries
80+
# self.group_size = group_size
81+
#
82+
# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap:
83+
# torch.Tensor):
84+
# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size)
85+
# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups])
86+
# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size))
87+
# out = (h * w).sum(dim=2) + class_embed_b
88+
# out = out.view((h.shape[0], self.group_size * self.num_queries))
89+
# return out
90+
91+
@torch.jit.script
92+
class GroupFC(object):
93+
def __init__(self, embed_len_decoder: int):
94+
self.embed_len_decoder = embed_len_decoder
95+
96+
def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
97+
for i in range(self.embed_len_decoder):
98+
h_i = h[:, i, :]
99+
w_i = duplicate_pooling[i, :, :]
100+
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
101+
102+
103+
class MLDecoder(nn.Module):
104+
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
105+
super(MLDecoder, self).__init__()
106+
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
107+
if embed_len_decoder > num_classes:
108+
embed_len_decoder = num_classes
109+
110+
# switching to 768 initial embeddings
111+
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
112+
self.embed_standart = nn.Linear(initial_num_features, decoder_embedding)
113+
114+
# decoder
115+
decoder_dropout = 0.1
116+
num_layers_decoder = 1
117+
dim_feedforward = 2048
118+
layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding,
119+
dim_feedforward=dim_feedforward, dropout=decoder_dropout)
120+
self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder)
121+
122+
# non-learnable queries
123+
self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)
124+
self.query_embed.requires_grad_(False)
125+
126+
# group fully-connected
127+
self.num_classes = num_classes
128+
self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)
129+
self.duplicate_pooling = torch.nn.Parameter(
130+
torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor))
131+
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
132+
torch.nn.init.xavier_normal_(self.duplicate_pooling)
133+
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
134+
self.group_fc = GroupFC(embed_len_decoder)
135+
136+
def forward(self, x):
137+
if len(x.shape) == 4: # [bs,2048, 7,7]
138+
embedding_spatial = x.flatten(2).transpose(1, 2)
139+
else: # [bs, 197,468]
140+
embedding_spatial = x
141+
embedding_spatial_786 = self.embed_standart(embedding_spatial)
142+
embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
143+
144+
bs = embedding_spatial_786.shape[0]
145+
query_embed = self.query_embed.weight
146+
# tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
147+
tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand
148+
h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768]
149+
h = h.transpose(0, 1)
150+
151+
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
152+
self.group_fc(h, self.duplicate_pooling, out_extrap)
153+
h_out = out_extrap.flatten(1)[:, :self.num_classes]
154+
h_out += self.duplicate_pooling_bias
155+
logits = h_out
156+
return logits

0 commit comments

Comments
 (0)