Skip to content

Commit 3591fe9

Browse files
committed
Merge branch 'main' into fast_load
2 parents ad788be + 131518c commit 3591fe9

25 files changed

+949
-300
lines changed

README.md

+15-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,19 @@
1212

1313
## What's New
1414

15+
## Dec 31, 2024
16+
* `convnext_nano` 384x384 ImageNet-12k pretrain & fine-tune. https://huggingface.co/models?search=convnext_nano%20r384
17+
* Add AIM-v2 encoders from https://github.com/apple/ml-aim, see on Hub: https://huggingface.co/models?search=timm%20aimv2
18+
* Add PaliGemma2 encoders from https://github.com/google-research/big_vision to existing PaliGemma, see on Hub: https://huggingface.co/models?search=timm%20pali2
19+
* Add missing L/14 DFN2B 39B CLIP ViT, `vit_large_patch14_clip_224.dfn2b_s39b`
20+
* Fix existing RmsProp layer to match standard formulation, use PT 2.5 impl when possible. Move old impl to `SimpleNorm` layer, it's LN w/o centering or bias. There were only two `timm` models using it, and they have been updated.
21+
* Allow override of `cache_dir` arg for model creation
22+
* Pass through `trust_remote_code` for HF datasets wrapper
23+
* `inception_next_atto` model added by creator
24+
* Adan optimizer caution, and Lamb decoupled weighgt decay options
25+
* Some feature_info metadata fixed by https://github.com/brianhou0208
26+
* All OpenCLIP and JAX (CLIP, SigLIP, Pali, etc) model weights that used load time remapping were given their own HF Hub instances so that they work with `hf-hub:` based loading, and thus will work with new Transformers `TimmWrapperModel`
27+
1528
## Nov 28, 2024
1629
* More optimizers
1730
* Add MARS optimizer (https://arxiv.org/abs/2411.10438, https://github.com/AGI-Arena/MARS)
@@ -248,7 +261,7 @@ Add a set of new very well trained ResNet & ResNet-V2 18/34 (basic block) weight
248261
### April 11, 2024
249262
* Prepping for a long overdue 1.0 release, things have been stable for a while now.
250263
* Significant feature that's been missing for a while, `features_only=True` support for ViT models with flat hidden states or non-std module layouts (so far covering `'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'`)
251-
* Above feature support achieved through a new `forward_intermediates()` API that can be used with a feature wrapping module or direclty.
264+
* Above feature support achieved through a new `forward_intermediates()` API that can be used with a feature wrapping module or directly.
252265
```python
253266
model = timm.create_model('vit_base_patch16_224')
254267
final_feat, intermediates = model.forward_intermediates(input)
@@ -486,7 +499,7 @@ Included optimizers available via `timm.optim.create_optimizer_v2` factory metho
486499
* `madgrad` an implementation of MADGRAD adapted from https://github.com/facebookresearch/madgrad - https://arxiv.org/abs/2101.11075
487500
* `mars` MARS optimizer from https://github.com/AGI-Arena/MARS - https://arxiv.org/abs/2411.10438
488501
* `nadam` an implementation of Adam w/ Nesterov momentum
489-
* `nadamw` an impementation of AdamW (Adam w/ decoupled weight-decay) w/ Nesterov momentum. A simplified impl based on https://github.com/mlcommons/algorithmic-efficiency
502+
* `nadamw` an implementation of AdamW (Adam w/ decoupled weight-decay) w/ Nesterov momentum. A simplified impl based on https://github.com/mlcommons/algorithmic-efficiency
490503
* `novograd` by [Masashi Kimura](https://github.com/convergence-lab/novograd) - https://arxiv.org/abs/1905.11286
491504
* `radam` by [Liyuan Liu](https://github.com/LiyuanLucasLiu/RAdam) - https://arxiv.org/abs/1908.03265
492505
* `rmsprop_tf` adapted from PyTorch RMSProp by myself. Reproduces much improved Tensorflow RMSProp behaviour

UPGRADING.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Upgrading from previous versions
22

3-
I generally try to maintain code interface and especially model weight compability across many `timm` versions. Sometimes there are exceptions.
3+
I generally try to maintain code interface and especially model weight compatibility across many `timm` versions. Sometimes there are exceptions.
44

55
## Checkpoint remapping
66

7-
Pretrained weight remapping is handled by `checkpoint_filter_fn` in a model implementation module. This remaps old pretrained checkpoints to new, and also 3rd party (original) checkpoints to `timm` format if the model was modified when brough into `timm`.
7+
Pretrained weight remapping is handled by `checkpoint_filter_fn` in a model implementation module. This remaps old pretrained checkpoints to new, and also 3rd party (original) checkpoints to `timm` format if the model was modified when brought into `timm`.
88

99
The `checkpoint_filter_fn` is automatically called when loading pretrained weights via `pretrained=True`, but they can be called manually if you call the fn directly with the current model instance and old state dict.
1010

@@ -19,6 +19,6 @@ Many changes were made since the 0.6.x stable releases. They were previewed in 0
1919
* The pretrained_tag is the specific weight variant (different head) for the architecture.
2020
* Using only `architecture` defaults to the first weights in the default_cfgs for that model architecture.
2121
* In adding pretrained tags, many model names that existed to differentiate were renamed to use the tag (ex: `vit_base_patch16_224_in21k` -> `vit_base_patch16_224.augreg_in21k`). There are deprecation mappings for these.
22-
* A number of models had their checkpoints remaped to match architecture changes needed to better support `features_only=True`, there are `checkpoint_filter_fn` methods in any model module that was remapped. These can be passed to `timm.models.load_checkpoint(..., filter_fn=timm.models.swin_transformer_v2.checkpoint_filter_fn)` to remap your existing checkpoint.
22+
* A number of models had their checkpoints remapped to match architecture changes needed to better support `features_only=True`, there are `checkpoint_filter_fn` methods in any model module that was remapped. These can be passed to `timm.models.load_checkpoint(..., filter_fn=timm.models.swin_transformer_v2.checkpoint_filter_fn)` to remap your existing checkpoint.
2323
* The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license.
2424
* Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version.

hfdocs/source/quickstart.mdx

+8-8
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,14 @@ First we'll need an image to do inference on. Here we load a picture of a leaf f
164164
>>> import requests
165165
>>> from PIL import Image
166166
>>> from io import BytesIO
167-
>>> url = 'https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg'
167+
>>> url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg'
168168
>>> image = Image.open(requests.get(url, stream=True).raw)
169169
>>> image
170170
```
171171

172172
Here's the image we loaded:
173173

174-
<img src="https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg" alt="An Image from a link" width="300"/>
174+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg" alt="An Image from a link" width="300"/>
175175

176176
Now, we'll create our model and transforms again. This time, we make sure to set our model in evaluation mode.
177177

@@ -211,7 +211,7 @@ Now we'll find the top 5 predicted class indexes and values using `torch.topk`.
211211
```py
212212
>>> values, indices = torch.topk(probabilities, 5)
213213
>>> indices
214-
tensor([162, 166, 161, 164, 167])
214+
tensor([281, 282, 285, 673, 670])
215215
```
216216

217217
If we check the imagenet labels for the top index, we can see what the model predicted...
@@ -220,9 +220,9 @@ If we check the imagenet labels for the top index, we can see what the model pre
220220
>>> IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
221221
>>> IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
222222
>>> [{'label': IMAGENET_1k_LABELS[idx], 'value': val.item()} for val, idx in zip(values, indices)]
223-
[{'label': 'beagle', 'value': 0.8486220836639404},
224-
{'label': 'Walker_hound, Walker_foxhound', 'value': 0.03753996267914772},
225-
{'label': 'basset, basset_hound', 'value': 0.024628572165966034},
226-
{'label': 'bluetick', 'value': 0.010317106731235981},
227-
{'label': 'English_foxhound', 'value': 0.006958036217838526}]
223+
[{'label': 'tabby, tabby_cat', 'value': 0.5101025700569153},
224+
{'label': 'tiger_cat', 'value': 0.22490699589252472},
225+
{'label': 'Egyptian_cat', 'value': 0.1835290789604187},
226+
{'label': 'mouse, computer_mouse', 'value': 0.006752475164830685},
227+
{'label': 'motor_scooter, scooter', 'value': 0.004942195490002632}]
228228
```

tests/test_layers.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import pytest
12
import torch
23
import torch.nn as nn
34

4-
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
5+
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d, MultiQueryAttentionV2
56

67
import importlib
78
import os
@@ -119,3 +120,41 @@ def test_get_act_fn_none():
119120
assert get_act_fn(None) is None
120121
assert get_act_fn('') is None
121122

123+
124+
@pytest.mark.parametrize("dim", [128])
125+
@pytest.mark.parametrize("dim_out", [128, 256])
126+
@pytest.mark.parametrize("use_m", [True, False])
127+
def test_mqa_v2(dim, dim_out, use_m):
128+
mqa = MultiQueryAttentionV2(dim, dim_out)
129+
130+
x = torch.randn(1, dim, 32, 48)
131+
if use_m:
132+
m = torch.randn(1, dim, 16, 24)
133+
else:
134+
m = None
135+
136+
y = mqa(x, m=m)
137+
138+
assert (y.shape) == (1, dim_out, 32, 48)
139+
140+
141+
@pytest.mark.parametrize("bias", [True, False])
142+
@pytest.mark.parametrize("expand_first", [True, False])
143+
@pytest.mark.parametrize("head_first", [True, False])
144+
@pytest.mark.parametrize("attn_mask", [True, False])
145+
def test_attn2d(bias, expand_first, head_first, attn_mask):
146+
x = torch.randn(1, 128, 32, 48)
147+
attn = Attention2d(
148+
128, 128, num_heads=4, bias=bias, expand_first=expand_first, head_first=head_first
149+
)
150+
151+
if attn_mask:
152+
mask = torch.randint(0, 1, size=(32 * 48, 32 * 48), dtype=torch.float32)
153+
else:
154+
mask = None
155+
156+
o1 = attn(x, mask)
157+
attn.fused_attn = False
158+
o2 = attn(x, mask)
159+
160+
assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"

tests/test_models.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@
5353
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5454
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
5555
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
56-
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2'
56+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*'
5757
]
5858

5959
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
6060
NON_STD_FILTERS = [
6161
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
62-
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
62+
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*',
6363
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
6464
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
6565
]
@@ -72,11 +72,11 @@
7272
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
7373
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
7474
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', '*huge*', '*giant*', '*gigantic*',
75-
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560']
76-
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*']
75+
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*']
76+
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*']
7777
else:
7878
EXCLUDE_FILTERS = ['*enormous*']
79-
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']
79+
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*']
8080

8181
EXCLUDE_JIT_FILTERS = ['hiera_*']
8282

timm/layers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .mixed_conv2d import MixedConv2d
3535
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
3636
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
37-
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
37+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
3838
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
3939
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
4040
from .padding import get_padding, get_same_padding, pad_same

timm/layers/attention2d.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,24 @@ def _reshape_input(self, t):
5959

6060
def forward(self, x, m: Optional[torch.Tensor] = None):
6161
"""Run layer computation."""
62-
s = x.shape
63-
m = m or x
62+
b, _, h, w = x.shape
63+
m = m if m is not None else x
6464

6565
reshaped_x = self._reshape_input(x)
6666
reshaped_m = self._reshape_input(m)
6767

6868
q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
6969
k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)
7070

71-
attn = torch.einsum('bnhk,bmk->bnhm', q, k)
71+
attn = torch.einsum('bnhk,bmk->bnhm', q, k) * self.scale
7272
attn = attn.softmax(dim=-1)
7373
attn = self.attn_drop(attn)
7474

7575
v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
7676
o = torch.einsum('bnhm,bmv->bnhv', attn, v)
77-
result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj)
77+
result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj)
7878
result = self.proj_drop(result)
79-
return result.reshape(s)
79+
return result.reshape(b, -1, h, w)
8080

8181

8282
class MultiQueryAttention2d(nn.Module):
@@ -312,7 +312,6 @@ def __init__(
312312
self.num_heads = num_heads
313313
self.dim_head = dim_attn // num_heads
314314
self.head_first = head_first
315-
self.scale = num_heads ** -0.5
316315
self.fused_attn = use_fused_attn()
317316

318317
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
@@ -337,14 +336,15 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
337336
dropout_p=self.attn_drop.p if self.training else 0.,
338337
).transpose(-1, -2).reshape(B, -1, H, W)
339338
else:
340-
q = q * self.scale
341-
attn = q.transpose(-2, -1) @ k
339+
q = q.transpose(-1, -2)
340+
v = v.transpose(-1, -2)
341+
attn = q @ k * q.size(-1) ** -0.5
342342
if attn_mask is not None:
343343
# NOTE: assumes mask is float and in correct shape
344344
attn = attn + attn_mask
345345
attn = attn.softmax(dim=-1)
346346
attn = self.attn_drop(attn)
347-
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
347+
x = (attn @ v).transpose(-1, -2).reshape(B, -1, H, W)
348348

349349
x = self.proj(x)
350350
x = self.proj_drop(x)

timm/layers/create_norm.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch.nn as nn
1212

13-
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
13+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
1414
from torchvision.ops.misc import FrozenBatchNorm2d
1515

1616
_NORM_MAP = dict(
@@ -23,6 +23,8 @@
2323
layernorm2d=LayerNorm2d,
2424
rmsnorm=RmsNorm,
2525
rmsnorm2d=RmsNorm2d,
26+
simplenorm=SimpleNorm,
27+
simplenorm2d=SimpleNorm2d,
2628
frozenbatchnorm2d=FrozenBatchNorm2d,
2729
)
2830
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}

timm/layers/fast_norm.py

+62-6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
has_apex_rmsnorm = False
2525

2626

27+
has_torch_rms_norm = hasattr(F, 'rms_norm')
28+
2729
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
2830
_USE_FAST_NORM = False # defaulting to False for now
2931

@@ -75,7 +77,6 @@ def fast_group_norm(
7577
if is_autocast_enabled(x.device.type):
7678
# normally native AMP casts GN inputs to float32
7779
# here we use the low precision autocast dtype
78-
# FIXME what to do re CPU autocast?
7980
dt = get_autocast_dtype(x.device.type)
8081
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
8182

@@ -101,7 +102,6 @@ def fast_layer_norm(
101102
# normally native AMP casts LN inputs to float32
102103
# apex LN does not, this is behaving like Apex
103104
dt = get_autocast_dtype(x.device.type)
104-
# FIXME what to do re CPU autocast?
105105
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
106106

107107
with torch.amp.autocast(device_type=x.device.type, enabled=False):
@@ -115,15 +115,16 @@ def rms_norm(
115115
eps: float = 1e-5,
116116
):
117117
norm_ndim = len(normalized_shape)
118+
v = x.pow(2)
118119
if torch.jit.is_scripting():
119120
# ndim = len(x.shape)
120121
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
121122
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
122123
assert norm_ndim == 1
123-
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
124+
v = torch.mean(v, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
124125
else:
125126
dims = tuple(range(-1, -norm_ndim - 1, -1))
126-
v = torch.var(x, dim=dims, keepdim=True)
127+
v = torch.mean(v, dim=dims, keepdim=True)
127128
x = x * torch.rsqrt(v + eps)
128129
if weight is not None:
129130
x = x * weight
@@ -146,5 +147,60 @@ def fast_rms_norm(
146147
else:
147148
return fused_rms_norm_affine(x, weight, normalized_shape, eps)
148149

149-
# fallback
150-
return rms_norm(x, normalized_shape, weight, eps)
150+
if is_autocast_enabled(x.device.type):
151+
# normally native AMP casts LN inputs to float32
152+
# apex LN does not, this is behaving like Apex
153+
dt = get_autocast_dtype(x.device.type)
154+
x, weight = x.to(dt), weight.to(dt)
155+
156+
with torch.amp.autocast(device_type=x.device.type, enabled=False):
157+
if has_torch_rms_norm:
158+
x = F.rms_norm(x, normalized_shape, weight, eps)
159+
else:
160+
x = rms_norm(x, normalized_shape, weight, eps)
161+
162+
return x
163+
164+
165+
def simple_norm(
166+
x: torch.Tensor,
167+
normalized_shape: List[int],
168+
weight: Optional[torch.Tensor] = None,
169+
eps: float = 1e-5,
170+
):
171+
norm_ndim = len(normalized_shape)
172+
if torch.jit.is_scripting():
173+
# ndim = len(x.shape)
174+
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
175+
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
176+
assert norm_ndim == 1
177+
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
178+
else:
179+
dims = tuple(range(-1, -norm_ndim - 1, -1))
180+
v = torch.var(x, dim=dims, keepdim=True)
181+
x = x * torch.rsqrt(v + eps)
182+
if weight is not None:
183+
x = x * weight
184+
return x
185+
186+
187+
def fast_simple_norm(
188+
x: torch.Tensor,
189+
normalized_shape: List[int],
190+
weight: Optional[torch.Tensor] = None,
191+
eps: float = 1e-5,
192+
) -> torch.Tensor:
193+
if torch.jit.is_scripting():
194+
# this must be by itself, cannot merge with has_apex_rmsnorm
195+
return simple_norm(x, normalized_shape, weight, eps)
196+
197+
if is_autocast_enabled(x.device.type):
198+
# normally native AMP casts LN inputs to float32
199+
# apex LN does not, this is behaving like Apex
200+
dt = get_autocast_dtype(x.device.type)
201+
x, weight = x.to(dt), weight.to(dt)
202+
203+
with torch.amp.autocast(device_type=x.device.type, enabled=False):
204+
x = simple_norm(x, normalized_shape, weight, eps)
205+
return x
206+

0 commit comments

Comments
 (0)