Skip to content

swin v2 adding padding to shifted window attention breaks the algorithm #2438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
alita-moore opened this issue Feb 11, 2025 · 11 comments
Open

Comments

@alita-moore
Copy link

alita-moore commented Feb 11, 2025

Here:

pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
_, Hp, Wp, _ = shifted_x.shape
the shifted window attention is applying padding after shifting the values.
def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
assumes that the window has been rolled by the shift_size. However, the applied padding would offset these values such that the generated mask does not contain the shifted values. Meaning, patches are being included in attention calculation when they should not be.

consider the following x

x = [[1,2,3],[4,5,6],[7,8,9]]

after shifting the windows you get

x = [[2,3,1],[5,6,4],[8,9,7]]

if the window size is 2 then it would apply padding like so

x = [[2,3,1, 0],[5,6,4, 0],[8,9,7,0], [0,0,0,0]]

because the shifted window attention mask is calculated from x at this point the calculated attn mask would only mask out the added padding tokens not the shifted values. In this particular example, the shifted values do not attend to each other inappropriately, but in the case of a larger grid (e.g. 3x3) you would see cases where tokens such as 7 might attend to a token such as 1.

@alita-moore
Copy link
Author

cc @rwightman if you have a chance to look at this

@rwightman
Copy link
Collaborator

@alita-moore hmm, yeah, might be a concern, have you compared the results... force a situation where the padding is needed (it's not usually active) and then see how the accuracy compares with the padding before vs after shift?

@alita-moore
Copy link
Author

I'll give it a try, but I don't think it's a controlled experiment because applying the padding before is not what the model is trained on.

@alita-moore
Copy link
Author

The performance is worse when applying the padding before. I don't have a very scientific setup right now, though. I'll see what I can do about training with padding before instead of after. And then compare the performance.

@rwightman
Copy link
Collaborator

@alita-moore the models weren't trained with that padding. It won't be active unless you use resize inputs, set strict_img_size=False, always_partition=True, etc... these are non-standard settings to allow flexibility for some applications but they were not used for the default training sizes, those had windows that lined up with feature map sizes

@alita-moore
Copy link
Author

sorry, I didn't clarify. I'm training a separate model which is not using the pre-trained models. Our resolution is such that the padding is applied, so it affects our model. I'm currently running training and adding padding before the shift, I'll update once we have the model trained for a few hours.

@rwightman
Copy link
Collaborator

@alita-moore okay, thanks for clarifying, makes more sense. It would also be possible to check with the normal pretrained weights, they'll still validate reasonably if the sizes that force the padding aren't too far off the original. Since the original wasn't trained with any padding, you'd get a signal how it impacts... and should see one validating worse.

@alita-moore
Copy link
Author

alita-moore commented Feb 11, 2025

we had to update how the padding reversals are applied as well; here's the current version we're using:

def _attn(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, C = x.shape

        pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
        pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
        x = torch.nn.functional.pad(x, (0, 0, 0, pad_w, 0, pad_h))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        has_shift = any(self.shift_size)
        if has_shift:
            shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_area, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        if getattr(self, 'dynamic_mask', False):
            attn_mask = self.get_attn_mask(shifted_x)
        else:
            attn_mask = self.attn_mask
        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
        shifted_x = window_reverse(attn_windows, self.window_size, (Hp, Wp))  # B H' W' C

        # reverse cyclic shift
        if has_shift:
            shifted_x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
        else:
            shifted_x = shifted_x

        x = shifted_x[:, :H, :W, :].contiguous()

        return x

@alita-moore
Copy link
Author

after further investigation I've determined that the reason we were seeing differences in performance was primarily related to the fact I had initially not flipped the order in which the roll is reversed and then padding is indexed.

With the updated implementation there are small but largely insignificant changes in performance. We trained with the updated code and have found that it does not perform significantly differently, we're still training so this might change at some point as well.

Regardless, it's still probably worth fixing.

@rwightman
Copy link
Collaborator

rwightman commented Feb 12, 2025

shouldn't the attention mask be adjusted in either case? feel ensuring the mask makes sense is more important than pad before vs after shift, as both will alter the validity of the mask no?

If there is little to no measurable impact than it shouldn't be fixed as it breaks backward compat.

@alita-moore
Copy link
Author

what do you mean exactly by attention mask in this case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants