-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
swin v2 adding padding to shifted window attention breaks the algorithm #2438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
cc @rwightman if you have a chance to look at this |
@alita-moore hmm, yeah, might be a concern, have you compared the results... force a situation where the padding is needed (it's not usually active) and then see how the accuracy compares with the padding before vs after shift? |
I'll give it a try, but I don't think it's a controlled experiment because applying the padding before is not what the model is trained on. |
The performance is worse when applying the padding before. I don't have a very scientific setup right now, though. I'll see what I can do about training with padding before instead of after. And then compare the performance. |
@alita-moore the models weren't trained with that padding. It won't be active unless you use resize inputs, set |
sorry, I didn't clarify. I'm training a separate model which is not using the pre-trained models. Our resolution is such that the padding is applied, so it affects our model. I'm currently running training and adding padding before the shift, I'll update once we have the model trained for a few hours. |
@alita-moore okay, thanks for clarifying, makes more sense. It would also be possible to check with the normal pretrained weights, they'll still validate reasonably if the sizes that force the padding aren't too far off the original. Since the original wasn't trained with any padding, you'd get a signal how it impacts... and should see one validating worse. |
we had to update how the padding reversals are applied as well; here's the current version we're using: def _attn(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, C = x.shape
pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
x = torch.nn.functional.pad(x, (0, 0, 0, pad_w, 0, pad_h))
_, Hp, Wp, _ = x.shape
# cyclic shift
has_shift = any(self.shift_size)
if has_shift:
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
if getattr(self, 'dynamic_mask', False):
attn_mask = self.get_attn_mask(shifted_x)
else:
attn_mask = self.attn_mask
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
shifted_x = window_reverse(attn_windows, self.window_size, (Hp, Wp)) # B H' W' C
# reverse cyclic shift
if has_shift:
shifted_x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
else:
shifted_x = shifted_x
x = shifted_x[:, :H, :W, :].contiguous()
return x |
after further investigation I've determined that the reason we were seeing differences in performance was primarily related to the fact I had initially not flipped the order in which the roll is reversed and then padding is indexed. With the updated implementation there are small but largely insignificant changes in performance. We trained with the updated code and have found that it does not perform significantly differently, we're still training so this might change at some point as well. Regardless, it's still probably worth fixing. |
shouldn't the attention mask be adjusted in either case? feel ensuring the mask makes sense is more important than pad before vs after shift, as both will alter the validity of the mask no? If there is little to no measurable impact than it shouldn't be fixed as it breaks backward compat. |
what do you mean exactly by attention mask in this case? |
Here:
pytorch-image-models/timm/models/swin_transformer_v2.py
Lines 380 to 383 in a49b020
pytorch-image-models/timm/models/swin_transformer_v2.py
Line 292 in a49b020
shift_size
. However, the applied padding would offset these values such that the generated mask does not contain the shifted values. Meaning, patches are being included in attention calculation when they should not be.consider the following x
after shifting the windows you get
if the window size is 2 then it would apply padding like so
because the shifted window attention mask is calculated from x at this point the calculated attn mask would only mask out the added padding tokens not the shifted values. In this particular example, the shifted values do not attend to each other inappropriately, but in the case of a larger grid (e.g. 3x3) you would see cases where tokens such as 7 might attend to a token such as 1.
The text was updated successfully, but these errors were encountered: