@@ -297,131 +297,109 @@ def _apply_learned_naflex_pos_embed(
297
297
size_to_indices [k ].append (bi )
298
298
299
299
# Handle each batch element separately with its own grid size
300
+ pos_embed_nchw = self .pos_embed .permute (0 , 3 , 1 , 2 ) # B,C,H,W
300
301
for k , batch_indices in size_to_indices .items ():
301
302
h , w = k
302
303
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
303
304
# Interpolate only once for this (h, w)
304
305
if (h == orig_h ) and (w == orig_w ):
305
- pos_embed_flat = self .pos_embed .reshape (orig_h * orig_w , - 1 )
306
+ pos_embed_flat = self .pos_embed .reshape (1 , orig_h * orig_w , - 1 )
306
307
else :
307
- pos_embed_resized = F .interpolate (
308
- self . pos_embed . permute ( 0 , 3 , 1 , 2 ), # B,C,H,W
308
+ pos_embed_flat = F .interpolate (
309
+ pos_embed_nchw ,
309
310
size = (h , w ),
310
311
mode = self .pos_embed_interp_mode ,
311
312
align_corners = False ,
312
313
antialias = True ,
313
- )
314
- pos_embed_flat = pos_embed_resized .permute (0 , 2 , 3 , 1 ).reshape (h * w , - 1 )
314
+ ).flatten (2 ).transpose (1 , 2 )
315
315
316
- seq_len = min (x .shape [1 ], pos_embed_flat .shape [0 ])
317
- x [batch_indices , :seq_len ].add_ (pos_embed_flat [:seq_len ])
316
+ seq_len = min (x .shape [1 ], pos_embed_flat .shape [1 ])
317
+ x [batch_indices , :seq_len ].add_ (pos_embed_flat [:, : seq_len ])
318
318
319
319
def _apply_learned_pos_embed (
320
320
self ,
321
321
x : torch .Tensor ,
322
322
grid_size : List [int ],
323
323
):
324
324
orig_h , orig_w = self .pos_embed .shape [1 :3 ]
325
- if grid_size [0 ] != orig_h or grid_size [1 ] != orig_w :
325
+ if grid_size [0 ] == orig_h or grid_size [1 ] == orig_w :
326
+ # No resize needed, just flatten
327
+ pos_embed_flat = self .pos_embed .reshape (1 , orig_h * orig_w , - 1 )
328
+ else :
326
329
# Resize if needed - directly using F.interpolate
327
- pos_embed = F .interpolate (
330
+ pos_embed_flat = F .interpolate (
328
331
self .pos_embed .permute (0 , 3 , 1 , 2 ), # B,C,H,W
329
332
size = grid_size ,
330
333
mode = self .pos_embed_interp_mode ,
331
334
align_corners = False ,
332
335
antialias = True ,
333
- )
334
- # Convert back and flatten
335
- pos_embed = pos_embed .permute (0 , 2 , 3 , 1 )
336
- pos_embed = pos_embed .reshape (1 , grid_size [0 ] * grid_size [1 ], - 1 )
337
-
338
- else :
339
- # No resize needed, just flatten
340
- pos_embed = self .pos_embed .reshape (1 , orig_h * orig_w , - 1 )
336
+ ).flatten (2 ).transpose (1 , 2 )
341
337
342
- x .add_ (pos_embed )
338
+ x .add_ (pos_embed_flat )
343
339
344
340
345
341
@register_notrace_function
346
342
def create_attention_mask (
347
- patch_valid : torch .Tensor ,
348
- num_prefix_tokens : int = 0 ,
349
- dtype : torch .dtype = torch .float32 ,
350
- ) -> torch .Tensor :
351
- """Create attention mask from patch type information.
352
-
353
- Used for NaFlex mode to handle variable token counts and padding tokens.
354
-
355
- Args:
356
- patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
357
- num_prefix_tokens: Number of prefix tokens (class token, register tokens)
358
- dtype: Dtype of the attention mask
359
-
360
- Returns:
361
- Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
362
- or None if patch_type is None
363
- """
364
- patch_valid = patch_valid .to (torch .bool )
365
- B = patch_valid .shape [0 ]
366
-
367
- if num_prefix_tokens > 0 :
368
- prefix_valid = patch_valid .new_ones ((B , num_prefix_tokens ))
369
- patch_valid = torch .cat ([prefix_valid , patch_valid ], dim = 1 )
370
-
371
- mask_bool = (patch_valid .unsqueeze (- 1 ) & patch_valid .unsqueeze (1 )).unsqueeze (1 )
372
- mask_float = torch .zeros_like (mask_bool , dtype = dtype )
373
- mask_float .masked_fill_ (~ mask_bool , torch .finfo (mask_float .dtype ).min )
374
-
375
- return mask_float
376
-
377
-
378
- @register_notrace_function
379
- def create_attention_mask2 (
380
343
patch_valid : torch .Tensor ,
381
344
num_prefix_tokens : int = 0 ,
345
+ symmetric : bool = True ,
382
346
q_len : Optional [int ] = None ,
383
347
dtype : torch .dtype = torch .float32 ,
384
- ) -> Optional [torch .Tensor ]:
385
- """Create expanded attention mask from patch validity info.
348
+ ) -> torch .Tensor :
349
+ """Creates an attention mask from patch validity information.
350
+
351
+ Supports two modes controlled by `symmetric`:
352
+ 1. `symmetric=True` (default): Creates a symmetric mask of shape
353
+ [B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if
354
+ both token i and token j are valid. Suitable for standard self-attention.
355
+ 2. `symmetric=False`: Creates a potentially non-square mask of shape
356
+ [B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if
357
+ the key/value token k is valid. Query token validity is not checked
358
+ in the mask itself. Useful for cross-attention or specific self-attention
359
+ implementations `q_len` can be specified.
386
360
387
361
Used for NaFlex mode to handle variable token counts and padding tokens.
388
362
389
363
Args:
390
- patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding
364
+ patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding.
391
365
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
392
- q_len: Length override for query sequence
393
- dtype: Dtype of the attention mask
366
+ to prepend, which are always considered valid.
367
+ symmetric: If True, create a symmetric mask.
368
+ If False, create an expanded mask based only on key/value validity.
369
+ q_len: Query sequence length override. Only used when `symmetric` is False.
370
+ Defaults to the key/value sequence length (`kv_len`) if None.
371
+ dtype: Dtype of the output attention mask (e.g., torch.float32).
394
372
395
373
Returns:
396
- Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
397
- or None if patch_type is None
374
+ Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked).
375
+ Shape is [B, 1, seq_len, seq_len] if symmetric=True,
376
+ or [B, 1, q_len, kv_len] if symmetric=False.
398
377
"""
399
- patch_valid = patch_valid .bool ()
400
- B , kv_len = patch_valid .shape
378
+ patch_valid = patch_valid .bool () # Ensure boolean type
379
+ B , N = patch_valid .shape
380
+ kv_len = N # Initial key/value length is the number of patches
401
381
382
+ # Prepend prefix tokens if any
402
383
if num_prefix_tokens > 0 :
403
- prefix_valid = patch_valid .new_ones ((B , num_prefix_tokens ))
384
+ # Create prefix validity tensor on the same device/dtype base as patch_valid
385
+ prefix_valid = patch_valid .new_ones ((B , num_prefix_tokens ), dtype = torch .bool )
386
+ # Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N]
404
387
patch_valid = torch .cat ([prefix_valid , patch_valid ], dim = 1 )
405
- kv_len = patch_valid .shape [1 ]
406
-
407
- q_len = q_len if q_len is not None else kv_len
408
-
409
- mask_bool = patch_valid [:, None , None , :].expand (B , 1 , q_len , kv_len ).to (dtype )
410
- mask_float = torch .zeros_like (mask_bool , dtype = dtype )
411
- mask_float .masked_fill_ (~ mask_bool , torch .finfo (mask_float .dtype ).min )
412
-
413
- return mask_float
388
+ kv_len += num_prefix_tokens # Update total key/value sequence length
414
389
390
+ if symmetric :
391
+ # Symmetric mask is True where BOTH query and key are valid
392
+ mask_bool = patch_valid .unsqueeze (- 1 ) & patch_valid .unsqueeze (1 )
393
+ mask_bool = mask_bool .unsqueeze (1 ) # Add head dimension: [B, 1, seq_len, seq_len]
394
+ else :
395
+ # Expanded mask
396
+ q_len = q_len or kv_len
397
+ mask_bool = patch_valid [:, None , None , :].expand (B , 1 , q_len , kv_len )
415
398
416
- @register_notrace_function
417
- def create_pool_mask (
418
- patch_valid :torch .Tensor ,
419
- dtype : torch .dtype = torch .float32 ,
420
- ) -> torch .Tensor :
421
- patch_valid = patch_valid .bool ()
422
- mask_bool = patch_valid [:, None , None , :]
399
+ # Create the float mask and apply masking using additive mask convention
423
400
mask_float = torch .zeros_like (mask_bool , dtype = dtype )
424
- mask_float .masked_fill_ (~ mask_bool , torch .finfo (mask_float .dtype ).min )
401
+ # Fill with negative infinity where mask_bool is False (masked positions)
402
+ mask_float .masked_fill_ (~ mask_bool , torch .finfo (dtype ).min )
425
403
426
404
return mask_float
427
405
@@ -809,7 +787,12 @@ def _pool(
809
787
) -> torch .Tensor :
810
788
if self .attn_pool is not None :
811
789
# For attention pooling, we need to pass the mask for NaFlex models
812
- attn_mask = create_pool_mask (patch_valid , dtype = x .dtype )
790
+ attn_mask = create_attention_mask (
791
+ patch_valid ,
792
+ symmetric = False ,
793
+ q_len = 1 ,
794
+ dtype = x .dtype ,
795
+ )
813
796
x = self .attn_pool (x [:, self .num_prefix_tokens :], attn_mask = attn_mask )
814
797
return x
815
798
@@ -839,7 +822,7 @@ def _pool(
839
822
840
823
# For max pooling with mask
841
824
masked_x = x .clone ()
842
- masked_x [~ patch_valid ] = - 1e4 # torch.finfo(masked_x.dtype).min
825
+ masked_x [~ patch_valid ] = torch .finfo (masked_x .dtype ).min
843
826
masked_max = masked_x .max (dim = 1 )[0 ]
844
827
845
828
# Combine average and max
@@ -876,9 +859,7 @@ def forward(
876
859
Returns:
877
860
Model output tensor
878
861
"""
879
- if isinstance (x , torch .Tensor ):
880
- patches = x
881
- else :
862
+ if isinstance (x , Dict ):
882
863
# Handle dictionary input from NaFlex collator
883
864
patch_coord = x ['patch_coord' ]
884
865
patch_valid = x ['patch_valid' ]
@@ -893,6 +874,8 @@ def forward(
893
874
# patch = patch.reshape(3, h*16, w*16)
894
875
# from torchvision.utils import save_image
895
876
# save_image(patch, f'patch_{i}.jpg', normalize=True)
877
+ else :
878
+ patches = x
896
879
897
880
# Create attention mask if patch_type is provided
898
881
if patch_valid is not None :
0 commit comments