Skip to content

Commit c527c37

Browse files
committed
Optimizations for pos embed resize, merge different mask helper fns
1 parent ea728f6 commit c527c37

File tree

1 file changed

+66
-83
lines changed

1 file changed

+66
-83
lines changed

timm/models/vision_transformer_flex.py

+66-83
Original file line numberDiff line numberDiff line change
@@ -297,131 +297,109 @@ def _apply_learned_naflex_pos_embed(
297297
size_to_indices[k].append(bi)
298298

299299
# Handle each batch element separately with its own grid size
300+
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2) # B,C,H,W
300301
for k, batch_indices in size_to_indices.items():
301302
h, w = k
302303
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
303304
# Interpolate only once for this (h, w)
304305
if (h == orig_h) and (w == orig_w):
305-
pos_embed_flat = self.pos_embed.reshape(orig_h * orig_w, -1)
306+
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
306307
else:
307-
pos_embed_resized = F.interpolate(
308-
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
308+
pos_embed_flat = F.interpolate(
309+
pos_embed_nchw,
309310
size=(h, w),
310311
mode=self.pos_embed_interp_mode,
311312
align_corners=False,
312313
antialias=True,
313-
)
314-
pos_embed_flat = pos_embed_resized.permute(0, 2, 3, 1).reshape(h * w, -1)
314+
).flatten(2).transpose(1, 2)
315315

316-
seq_len = min(x.shape[1], pos_embed_flat.shape[0])
317-
x[batch_indices, :seq_len].add_(pos_embed_flat[:seq_len])
316+
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
317+
x[batch_indices, :seq_len].add_(pos_embed_flat[:, :seq_len])
318318

319319
def _apply_learned_pos_embed(
320320
self,
321321
x: torch.Tensor,
322322
grid_size: List[int],
323323
):
324324
orig_h, orig_w = self.pos_embed.shape[1:3]
325-
if grid_size[0] != orig_h or grid_size[1] != orig_w:
325+
if grid_size[0] == orig_h or grid_size[1] == orig_w:
326+
# No resize needed, just flatten
327+
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
328+
else:
326329
# Resize if needed - directly using F.interpolate
327-
pos_embed = F.interpolate(
330+
pos_embed_flat = F.interpolate(
328331
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
329332
size=grid_size,
330333
mode=self.pos_embed_interp_mode,
331334
align_corners=False,
332335
antialias=True,
333-
)
334-
# Convert back and flatten
335-
pos_embed = pos_embed.permute(0, 2, 3, 1)
336-
pos_embed = pos_embed.reshape(1, grid_size[0] * grid_size[1], -1)
337-
338-
else:
339-
# No resize needed, just flatten
340-
pos_embed = self.pos_embed.reshape(1, orig_h * orig_w, -1)
336+
).flatten(2).transpose(1, 2)
341337

342-
x.add_(pos_embed)
338+
x.add_(pos_embed_flat)
343339

344340

345341
@register_notrace_function
346342
def create_attention_mask(
347-
patch_valid: torch.Tensor,
348-
num_prefix_tokens: int = 0,
349-
dtype: torch.dtype = torch.float32,
350-
) -> torch.Tensor:
351-
"""Create attention mask from patch type information.
352-
353-
Used for NaFlex mode to handle variable token counts and padding tokens.
354-
355-
Args:
356-
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
357-
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
358-
dtype: Dtype of the attention mask
359-
360-
Returns:
361-
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
362-
or None if patch_type is None
363-
"""
364-
patch_valid = patch_valid.to(torch.bool)
365-
B = patch_valid.shape[0]
366-
367-
if num_prefix_tokens > 0:
368-
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
369-
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
370-
371-
mask_bool = (patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)).unsqueeze(1)
372-
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
373-
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
374-
375-
return mask_float
376-
377-
378-
@register_notrace_function
379-
def create_attention_mask2(
380343
patch_valid: torch.Tensor,
381344
num_prefix_tokens: int = 0,
345+
symmetric: bool = True,
382346
q_len: Optional[int] = None,
383347
dtype: torch.dtype = torch.float32,
384-
) -> Optional[torch.Tensor]:
385-
"""Create expanded attention mask from patch validity info.
348+
) -> torch.Tensor:
349+
"""Creates an attention mask from patch validity information.
350+
351+
Supports two modes controlled by `symmetric`:
352+
1. `symmetric=True` (default): Creates a symmetric mask of shape
353+
[B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if
354+
both token i and token j are valid. Suitable for standard self-attention.
355+
2. `symmetric=False`: Creates a potentially non-square mask of shape
356+
[B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if
357+
the key/value token k is valid. Query token validity is not checked
358+
in the mask itself. Useful for cross-attention or specific self-attention
359+
implementations `q_len` can be specified.
386360
387361
Used for NaFlex mode to handle variable token counts and padding tokens.
388362
389363
Args:
390-
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
364+
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding.
391365
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
392-
q_len: Length override for query sequence
393-
dtype: Dtype of the attention mask
366+
to prepend, which are always considered valid.
367+
symmetric: If True, create a symmetric mask.
368+
If False, create an expanded mask based only on key/value validity.
369+
q_len: Query sequence length override. Only used when `symmetric` is False.
370+
Defaults to the key/value sequence length (`kv_len`) if None.
371+
dtype: Dtype of the output attention mask (e.g., torch.float32).
394372
395373
Returns:
396-
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
397-
or None if patch_type is None
374+
Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked).
375+
Shape is [B, 1, seq_len, seq_len] if symmetric=True,
376+
or [B, 1, q_len, kv_len] if symmetric=False.
398377
"""
399-
patch_valid = patch_valid.bool()
400-
B, kv_len = patch_valid.shape
378+
patch_valid = patch_valid.bool() # Ensure boolean type
379+
B, N = patch_valid.shape
380+
kv_len = N # Initial key/value length is the number of patches
401381

382+
# Prepend prefix tokens if any
402383
if num_prefix_tokens > 0:
403-
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
384+
# Create prefix validity tensor on the same device/dtype base as patch_valid
385+
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool)
386+
# Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N]
404387
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
405-
kv_len = patch_valid.shape[1]
406-
407-
q_len = q_len if q_len is not None else kv_len
408-
409-
mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len).to(dtype)
410-
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
411-
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
412-
413-
return mask_float
388+
kv_len += num_prefix_tokens # Update total key/value sequence length
414389

390+
if symmetric:
391+
# Symmetric mask is True where BOTH query and key are valid
392+
mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)
393+
mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len]
394+
else:
395+
# Expanded mask
396+
q_len = q_len or kv_len
397+
mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len)
415398

416-
@register_notrace_function
417-
def create_pool_mask(
418-
patch_valid:torch.Tensor,
419-
dtype: torch.dtype = torch.float32,
420-
) -> torch.Tensor:
421-
patch_valid = patch_valid.bool()
422-
mask_bool = patch_valid[:, None, None, :]
399+
# Create the float mask and apply masking using additive mask convention
423400
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
424-
mask_float.masked_fill_(~mask_bool, torch.finfo(mask_float.dtype).min)
401+
# Fill with negative infinity where mask_bool is False (masked positions)
402+
mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min)
425403

426404
return mask_float
427405

@@ -809,7 +787,12 @@ def _pool(
809787
) -> torch.Tensor:
810788
if self.attn_pool is not None:
811789
# For attention pooling, we need to pass the mask for NaFlex models
812-
attn_mask = create_pool_mask(patch_valid, dtype=x.dtype)
790+
attn_mask = create_attention_mask(
791+
patch_valid,
792+
symmetric=False,
793+
q_len=1,
794+
dtype=x.dtype,
795+
)
813796
x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask)
814797
return x
815798

@@ -839,7 +822,7 @@ def _pool(
839822

840823
# For max pooling with mask
841824
masked_x = x.clone()
842-
masked_x[~patch_valid] = -1e4 # torch.finfo(masked_x.dtype).min
825+
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
843826
masked_max = masked_x.max(dim=1)[0]
844827

845828
# Combine average and max
@@ -876,9 +859,7 @@ def forward(
876859
Returns:
877860
Model output tensor
878861
"""
879-
if isinstance(x, torch.Tensor):
880-
patches = x
881-
else:
862+
if isinstance(x, Dict):
882863
# Handle dictionary input from NaFlex collator
883864
patch_coord = x['patch_coord']
884865
patch_valid = x['patch_valid']
@@ -893,6 +874,8 @@ def forward(
893874
# patch = patch.reshape(3, h*16, w*16)
894875
# from torchvision.utils import save_image
895876
# save_image(patch, f'patch_{i}.jpg', normalize=True)
877+
else:
878+
patches = x
896879

897880
# Create attention mask if patch_type is provided
898881
if patch_valid is not None:

0 commit comments

Comments
 (0)