diff --git a/tests/test_models.py b/tests/test_models.py index 3ba3615db4..3a83b927af 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -78,7 +78,7 @@ EXCLUDE_FILTERS = ['*enormous*'] NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*'] -EXCLUDE_JIT_FILTERS = ['hiera_*'] +EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*'] TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_BWD_SIZE = 128 diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 4b95fbd174..3eba2193da 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -8,6 +8,15 @@ from .imagenet_info import ImageNetInfo, infer_imagenet_subset from .loader import create_loader from .mixup import Mixup, FastCollateMixup +from .naflex_dataset import VariableSeqMapWrapper +from .naflex_loader import create_naflex_loader +from .naflex_transforms import ( + ResizeToSequence, + CenterCropToSequence, + RandomCropToSequence, + RandomResizedCropToSequence, + ResizeKeepRatioToSequence, +) from .readers import create_reader from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions from .real_labels import RealLabelsImagenet diff --git a/timm/data/naflex_dataset.py b/timm/data/naflex_dataset.py new file mode 100644 index 0000000000..858a182fb7 --- /dev/null +++ b/timm/data/naflex_dataset.py @@ -0,0 +1,424 @@ +""" +Dynamic Sequence Length Datasets for Variable Resolution Image Processing + +Implements two dataset wrappers: +1. DynamicSeqMapDataset - Map-style dataset that returns batches with variable sequence lengths +2. DynamicSeqIterDataset - Iterable dataset that yields batches with variable sequence lengths + +Both support: +- Pre-initialized transforms for efficiency +- Distributed training +- Multiple workers +- Variable batch sizes based on sequence length +""" + +import math +import random +import warnings +from functools import partial +from typing import Any, Iterator, List, Tuple, Dict, Optional, Union, Callable + +import torch +from torch.utils.data import Dataset, IterableDataset, DataLoader +from torchvision import transforms +from PIL import Image + + +from .naflex_transforms import Patchify, patchify + + +def calculate_batch_size( + tokens_per_batch: int, + seq_len: int, + max_size: Optional[int] = None, + divisor: int = 1, + rounding: str ='floor', +): + """Calculate batch size based on sequence length with divisibility constraints.""" + # Calculate raw batch size based on sequence length + raw_batch_size = tokens_per_batch / seq_len + + # Apply divisibility with specified rounding method + if divisor > 1: + if rounding == 'floor': + batch_size = math.floor(raw_batch_size / divisor) * divisor + elif rounding == 'ceil': + batch_size = math.ceil(raw_batch_size / divisor) * divisor + else: # 'round' is the default + batch_size = round(raw_batch_size / divisor) * divisor + else: + # If no divisor specified, just use integer division + batch_size = int(raw_batch_size) + + # Ensure batch size is valid + batch_size = max(1, batch_size) # At least 1 + + if max_size is not None: + batch_size = min(batch_size, max_size) + + return batch_size + + +class NaFlexCollator: + """Custom collator for batching NaFlex-style variable-resolution images.""" + + def __init__( + self, + max_seq_len=None, + ): + self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24) + + def __call__(self, batch): + """ + Args: + batch: List of tuples (patch_dict, target) + + Returns: + A tuple of (input_dict, targets) where input_dict contains: + - patches: Padded tensor of patches + - patch_coord: Coordinates for each patch (y, x) + - patch_valid: Valid indicators + """ + assert isinstance(batch[0], tuple) + batch_size = len(batch) + + # Extract targets + # FIXME need to handle dense (float) targets or always done downstream of this? + targets = torch.tensor([item[1] for item in batch], dtype=torch.int64) + + # Get patch dictionaries + patch_dicts = [item[0] for item in batch] + + # If we have a maximum sequence length constraint, ensure we don't exceed it + if self.max_seq_len is not None: + max_patches = self.max_seq_len + else: + # Find the maximum number of patches in this batch + max_patches = max(item['patches'].shape[0] for item in patch_dicts) + + # Get patch dimensionality + patch_dim = patch_dicts[0]['patches'].shape[1] + + # Prepare tensors for the batch + patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32) + patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x) + patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool) + + # Fill in the tensors + for i, patch_dict in enumerate(patch_dicts): + num_patches = min(patch_dict['patches'].shape[0], max_patches) + + patches[i, :num_patches] = patch_dict['patches'][:num_patches] + patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches] + patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches] + + return { + 'patches': patches, + 'patch_coord': patch_coord, + 'patch_valid': patch_valid, + 'seq_len': max_patches, + }, targets + + +class VariableSeqMapWrapper(IterableDataset): + """ + IterableDataset wrapper for a map-style base dataset. + + Yields batches with variable sequence lengths. It calculates a canonical + batch schedule (sequence length, batch size pairs) once based on the + total dataset size (padded for distribution). Each epoch, it shuffles + the *order* of this canonical schedule and the dataset indices. + This ensures a consistent number of batches and samples per epoch + across all ranks. Handles distributed training and multiple workers. + """ + + def __init__( + self, + base_dataset: Dataset, + patch_size: Union[int, Tuple[int, int]] = 16, + seq_lens: List[int] = (128, 256, 576, 784, 1024), + max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens + transform_factory: Optional[Callable] = None, + seed: int = 42, + shuffle: bool = True, + distributed: bool = False, + rank: int = 0, + world_size: int = 1, + epoch: int = 0, + batch_divisor: int = 8, # Ensure batch size is divisible by this + ): + super().__init__() + if not hasattr(base_dataset, '__len__') or not hasattr(base_dataset, '__getitem__'): + raise TypeError("base_dataset must be a map-style dataset (implement __len__ and __getitem__)") + + self.base_dataset = base_dataset + self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size) + self.seq_lens = sorted(list(set(seq_lens))) # Ensure unique and sorted + self.max_tokens_per_batch = max_tokens_per_batch + self.seed = seed + self.shuffle = shuffle + self.distributed = distributed + self.rank = rank if distributed else 0 + self.world_size = world_size if distributed else 1 + self.epoch = epoch + self.batch_divisor = batch_divisor + + # Pre-initialize transforms and collate fns for each sequence length + self.transforms: Dict[int, Optional[Callable]] = {} + self.collate_fns: Dict[int, Callable] = {} + for seq_len in self.seq_lens: + if transform_factory: + self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size) + else: + self.transforms[seq_len] = None # No transform + self.collate_fns[seq_len] = NaFlexCollator(seq_len) + self.patchifier = Patchify(self.patch_size) + + # --- Canonical Schedule Calculation (Done Once) --- + self._canonical_batch_schedule: List[Tuple[int, int]] = [] + self._num_batches_per_rank: int = 0 + self._padded_samples_per_rank: int = 0 + self._create_canonical_schedule() # Calculate schedule based on padded size + + # --- Per-Epoch State --- + # Stores (seq_len, list_of_indices) for the current epoch, specific to this rank + self._epoch_batches: List[Tuple[int, List[int]]] = [] + self._prepare_epoch_batches(self.epoch) # setup for initial epoch + + def _create_canonical_schedule(self): + """ + Calculates the canonical batch schedule (seq_len, batch_size pairs) + based on the dataset size, padded for distributed training. + This schedule is the *same* for all ranks and ensures consistent + epoch length. It is calculated once during initialization. + """ + total_len = len(self.base_dataset) + padded_total_len = total_len + num_samples_per_rank = total_len + + if self.distributed and self.world_size > 1: + # Calculate padding needed for even distribution + if total_len % self.world_size != 0: + pad_size = self.world_size - (total_len % self.world_size) + padded_total_len += pad_size + print(f"Rank {self.rank}: Padding dataset with {pad_size} samples for distributed training (total size {padded_total_len}).") + else: + pad_size = 0 + + if padded_total_len % self.world_size != 0: + # This should not happen with the padding logic, but safeguard + raise RuntimeError(f"Internal Error: Padded total length {padded_total_len} not divisible by world size {self.world_size}") + + num_samples_per_rank = padded_total_len // self.world_size + elif self.distributed and self.world_size <= 1: + # Distributed flag set but world_size is 1, treat as non-distributed + pass # num_samples_per_rank remains total_len + + self._padded_samples_per_rank = num_samples_per_rank + + if num_samples_per_rank == 0: + self._canonical_batch_schedule = [] + self._num_batches_per_rank = 0 + return + + # Use a fixed seed for generating the canonical schedule structure + g = torch.Generator() + g.manual_seed(self.seed) # Use base seed, NOT epoch seed + + current_schedule: List[Tuple[int, int]] = [] + remaining_samples = num_samples_per_rank + total_scheduled_samples = 0 + + while remaining_samples > 0: + # Sample sequence length deterministically based on base seed + seq_idx = torch.randint(0, len(self.seq_lens), (1,), generator=g).item() + seq_len = self.seq_lens[seq_idx] + + # Calculate batch size + batch_size = calculate_batch_size( + tokens_per_batch=self.max_tokens_per_batch, + seq_len=seq_len, + # max_size should be remaining_samples to avoid overshooting + max_size=remaining_samples, + divisor=self.batch_divisor, + rounding='floor', + ) + # Ensure batch size is positive and doesn't exceed remaining samples + batch_size = max(1, batch_size) + batch_size = min(batch_size, remaining_samples) + + if batch_size <= 0: + warnings.warn(f"Calculated batch size <= 0 (seq_len={seq_len}, remaining={remaining_samples}). Stopping schedule generation early.") + break # Avoid infinite loop if something goes wrong + + current_schedule.append((seq_len, batch_size)) + remaining_samples -= batch_size + total_scheduled_samples += batch_size + + # Sanity check: Ensure the schedule covers all samples for the rank + if total_scheduled_samples != num_samples_per_rank: + warnings.warn( + f"Rank {self.rank}: Canonical schedule accounts for {total_scheduled_samples} samples, " + f"but expected {num_samples_per_rank} samples per rank. " + f"This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. " + f"Check parameters. Remaining samples: {remaining_samples}" + ) + # Adjust if needed? Could add a final small batch, but might violate constraints. + # Current behavior: some samples might be dropped if schedule logic fails. + + self._canonical_batch_schedule = current_schedule + self._num_batches_per_rank = len(current_schedule) + print(f"Rank {self.rank}: Created canonical schedule with {self._num_batches_per_rank} batches for {self._padded_samples_per_rank} samples/rank.") + + + def _prepare_epoch_batches(self, epoch: int): + """ + Prepares the batches for the current epoch by: + 1. Shuffling the full dataset indices (using epoch seed). + 2. Applying padding if in distributed mode. + 3. Selecting indices for the current rank. + 4. Shuffling the *order* of the canonical batch schedule (using epoch seed). + 5. Assigning the rank's indices to the shuffled batches. + """ + g = torch.Generator() + g.manual_seed(self.seed + epoch) # Epoch-specific seed for shuffling + + # 1. Get shuffled global indices + total_len = len(self.base_dataset) + if self.shuffle: + all_indices_shuffled = torch.randperm(total_len, generator=g).tolist() + else: + all_indices_shuffled = list(range(total_len)) + + # 2. Apply padding for distributed mode + indices_for_ranks = all_indices_shuffled + if self.distributed and self.world_size > 1: + padded_total_len = self._padded_samples_per_rank * self.world_size + if padded_total_len > total_len: + pad_size = padded_total_len - total_len + # Repeat initial elements from the *shuffled* list for padding + indices_for_ranks = all_indices_shuffled + all_indices_shuffled[:pad_size] + # Ensure length matches expectation + if len(indices_for_ranks) != padded_total_len: + raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}") + + + # 3. Select indices for the current rank + if self.distributed and self.world_size > 1: + indices_this_rank = indices_for_ranks[self.rank::self.world_size] + else: # Non-distributed or world_size=1 + indices_this_rank = indices_for_ranks + + # Sanity check length + if len(indices_this_rank) != self._padded_samples_per_rank: + # This might happen if canonical schedule generation had warnings/issues + warnings.warn( + f"Rank {self.rank}: Number of indices for this rank ({len(indices_this_rank)}) " + f"does not match expected padded samples per rank ({self._padded_samples_per_rank}). " + f"Epoch generation might be inconsistent." + ) + # Adjust expected samples? Or truncate/pad indices? Let's proceed but warn. + # Using min() prevents IndexError later if indices are fewer than expected. + effective_samples_this_rank = min(len(indices_this_rank), self._padded_samples_per_rank) + indices_this_rank = indices_this_rank[:effective_samples_this_rank] + + else: + effective_samples_this_rank = self._padded_samples_per_rank + + # 4. Shuffle the order of the canonical batch schedule for this epoch + if self.shuffle: + schedule_perm = torch.randperm(self._num_batches_per_rank, generator=g).tolist() + shuffled_schedule = [self._canonical_batch_schedule[i] for i in schedule_perm] + else: + shuffled_schedule = list(self._canonical_batch_schedule) # Keep original order + + # 5. Assign indices to the shuffled batches + self._epoch_batches = [] + idx_pos = 0 + scheduled_samples_count = 0 + for seq_len, bs in shuffled_schedule: + # Ensure we don't try to grab more indices than available for the rank + actual_bs = min(bs, effective_samples_this_rank - idx_pos) + if actual_bs <= 0: + if scheduled_samples_count < effective_samples_this_rank: + # This indicates mismatch between schedule total and actual samples + warnings.warn(f"Rank {self.rank}: Ran out of samples ({idx_pos}/{effective_samples_this_rank}) before processing entire schedule. Check schedule generation.") + break # Stop if no more indices or batch size is zero + + batch_indices = indices_this_rank[idx_pos : idx_pos + actual_bs] + self._epoch_batches.append((seq_len, batch_indices)) + idx_pos += actual_bs + scheduled_samples_count += actual_bs + + # Final check + if scheduled_samples_count != effective_samples_this_rank: + warnings.warn( + f"Rank {self.rank}: Assigned {scheduled_samples_count} samples to batches, " + f"but expected {effective_samples_this_rank} effective samples this epoch. " + f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}." + ) + + def set_epoch(self, epoch: int): + """Updates the epoch, regenerating the epoch-specific batches.""" + # Only regenerate if the epoch actually changes + if epoch != self.epoch: + self.epoch = epoch + self._prepare_epoch_batches(epoch) + + def __len__(self) -> int: + """ + Returns the number of batches per **worker** for the current epoch. + Calculated based on the fixed number of batches per rank divided by + the number of workers. + """ + return self._num_batches_per_rank + + def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]: + """ + Iterates through the pre-calculated batches for the current epoch, + distributing them among workers. + """ + worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info else 1 + worker_id = worker_info.id if worker_info else 0 + + # Distribute pre-calculated batches among workers for this rank + # Each worker processes a slice of the batches prepared in _prepare_epoch_batches + batches_for_worker = self._epoch_batches[worker_id::num_workers] + for seq_len, indices in batches_for_worker: + if not indices: # Skip if a batch ended up with no indices (shouldn't happen often) + continue + + # Get the pre-initialized transform for this sequence length + transform = self.transforms.get(seq_len) + + batch_samples = [] + for idx in indices: + try: + # Get original image and label from map-style dataset + img, label = self.base_dataset[idx] + + # Apply transform if available + # Handle cases where transform might return None or fail + processed_img = transform(img) if transform else img + if processed_img is None: + warnings.warn(f"Transform returned None for index {idx}. Skipping sample.") + continue + + # Apply patching + patch_data = self.patchifier(processed_img) + batch_samples.append((patch_data, label)) + + except IndexError: + warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.") + continue + except Exception as e: + # Log other potential errors during data loading/processing + warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.") + continue # Skip problematic sample + + # Collate the processed samples into a batch + if batch_samples: # Only yield if we successfully processed samples + yield self.collate_fns[seq_len](batch_samples) + + # If batch_samples is empty after processing 'indices', an empty batch is skipped. diff --git a/timm/data/naflex_loader.py b/timm/data/naflex_loader.py new file mode 100644 index 0000000000..bb96d07d1a --- /dev/null +++ b/timm/data/naflex_loader.py @@ -0,0 +1,269 @@ +import math +from contextlib import suppress +from functools import partial +from typing import Callable, List, Optional, Tuple, Union + +import torch + +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .loader import _worker_init +from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator +from .transforms_factory import create_transform + + +class NaFlexPrefetchLoader: + """Data prefetcher for NaFlex format which normalizes patches.""" + + def __init__( + self, + loader, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + img_dtype=torch.float32, + device=torch.device('cuda') + ): + self.loader = loader + self.device = device + self.img_dtype = img_dtype or torch.float32 + + # Create mean/std tensors for normalization (will be applied to patches) + self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3) + self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3) + + # Check for CUDA/NPU availability + self.is_cuda = device.type == 'cuda' and torch.cuda.is_available() + self.is_npu = device.type == 'npu' and torch.npu.is_available() + + def __iter__(self): + first = True + if self.is_cuda: + stream = torch.cuda.Stream() + stream_context = partial(torch.cuda.stream, stream=stream) + elif self.is_npu: + stream = torch.npu.Stream() + stream_context = partial(torch.npu.stream, stream=stream) + else: + stream = None + stream_context = suppress + + for next_input_dict, next_target in self.loader: + with stream_context(): + # Move all tensors in input_dict to device + for k, v in next_input_dict.items(): + if isinstance(v, torch.Tensor): + dtype = self.img_dtype if k == 'patches' else None + next_input_dict[k] = next_input_dict[k].to( + device=self.device, + non_blocking=True, + dtype=dtype, + ) + + next_target = next_target.to(device=self.device, non_blocking=True) + + # Normalize patch values (assuming patches are in format [B, N, P*P*C]) + batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape + patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization + patches = patches.sub(self.mean).div(self.std) + + # Reshape back + next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels) + + if not first: + yield input_dict, target + else: + first = False + + if stream is not None: + if self.is_cuda: + torch.cuda.current_stream().wait_stream(stream) + elif self.is_npu: + torch.npu.current_stream().wait_stream(stream) + + input_dict = next_input_dict + target = next_target + + yield input_dict, target + + def __len__(self): + return len(self.loader) + + @property + def sampler(self): + return self.loader.sampler + + @property + def dataset(self): + return self.loader.dataset + + +def create_naflex_loader( + dataset, + patch_size: Union[Tuple[int, int], int] = 16, + train_seq_lens: List[int] = (128, 256, 576, 784, 1024), # Training sequence lengths + max_seq_len: int = 576, # Fixed sequence length for validation + batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens) + is_training: bool = False, + + no_aug: bool = False, + re_prob: float = 0., + re_mode: str = 'const', + re_count: int = 1, + re_split: bool = False, + train_crop_mode: Optional[str] = None, + scale: Optional[Tuple[float, float]] = None, + ratio: Optional[Tuple[float, float]] = None, + hflip: float = 0.5, + vflip: float = 0., + color_jitter: float = 0.4, + color_jitter_prob: Optional[float] = None, + grayscale_prob: float = 0., + gaussian_blur_prob: float = 0., + auto_augment: Optional[str] = None, + num_aug_repeats: int = 0, + num_aug_splits: int = 0, + interpolation: str = 'bilinear', + mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, + std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, + crop_pct: Optional[float] = None, + crop_mode: Optional[str] = None, + crop_border_pixels: Optional[int] = None, + + num_workers: int = 4, + distributed: bool = False, + rank: int = 0, + world_size: int = 1, + seed: int = 42, + epoch: int = 0, + use_prefetcher: bool = True, + pin_memory: bool = True, + img_dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = torch.device('cuda'), + persistent_workers: bool = True, + worker_seeding: str = 'all', + ): + """Create a data loader with dynamic sequence length sampling for training.""" + + if is_training: + # For training, use the dynamic sequence length mechanism + assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader' + + transform_factory = partial( + create_transform, + is_training=True, + no_aug=no_aug, + train_crop_mode=train_crop_mode, + scale=scale, + ratio=ratio, + hflip=hflip, + vflip=vflip, + color_jitter=color_jitter, + color_jitter_prob=color_jitter_prob, + grayscale_prob=grayscale_prob, + gaussian_blur_prob=gaussian_blur_prob, + auto_augment=auto_augment, + interpolation=interpolation, + mean=mean, + std=std, + crop_pct=crop_pct, + crop_mode=crop_mode, + crop_border_pixels=crop_border_pixels, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + use_prefetcher=use_prefetcher, + naflex=True, + ) + + max_train_seq_len = max(train_seq_lens) + max_tokens_per_batch = batch_size * max_train_seq_len + + if isinstance(dataset, torch.utils.data.IterableDataset): + assert False, "IterableDataset Wrapper is a WIP" + + naflex_dataset = VariableSeqMapWrapper( + dataset, + transform_factory=transform_factory, + patch_size=patch_size, + seq_lens=train_seq_lens, + max_tokens_per_batch=max_tokens_per_batch, + seed=seed, + distributed=distributed, + rank=rank, + world_size=world_size, + shuffle=True, + epoch=epoch, + ) + + # NOTE: Collation is handled by the dataset wrapper for training + # Create the collator (handles fixed-size collation) + # collate_fn = NaFlexCollator( + # max_seq_len=max(seq_lens) + 1, # +1 for class token + # ) + + loader = torch.utils.data.DataLoader( + naflex_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers, + sampler=None, + #collate_fn=collate_fn, + pin_memory=pin_memory, + worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), + persistent_workers=persistent_workers + ) + + if use_prefetcher: + loader = NaFlexPrefetchLoader( + loader, + mean=mean, + std=std, + img_dtype=img_dtype, + device=device, + ) + + else: + # For validation, use fixed sequence length (unchanged) + dataset.transform = create_transform( + is_training=False, + interpolation=interpolation, + mean=mean, + std=std, + # FIXME add crop args when sequence transforms support crop modes + use_prefetcher=use_prefetcher, + naflex=True, + patch_size=patch_size, + max_seq_len=max_seq_len, + patchify=True, + ) + + # Create the collator + collate_fn = NaFlexCollator(max_seq_len=max_seq_len) + + # Handle distributed training + sampler = None + if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): + # For validation, use OrderedDistributedSampler + from timm.data.distributed_sampler import OrderedDistributedSampler + sampler = OrderedDistributedSampler(dataset) + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=False, + ) + + if use_prefetcher: + loader = NaFlexPrefetchLoader( + loader, + mean=mean, + std=std, + img_dtype=img_dtype, + device=device, + ) + + return loader diff --git a/timm/data/naflex_transforms.py b/timm/data/naflex_transforms.py new file mode 100644 index 0000000000..81653e78be --- /dev/null +++ b/timm/data/naflex_transforms.py @@ -0,0 +1,803 @@ +""" NaFlex (NaViT + FlexiViT) Transforms and Collation + +Implements PyTorch versions of the transforms described in the NaViT and FlexiViT papers: +- NaViT: https://arxiv.org/abs/2307.14995 +- FlexiViT: https://arxiv.org/abs/2212.08013 + +Enables variable resolution/aspect ratio image handling with efficient patching. +""" + +import math +import random +import warnings +from typing import List, Optional, Sequence, Tuple, Union + +import torch +from PIL import Image +from torchvision import transforms +from torchvision.transforms import functional as F +from torchvision.transforms.functional import InterpolationMode + +from .transforms import str_to_interp_mode, crop_or_pad, center_crop_or_pad + + +def get_image_size_for_seq( + image_hw, + patch_size=16, + max_seq_len=1024, + divisible_by_patch=True, + max_ratio=None, + eps = 1e-5, +): + """ + Determine scaling ratio and image size so that when `image_hw` is scaled + by 'ratio', the total number of resulting patches does not exceed + 'max_seq_len'. + + - Patch size can be an integer (square patch) or a tuple (patch_h, patch_w). + - Optionally cap the ratio at `max_ratio` to prevent upsampling beyond + a certain multiple of the original size. + + Args: + image_hw (tuple or list of int): (height, width) of the original image. + patch_size (int or tuple[int, int]): If int, patch is square. If tuple, + patch is rectangular (patch_h, patch_w). + max_seq_len (int): Maximum allowed sequence length for the resulting image. + divisible_by_patch (bool): If True, the resulting image height and width + must be multiples of patch_size. + eps (float): Small number for binary search convergence. + max_ratio (float or None): If provided, the scaling ratio found by the + binary search will be clamped to min(found_ratio, max_ratio). Set + max_ratio=1.0 to ensure no upsampling beyond original size. + + Returns: + ratio (float): Found scaling ratio (capped by `max_ratio` if provided). + target_hw (tuple of int): Target (height, width) after scaling. + """ + + # Handle patch size input, extract patch_h, patch_w + if isinstance(patch_size, int): + patch_h, patch_w = patch_size, patch_size + else: + # Assume it's a tuple/list: (patch_h, patch_w) + if len(patch_size) != 2: + raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).") + patch_h, patch_w = patch_size + + # Safety checks + if patch_h <= 0 or patch_w <= 0: + raise ValueError("patch_size dimensions must be positive.") + + def prepare_target_hw(ratio): + """Scale image_hw by ratio and optionally round dimensions to multiples of patch_h, patch_w.""" + scaled_h = image_hw[0] * ratio + scaled_w = image_hw[1] * ratio + + # If we need the result to be divisible by patch_size + if divisible_by_patch: + scaled_h = patch_h * math.ceil(scaled_h / patch_h) + scaled_w = patch_w * math.ceil(scaled_w / patch_w) + + # Ensure at least one patch in each dimension + scaled_h = int(max(scaled_h, patch_h)) + scaled_w = int(max(scaled_w, patch_w)) + + return scaled_h, scaled_w + + def is_feasible(ratio): + """Check if scaling by 'ratio' keeps patch count within max_seq_len.""" + t_h, t_w = prepare_target_hw(ratio) + + # Each dimension is already a multiple of patch_h, patch_w if divisible_by_patch=True. + # Use integer division to count patches. + num_patches_h = t_h // patch_h + num_patches_w = t_w // patch_w + seq_len = num_patches_h * num_patches_w + + return seq_len <= max_seq_len + + # Binary search boundaries + lb = eps / 10.0 + rb = 100.0 + + # Standard binary search loop + while (rb - lb) >= eps: + mid = (lb + rb) / 2.0 + if is_feasible(mid): + lb = mid + else: + rb = mid + + # The final ratio from the binary search + ratio = lb + + # If max_ratio is provided, clamp it to prevent upsampling beyond that threshold + if max_ratio is not None: + ratio = min(ratio, max_ratio) + + # Final checks + if ratio <= eps: + raise ValueError("Binary search failed - image might be too large?") + if ratio >= 100.0: + raise ValueError("Binary search failed - image might be too small?") + + # Prepare the final target dimensions with the possibly clamped ratio + target_hw = prepare_target_hw(ratio) + return ratio, target_hw + + +_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic')) + + +class ResizeToSequence(torch.nn.Module): + """Resize image to fit within a maximum sequence length constraint when patchified. + + This maintains aspect ratio while ensuring the resulting image, when divided into patches, + will not exceed the specified maximum sequence length. + """ + def __init__( + self, + patch_size: int, + max_seq_len: int = 1024, + divisible_by_patch: bool = True, + max_ratio: Optional[float] = None, + interpolation='bicubic', + ): + super().__init__() + self.patch_size = patch_size + self.max_seq_len = max_seq_len + self.divisible_by_patch = divisible_by_patch + self.max_ratio = max_ratio + if isinstance(interpolation, str): + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = str_to_interp_mode(interpolation) + else: + self.interpolation = interpolation + + + def forward(self, img): + """Resize image to maintain aspect ratio and fit sequence constraint.""" + _, h, w = transforms.functional.get_dimensions(img) + + _, target_hw = get_image_size_for_seq( + (h, w), + self.patch_size, + self.max_seq_len, + divisible_by_patch=self.divisible_by_patch, + max_ratio=self.max_ratio, + ) + + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + + resized_img = transforms.functional.resize(img, target_hw, interpolation=interpolation, antialias=True) + + return resized_img + + +class ResizeKeepRatioToSequence(torch.nn.Module): + """ + Resize and Keep Aspect Ratio, adapted to fit sequence length constraints. + """ + + def __init__( + self, + patch_size=16, + max_sequence_len=1024, + divisible_by_patch=True, + longest=0., + interpolation='bilinear', + random_scale_prob=0., + random_scale_range=(0.85, 1.05), + random_scale_area=False, + random_aspect_prob=0., + random_aspect_range=(0.9, 1.11), + max_ratio=None, + ): + """ + Args: + patch_size: Size of patches (int or tuple of (patch_h, patch_w)) + max_sequence_len: Maximum allowed sequence length for the resulting image + divisible_by_patch: If True, ensure dimensions are divisible by patch_size + longest: Float between 0-1 where 0=shortest side, 1=longest side determines scale + interpolation: Interpolation method for resizing + random_scale_prob: Probability of applying random scaling + random_scale_range: Range for random scaling factor (min, max) + random_scale_area: If True, scale factors affect area (√ factor) + random_aspect_prob: Probability of applying random aspect ratio jittering + random_aspect_range: Range for random aspect ratio (min, max) + max_ratio: Maximum allowed scaling ratio + """ + super().__init__() + self.patch_size = patch_size + self.max_sequence_len = max_sequence_len + self.divisible_by_patch = divisible_by_patch + self.longest = float(longest) + + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = str_to_interp_mode(interpolation) + + self.random_scale_prob = random_scale_prob + self.random_scale_range = random_scale_range + self.random_scale_area = random_scale_area + self.random_aspect_prob = random_aspect_prob + self.random_aspect_range = random_aspect_range + self.max_ratio = max_ratio + + @staticmethod + def get_params( + img, + patch_size, + max_sequence_len, + divisible_by_patch, + longest, + random_scale_prob=0., + random_scale_range=(1.0, 1.33), + random_scale_area=False, + random_aspect_prob=0., + random_aspect_range=(0.9, 1.11), + max_ratio=None, + ): + """Get parameters for resizing.""" + # Get image dimensions + img_h, img_w = F.get_dimensions(img)[1:] + + # Step 1: Get the maximum allowed dimensions from sequence length constraint + _, target_hw = get_image_size_for_seq( + (img_h, img_w), + patch_size, + max_sequence_len, + divisible_by_patch, + max_ratio, + ) + target_h, target_w = target_hw + + # Calculate ratio based on sequence constraint + ratio_h = target_h / img_h + ratio_w = target_w / img_w + # Apply longest blending + ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) + + # Apply random scaling + if random_scale_prob > 0 and random.random() < random_scale_prob: + ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) + if random_scale_area: + # Make ratio factor equivalent to area change + ratio_factor = 1. / math.sqrt(ratio_factor) + ratio_factor = (ratio_factor, ratio_factor) + else: + ratio_factor = (1., 1.) + + # Apply random aspect + if random_aspect_prob > 0 and random.random() < random_aspect_prob: + log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1])) + aspect_factor = math.exp(random.uniform(*log_aspect)) + aspect_factor = math.sqrt(aspect_factor) + # Apply aspect ratio jittering + ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) + + # Calculate final dimensions + size = [round(dim * ratio * f) for dim, f in zip((img_h, img_w), ratio_factor)] + + # Ensure dimensions satisfy sequence constraint and are divisible by patch size + if isinstance(patch_size, int): + ph, pw = patch_size, patch_size + else: + ph, pw = patch_size + + # Ensure dimensions are at least one patch + size[0] = max(size[0], ph) + size[1] = max(size[1], pw) + + # Make divisible by patch size if needed + if divisible_by_patch: + size[0] = ph * math.ceil(size[0] / ph) + size[1] = pw * math.ceil(size[1] / pw) + + # Verify we haven't exceeded sequence length + num_patches_h = size[0] // ph + num_patches_w = size[1] // pw + seq_len = num_patches_h * num_patches_w + + if seq_len > max_sequence_len: + # Scale back down to fit sequence constraint + scale_back = math.sqrt(max_sequence_len / seq_len) + size[0] = int(size[0] * scale_back) + size[1] = int(size[1] * scale_back) + + # Ensure divisible by patch size after scaling back + if divisible_by_patch: + size[0] = ph * math.ceil(size[0] / ph) + size[1] = pw * math.ceil(size[1] / pw) + + return size + + def forward(self, img): + """ + Resize the image with aspect ratio preservation and sequence length constraints. + """ + size = self.get_params( + img, + self.patch_size, + self.max_sequence_len, + self.divisible_by_patch, + self.longest, + self.random_scale_prob, + self.random_scale_range, + self.random_scale_area, + self.random_aspect_prob, + self.random_aspect_range, + self.max_ratio, + ) + + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + + return F.resize(img, size, interpolation) + + def __repr__(self): + interpolate_str = "random" if isinstance(self.interpolation, (tuple, list)) else str(self.interpolation) + return (f"{self.__class__.__name__}(patch_size={self.patch_size}, " + f"max_sequence_len={self.max_sequence_len}, " + f"longest={self.longest:.3f}, " + f"random_scale_prob={self.random_scale_prob:.3f}, " + f"random_aspect_prob={self.random_aspect_prob:.3f})") + + +class CenterCropToSequence(torch.nn.Module): + """Center crop the image such that the resulting patch sequence length meets constraints.""" + def __init__( + self, + patch_size: int, + max_seq_len: int, + divisible_by_patch: bool = True, + fill: Union[int, Tuple[int, int, int]] = 0, + padding_mode: str = 'constant' + ): + super().__init__() + self.patch_size = patch_size + self.max_seq_len = max_seq_len + self.divisible_by_patch = divisible_by_patch + self.fill = fill + self.padding_mode = padding_mode + + + def forward(self, img): + """Center crop the image to maintain aspect ratio and fit sequence constraint.""" + _, h, w = transforms.functional.get_dimensions(img) + _, target_hw = get_image_size_for_seq( + (h, w), + self.patch_size, + self.max_seq_len, + self.divisible_by_patch + ) + + # Use center crop + return center_crop_or_pad(img, target_hw, fill=self.fill, padding_mode=self.padding_mode) + + +class RandomCropToSequence(torch.nn.Module): + """Randomly crop and/or pad the image to fit sequence length constraints. + + This maintains aspect ratio while ensuring the resulting image, when divided into patches, + will not exceed the specified maximum sequence length. Similar to CentralCropToSequence + but with randomized positioning. + """ + + def __init__( + self, + patch_size: int, + max_sequence_len: int, + divisible_by_patch: bool = True, + fill: Union[int, Tuple[int, int, int]] = 0, + padding_mode: str = 'constant' + ): + """ + Args: + patch_size: Size of patches (int or tuple of (patch_h, patch_w)) + max_sequence_len: Maximum allowed sequence length for the resulting image + divisible_by_patch: If True, resulting image dimensions will be multiples of patch_size + fill: Fill value for padding + padding_mode: Padding mode ('constant', 'edge', 'reflect', 'symmetric') + """ + super().__init__() + self.patch_size = patch_size + self.max_sequence_len = max_sequence_len + self.divisible_by_patch = divisible_by_patch + self.fill = fill + self.padding_mode = padding_mode + + @staticmethod + def get_params(img, target_size): + """Get random position for crop/pad.""" + _, image_height, image_width = transforms.functional.get_dimensions(img) + delta_height = image_height - target_size[0] + delta_width = image_width - target_size[1] + + # Handle both positive (crop) and negative (pad) deltas + if delta_height == 0: + top = 0 + else: + top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height)) + + if delta_width == 0: + left = 0 + else: + left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width)) + + return top, left + + def forward(self, img): + """Randomly crop or pad the image to maintain aspect ratio and fit sequence constraint.""" + # Get current dimensions + _, img_h, img_w = transforms.functional.get_dimensions(img) + + # Calculate target dimensions that satisfy sequence length + # We use max_ratio=1.0 to prevent upscaling - we only want to crop or maintain current size + _, target_hw = get_image_size_for_seq( + (img_h, img_w), + self.patch_size, + self.max_sequence_len, + self.divisible_by_patch, + max_ratio=1.0 # Prevent upscaling + ) + + # Get random position for crop/pad + top, left = self.get_params(img, target_hw) + + # Apply crop or pad + return crop_or_pad( + img, + top=top, + left=left, + height=target_hw[0], + width=target_hw[1], + fill=self.fill, + padding_mode=self.padding_mode, + ) + + def __repr__(self) -> str: + return (f"{self.__class__.__name__}(patch_size={self.patch_size}, " + f"max_sequence_len={self.max_sequence_len}, " + f"divisible_by_patch={self.divisible_by_patch})") + + +def _validate_range(value, name, length=2): + # Validate type and length + if not isinstance(value, Sequence) or len(value) != length: + raise ValueError(f"{name} should be a sequence of length {length}.") + + # Validate order + if value[0] > value[1]: + warnings.warn(f"{name.capitalize()} range reversed. Swapping.") + return value[1], value[0] + + return value + + +class RandomResizedCropToSequence(torch.nn.Module): + """ + Randomly crop the input image to a subregion with varying area and aspect ratio + (relative to the original), then resize that crop to a target size. The target size + is determined such that patchifying the resized image (with `patch_size`) + does not exceed `max_seq_len` patches, while maintaining the aspect ratio of the crop. + + This combines aspects of torchvision's RandomResizedCrop with sequence length constraints. + + Args: + patch_size (int or tuple[int, int]): + Patch dimensions (patch_h, patch_w) for sequence length calculation. + max_seq_len (int): + Maximum number of patches allowed in the final image. + scale (tuple[float, float]): + Range (min, max) of area fraction of the original image to crop. + ratio (tuple[float, float]): + Range (min, max) of aspect ratio *multipliers* for the crop, relative + to the original image's aspect ratio. E.g., (0.75, 1.333) means the + crop's aspect ratio will be sampled between 0.75*orig_ar and 1.333*orig_ar. + Uses log-uniform sampling. + interpolation (str or InterpolationMode): + Interpolation mode for resizing. Can be 'bilinear', 'bicubic', 'nearest', + or 'random' (chooses between bilinear and bicubic). + Defaults to 'bicubic'. + divisible_by_patch (bool): + If True, the final image height and width will be multiples of the + respective patch dimensions. Defaults to True. + max_ratio (float, optional): + An optional upper limit on the scaling ratio applied during resizing. + Prevents excessive upsampling of the initial crop. `max_ratio=1.0` + prevents any upsampling beyond the cropped size. Defaults to None (no limit). + final_scale_range (tuple[float, float], optional): + If provided, applies an *additional* random scaling factor to the + final target size. The factor is sampled uniformly from this range, + and multiplied by the size determined by `get_image_size_for_seq`. + E.g., (0.8, 1.0) means the final size will be between 80% and 100% + of the maximum feasible size. Defaults to None (use maximum feasible size). + attempts (int): + Number of attempts to sample a valid crop geometry before falling back + to a center crop strategy. Defaults to 10. + """ + + def __init__( + self, + patch_size: Union[int, Tuple[int, int]] = 16, + max_seq_len: int = 1024, + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (.8, 1.25), + interpolation: Union[str, InterpolationMode] = 'bicubic', + divisible_by_patch: bool = True, + max_ratio: Optional[float] = None, + final_scale_range: Optional[Tuple[float, float]] = None, + attempts: int = 10, + ): + super().__init__() + if isinstance(patch_size, int): + self.patch_h, self.patch_w = patch_size, patch_size + else: + # Assume it's a tuple/list: (patch_h, patch_w) + if len(patch_size) != 2: + raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).") + self.patch_h, self.patch_w = patch_size + self.max_seq_len = max_seq_len + self.scale = scale + self.ratio = ratio + self.divisible_by_patch = divisible_by_patch + self.max_ratio = max_ratio + self.final_scale_range = final_scale_range + self.attempts = attempts + if isinstance(interpolation, str): + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = str_to_interp_mode(interpolation) + else: + self.interpolation = interpolation + + # Validate scale and ratio + self.scale = _validate_range(self.scale, "scale") + self.ratio = _validate_range(self.ratio, "ratio") + + # Validate final_scale_range if provided + if self.final_scale_range is not None: + self.final_scale_range = _validate_range(self.final_scale_range, "final_scale_range") + + # Additional validation for final_scale_range values + if not (0.0 <= self.final_scale_range[0] <= self.final_scale_range[1] <= 1.0): + warnings.warn("final_scale_range values should ideally be between 0.0 and 1.0.") + + @staticmethod + def get_params( + img: torch.Tensor, + scale: Tuple[float, float], + ratio: Tuple[float, float], + crop_attempts: int = 10, + patch_h: int = 16, + patch_w: int = 16, + max_seq_len: int = 1024, + divisible_by_patch: bool = True, + max_ratio: Optional[float] = None, + final_scale_range: Optional[Tuple[float, float]] = None, + interpolation: Union[List[InterpolationMode], InterpolationMode] = _RANDOM_INTERPOLATION, + ) -> Tuple[Tuple[int, int, int, int], Tuple[int, int], InterpolationMode]: + """ Get parameters for a random sized crop relative to image aspect ratio. + """ + _, height, width = F.get_dimensions(img) + if height <= 0 or width <= 0: + raise ValueError(f"Input image must have positive dimensions, got H={height}, W={width}") + + area = height * width + orig_aspect = width / height + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + + for _ in range(crop_attempts): + target_area = area * random.uniform(scale[0], scale[1]) + aspect_ratio_factor = math.exp(random.uniform(log_ratio[0], log_ratio[1])) + aspect_ratio = orig_aspect * aspect_ratio_factor + + # Calculate target dimensions for the crop + # target_area = crop_w * crop_h, aspect_ratio = crop_w / crop_h + # => crop_h = sqrt(target_area / aspect_ratio) + # => crop_w = sqrt(target_area * aspect_ratio) + crop_h = int(round(math.sqrt(target_area / aspect_ratio))) + crop_w = int(round(math.sqrt(target_area * aspect_ratio))) + + if 0 < crop_w <= width and 0 < crop_h <= height: + top = random.randint(0, height - crop_h) + left = random.randint(0, width - crop_w) + break + else: + # Fallback strategy, use center crop trying to respect ratio range + min_aspect_ratio = orig_aspect * ratio[0] + max_aspect_ratio = orig_aspect * ratio[1] + + if orig_aspect < min_aspect_ratio: + # Original is narrower than target min, clamp width + crop_w = width + crop_h = min(int(round(crop_w / min_aspect_ratio)), height) + elif orig_aspect > max_aspect_ratio: + # Original is wider than target max, clamp height + crop_h = height + crop_w = min(int(round(crop_h * max_aspect_ratio)), width) + else: + # Aspect ratio is within range, take the largest possible crop (full image) + crop_w = width + crop_h = height + + # Ensure valid dimensions after fallback calculation + crop_h = max(1, crop_h) + crop_w = max(1, crop_w) + + top = (height - crop_h) // 2 + left = (width - crop_w) // 2 + + # Determine max feasible size for scaling of the *cropped* region + feasible_ratio, feasible_size = get_image_size_for_seq( + (crop_h, crop_w), + patch_size=(patch_h, patch_w), # Pass as tuple + max_seq_len=max_seq_len, + divisible_by_patch=divisible_by_patch, + max_ratio=max_ratio, + ) + + # Optionally apply final scale randomization + final_size = feasible_size + if final_scale_range is not None: + min_sc, max_sc = final_scale_range + scale_factor = random.uniform(min_sc, max_sc) + scale_factor = min(max(scale_factor, 0.0), 1.0) # Clamp factor just in case + + # Calculate raw scaled size + # Note: feasible_ratio already accounts for max_ratio clamp if any + raw_h = crop_h * feasible_ratio * scale_factor + raw_w = crop_w * feasible_ratio * scale_factor + + # Re-apply divisibility constraint if needed + if divisible_by_patch: + # Use ceil to avoid going under minimum patch size + target_h = patch_h * math.ceil(raw_h / patch_h) + target_w = patch_w * math.ceil(raw_w / patch_w) + else: + target_h = int(round(raw_h)) + target_w = int(round(raw_w)) + + # Ensure final size is at least one patch dimension + target_h = max(target_h, patch_h) + target_w = max(target_w, patch_w) + final_size = (target_h, target_w) + + # Final check: Ensure this randomized size still fits max_seq_len + # (It should, as we scaled down, but rounding might theoretically push it over) + num_patches_h = final_size[0] // patch_h + num_patches_w = final_size[1] // patch_w + if (num_patches_h * num_patches_w) > max_seq_len: + # If it exceeds, revert to the original feasible_size (safest) + final_size = feasible_size + warnings.warn(f"Final scale randomization ({scale_factor:.2f}) resulted in size {final_size} exceeding max_seq_len={max_seq_len} after rounding. Reverting to feasible size {feasible_size}.") + + # Select interpolation mode + if isinstance(interpolation, (tuple, list)): + interpolation = random.choice(interpolation) + else: + interpolation = interpolation + + return (top, left, crop_h, crop_w), final_size, interpolation + + def forward(self, img: torch.Tensor) -> torch.Tensor: + # Sample crop, resize, and interpolation parameters + crop_params, final_size, interpolation = self.get_params( + img, + scale=self.scale, + ratio=self.ratio, + crop_attempts=self.attempts, + patch_h=self.patch_h, + patch_w=self.patch_w, + divisible_by_patch=self.divisible_by_patch, + max_seq_len=self.max_seq_len, + final_scale_range=self.final_scale_range, + interpolation=self.interpolation, + ) + top, left, crop_h, crop_w = crop_params + + output = F.resized_crop( + img, + top=top, + left=left, + height=crop_h, + width=crop_w, + size=final_size, + interpolation=interpolation, + antialias=True, + ) + + return output + + def __repr__(self) -> str: + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ', '.join(str(m).split('.')[-1] for m in self.interpolation) + else: + interpolate_str = str(self.interpolation) + format_string = self.__class__.__name__ + '(' + format_string += f"patch_size=({self.patch_h}, {self.patch_w})" + format_string += f", max_seq_len={self.max_seq_len}" + format_string += f", scale={self.scale}" + format_string += f", ratio={self.ratio}" + format_string += f", interpolation=[{interpolate_str}]" + format_string += f", divisible_by_patch={self.divisible_by_patch}" + format_string += f", max_ratio={self.max_ratio}" + format_string += f", final_scale_range={self.final_scale_range}" + format_string += f", attempts={self.attempts}" + format_string += ')' + return format_string + + +def patchify( + img: torch.Tensor, + patch_size: Tuple[int, int], + pad: bool = True, + include_info: bool = True, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + c, h, w = img.shape + ph, pw = patch_size + + # Ensure the image is divisible by patch size + if pad and (h % ph != 0 or w % pw != 0): + new_h = math.ceil(h / ph) * ph + new_w = math.ceil(w / pw) * pw + padded_img = torch.zeros(c, new_h, new_w, dtype=img.dtype) + padded_img[:, :h, :w] = img + img = padded_img + c, h, w = img.shape + + # Calculate number of patches in each dimension + nh, nw = h // ph, w // pw + # Reshape image to patches [nh, nw, ph, pw, c] + patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c) + if include_info: + # Create coordinate indices + y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij') + # Stack into a single coords tensor [N, 2] with (y, x) order + coord = torch.stack([y_idx.reshape(-1), x_idx.reshape(-1)], dim=1) + # Create type indicators (all 1s for regular patches) + valid = torch.ones(nh * nw, dtype=torch.bool) + return patches, coord, valid + + return patches + + +class Patchify(torch.nn.Module): + """Transform an image into patches with corresponding coordinates and type indicators.""" + + def __init__(self, patch_size): + super().__init__() + self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size) + + def forward(self, img): + """ + Args: + img: A PIL Image or tensor of shape [C, H, W] + + Returns: + A dictionary containing: + - patches: Tensor of shape [N, P*P*C] where N is the number of patches + - patch_coord: Tensor of shape [N, 2] with (y, x) coordinates + - patch_valid: Valid indicator (all 1s for non-padding patches) + """ + if isinstance(img, Image.Image): + # Convert PIL Image to tensor [C, H, W] + img = transforms.functional.to_tensor(img) + + patches, coord, valid = patchify(img, self.patch_size) + + return { + 'patches': patches, + 'patch_coord': coord, + 'patch_valid': valid, + } diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 9be0e3bf3c..ed42745688 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -12,7 +12,8 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \ - ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor + ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, MaybeToTensor, MaybePILToTensor +from timm.data.naflex_transforms import RandomResizedCropToSequence, ResizeToSequence, Patchify from timm.data.random_erasing import RandomErasing @@ -46,7 +47,7 @@ def transforms_noaug_train( ] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] + tfl += [MaybePILToTensor()] elif not normalize: # when normalize disabled, converted to tensor without scaling, keep original dtype tfl += [MaybePILToTensor()] @@ -84,6 +85,10 @@ def transforms_imagenet_train( use_prefetcher: bool = False, normalize: bool = True, separate: bool = False, + naflex: bool = False, + patch_size: Union[int, Tuple[int, int]] = 16, + max_seq_len: int = 576, # 24x24 for 16x16 patch + patchify: bool = False, ): """ ImageNet-oriented image transforms for training. @@ -111,6 +116,9 @@ def transforms_imagenet_train( use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). separate: Output transforms in 3-stage tuple. + naflex: Enable NaFlex mode, sequence constrained patch output + patch_size: Patch size for NaFlex mode. + max_seq_len: Max sequence length for NaFlex mode. Returns: If separate==True, the transforms are returned as a tuple of 3 separate transforms @@ -121,35 +129,45 @@ def transforms_imagenet_train( """ train_crop_mode = train_crop_mode or 'rrc' assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'} - if train_crop_mode in ('rkrc', 'rkrr'): - # FIXME integration of RKR is a WIP - scale = tuple(scale or (0.8, 1.00)) - ratio = tuple(ratio or (0.9, 1/.9)) - primary_tfl = [ - ResizeKeepRatio( - img_size, - interpolation=interpolation, - random_scale_prob=0.5, - random_scale_range=scale, - random_scale_area=True, # scale compatible with RRC - random_aspect_prob=0.5, - random_aspect_range=ratio, - ), - CenterCropOrPad(img_size, padding_mode='reflect') - if train_crop_mode == 'rkrc' else - RandomCropOrPad(img_size, padding_mode='reflect') - ] + + primary_tfl = [] + if naflex: + primary_tfl += [RandomResizedCropToSequence( + patch_size=patch_size, + max_seq_len=max_seq_len, + interpolation=interpolation + )] else: - scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range - ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range - primary_tfl = [ - RandomResizedCropAndInterpolation( - img_size, - scale=scale, - ratio=ratio, - interpolation=interpolation, - ) - ] + if train_crop_mode in ('rkrc', 'rkrr'): + # FIXME integration of RKR is a WIP + scale = tuple(scale or (0.8, 1.00)) + ratio = tuple(ratio or (0.9, 1/.9)) + primary_tfl += [ + ResizeKeepRatio( + img_size, + interpolation=interpolation, + random_scale_prob=0.5, + random_scale_range=scale, + random_scale_area=True, # scale compatible with RRC + random_aspect_prob=0.5, + random_aspect_range=ratio, + ), + CenterCropOrPad(img_size, padding_mode='reflect') + if train_crop_mode == 'rkrc' else + RandomCropOrPad(img_size, padding_mode='reflect') + ] + else: + scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range + ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range + primary_tfl += [ + RandomResizedCropAndInterpolation( + img_size, + scale=scale, + ratio=ratio, + interpolation=interpolation, + ) + ] + if hflip > 0.: primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] if vflip > 0.: @@ -215,7 +233,7 @@ def transforms_imagenet_train( final_tfl = [] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm - final_tfl += [ToNumpy()] + final_tfl += [MaybePILToTensor()] elif not normalize: # when normalize disable, converted to tensor without scaling, keeps original dtype final_tfl += [MaybePILToTensor()] @@ -238,6 +256,9 @@ def transforms_imagenet_train( ) ] + if patchify: + final_tfl += [Patchify(patch_size=patch_size)] + if separate: return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) else: @@ -254,6 +275,10 @@ def transforms_imagenet_eval( std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, use_prefetcher: bool = False, normalize: bool = True, + naflex: bool = False, + patch_size: Union[int, Tuple[int, int]] = 16, + max_seq_len: int = 576, # 24x24 for 16x16 patch + patchify: bool = False, ): """ ImageNet-oriented image transform for evaluation and inference. @@ -267,6 +292,10 @@ def transforms_imagenet_eval( std: Image normalization standard deviation. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). + naflex: Enable NaFlex mode, sequence constrained patch output + patch_size: Patch size for NaFlex mode. + max_seq_len: Max sequence length for NaFlex mode. + patchify: Patchify the output instead of relying on prefetcher Returns: Composed transform pipeline @@ -285,37 +314,44 @@ def transforms_imagenet_eval( if crop_border_pixels: tfl += [TrimBorder(crop_border_pixels)] - if crop_mode == 'squash': - # squash mode scales each edge to 1/pct of target, then crops - # aspect ratio is not preserved, no img lost if crop_pct == 1.0 - tfl += [ - transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), - transforms.CenterCrop(img_size), - ] - elif crop_mode == 'border': - # scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop - # no image lost if crop_pct == 1.0 - fill = [round(255 * v) for v in mean] - tfl += [ - ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0), - CenterCropOrPad(img_size, fill=fill), - ] + if naflex: + tfl += [ResizeToSequence( + patch_size=patch_size, + max_seq_len=max_seq_len, + interpolation=interpolation, + )] else: - # default crop model is center - # aspect ratio is preserved, crops center within image, no borders are added, image is lost - if scale_size[0] == scale_size[1]: - # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) + if crop_mode == 'squash': + # squash mode scales each edge to 1/pct of target, then crops + # aspect ratio is not preserved, no img lost if crop_pct == 1.0 tfl += [ - transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation)) + transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), + transforms.CenterCrop(img_size), + ] + elif crop_mode == 'border': + # scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop + # no image lost if crop_pct == 1.0 + fill = [round(255 * v) for v in mean] + tfl += [ + ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0), + CenterCropOrPad(img_size, fill=fill), ] else: - # resize the shortest edge to matching target dim for non-square target - tfl += [ResizeKeepRatio(scale_size)] - tfl += [transforms.CenterCrop(img_size)] + # default crop model is center + # aspect ratio is preserved, crops center within image, no borders are added, image is lost + if scale_size[0] == scale_size[1]: + # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) + tfl += [ + transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation)) + ] + else: + # resize the shortest edge to matching target dim for non-square target + tfl += [ResizeKeepRatio(scale_size)] + tfl += [transforms.CenterCrop(img_size)] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] + tfl += [MaybePILToTensor()] elif not normalize: # when normalize disabled, converted to tensor without scaling, keeps original dtype tfl += [MaybePILToTensor()] @@ -328,6 +364,9 @@ def transforms_imagenet_eval( ), ] + if patchify: + tfl += [Patchify(patch_size=patch_size)] + return transforms.Compose(tfl) @@ -359,6 +398,10 @@ def create_transform( use_prefetcher: bool = False, normalize: bool = True, separate: bool = False, + naflex: bool = False, + patch_size: Union[int, Tuple[int, int]] = 16, + max_seq_len: int = 576, # 24x24 for 16x16 patch + patchify: bool = False ): """ @@ -442,6 +485,10 @@ def create_transform( use_prefetcher=use_prefetcher, normalize=normalize, separate=separate, + naflex=naflex, + patch_size=patch_size, + max_seq_len=max_seq_len, + patchify=patchify, ) else: assert not separate, "Separate transforms not supported for validation preprocessing" @@ -455,6 +502,10 @@ def create_transform( crop_border_pixels=crop_border_pixels, use_prefetcher=use_prefetcher, normalize=normalize, + naflex=naflex, + patch_size=patch_size, + max_seq_len=max_seq_len, + patchify=patchify, ) return transform diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index c71ff30c82..ada5b6656a 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -1,6 +1,7 @@ from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .attention import Attention from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention_pool import AttentionPoolLatent from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding diff --git a/timm/layers/attention.py b/timm/layers/attention.py new file mode 100644 index 0000000000..18113357d1 --- /dev/null +++ b/timm/layers/attention.py @@ -0,0 +1,66 @@ +from typing import Final, Type, Optional + +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .config import use_fused_attn + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: Type[nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if attn_mask is not None: + attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/timm/layers/attention_pool.py b/timm/layers/attention_pool.py index 2e87566ad4..d050e7ca2c 100644 --- a/timm/layers/attention_pool.py +++ b/timm/layers/attention_pool.py @@ -75,7 +75,7 @@ def init_weights(self): trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) - def forward(self, x): + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): B, N, C = x.shape if self.pos_embed is not None: @@ -91,10 +91,12 @@ def forward(self, x): q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: - x = F.scaled_dot_product_attention(q, k, v) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) else: q = q * self.scale attn = q @ k.transpose(-2, -1) + if attn_mask is not None: + attn = attn + attn_mask attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, self.latent_len, C) diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index c739291b32..519bb30caa 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -10,7 +10,7 @@ """ import logging import math -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn as nn @@ -180,7 +180,8 @@ def forward(self, x) -> Tuple[torch.Tensor, List[int]]: return x, feat_size -def resample_patch_embed( +# FIXME to remove, keeping for comparison for now +def resample_patch_embed_old( patch_embed, new_size: List[int], interpolation: str = 'bicubic', @@ -250,6 +251,191 @@ def resample_kernel(kernel): return patch_embed +DTYPE_INTERMEDIATE = torch.float32 + + +def _compute_resize_matrix( + old_size: Tuple[int, int], + new_size: Tuple[int, int], + interpolation: str, + antialias: bool, + device: torch.device, + dtype: torch.dtype = DTYPE_INTERMEDIATE +) -> torch.Tensor: + """Computes the resize matrix basis vectors and interpolates them to new_size.""" + old_h, old_w = old_size + new_h, new_w = new_size + old_total = old_h * old_w + new_total = new_h * new_w + + eye_matrix = torch.eye(old_total, device=device, dtype=dtype) + basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w) + + resized_basis_vectors_batch = F.interpolate( + basis_vectors_batch, + size=new_size, + mode=interpolation, + antialias=antialias, + align_corners=False + ) # Output shape: (old_total, 1, new_h, new_w) + + resize_matrix = resized_basis_vectors_batch.squeeze(1).reshape(old_total, new_total).T + return resize_matrix # Shape: (new_total, old_total) + + +def _compute_pinv_for_resampling(resize_matrix: torch.Tensor) -> torch.Tensor: + """Calculates the pseudoinverse matrix used for the resampling operation.""" + pinv_matrix = torch.linalg.pinv(resize_matrix.T) # Shape: (new_total, old_total) + return pinv_matrix + + +def _apply_resampling( + patch_embed: torch.Tensor, + pinv_matrix: torch.Tensor, + new_size_tuple: Tuple[int, int], + orig_dtype: torch.dtype, + intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE +) -> torch.Tensor: + """Applies the precomputed pinv_matrix to resample the patch_embed tensor.""" + try: + from torch import vmap + except ImportError: + from functorch import vmap + + def resample_kernel(kernel: torch.Tensor) -> torch.Tensor: + kernel_flat = kernel.reshape(-1).to(intermediate_dtype) + resampled_kernel_flat = pinv_matrix @ kernel_flat + return resampled_kernel_flat.reshape(new_size_tuple) + + resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0) + patch_embed_float = patch_embed.to(intermediate_dtype) + resampled_patch_embed = resample_kernel_vmap(patch_embed_float) + return resampled_patch_embed.to(orig_dtype) + + +def resample_patch_embed( + patch_embed: torch.Tensor, + new_size: List[int], + interpolation: str = 'bicubic', + antialias: bool = True, + verbose: bool = False, +): + """ Standalone function (computes matrix on each call). """ + assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_ch, in_ch, h, w)" + assert len(new_size) == 2, "New shape should only be hw (height, width)" + + old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:]) + new_size_tuple: Tuple[int, int] = tuple(new_size) + + if old_size_tuple == new_size_tuple: + return patch_embed + + device = patch_embed.device + orig_dtype = patch_embed.dtype + + resize_mat = _compute_resize_matrix( + old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE + ) + pinv_matrix = _compute_pinv_for_resampling(resize_mat) + resampled_patch_embed = _apply_resampling( + patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE + ) + return resampled_patch_embed + + +class PatchEmbedResamplerFixedOrigSize(nn.Module): + """ + Resample patch embedding weights from a fixed original size, + caching the pseudoinverse matrix based on the target size. + """ + def __init__( + self, + orig_size: Tuple[int, int], + interpolation: str = 'bicubic', + antialias: bool = True + ): + """ + Args: + orig_size (Tuple[int, int]): The expected original (height, width) of input patch_embed tensors. + interpolation (str): Interpolation mode. + antialias (bool): Use anti-aliasing filter in resize. + """ + super().__init__() + assert isinstance(orig_size, tuple) and len(orig_size) == 2, \ + "`orig_size` must be a tuple of (height, width)" + self.orig_size = orig_size # expected original size + self.interpolation = interpolation + self.antialias = antialias + # Cache map key is the target new_size tuple + self._pinv_cache_map: Dict[Tuple[int, int], str] = {} + + def _get_or_create_pinv_matrix( + self, + new_size: Tuple[int, int], + device: torch.device, + dtype: torch.dtype = DTYPE_INTERMEDIATE + ) -> torch.Tensor: + """Retrieves the cached pinv matrix or computes and caches it for the given new_size.""" + cache_key = new_size + buffer_name = self._pinv_cache_map.get(cache_key) + + if buffer_name and hasattr(self, buffer_name): + pinv_matrix = getattr(self, buffer_name) + if pinv_matrix.device == device and pinv_matrix.dtype == dtype: + return pinv_matrix + + # Calculate the matrix if not cached or needs update + resize_mat = _compute_resize_matrix( + self.orig_size, new_size, self.interpolation, self.antialias, device, dtype + ) + pinv_matrix = _compute_pinv_for_resampling(resize_mat) + + # Cache using register_buffer + buffer_name = f"pinv_{new_size[0]}x{new_size[1]}" + if hasattr(self, buffer_name): + delattr(self, buffer_name) + self.register_buffer(buffer_name, pinv_matrix) + self._pinv_cache_map[cache_key] = buffer_name # Map new_size key to buffer name + + return pinv_matrix + + def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tensor: + """ Resamples the patch embedding weights to new_size. + + Args: + patch_embed (torch.Tensor): Original weights (out_ch, in_ch, H_orig, W_orig). + new_size (List[int]): Target [height, width]. + + Returns: + torch.Tensor: Resampled weights. + """ + assert len(patch_embed.shape) == 4 + assert len(new_size) == 2 + + # Input Validation + input_size = tuple(patch_embed.shape[-2:]) + assert input_size == self.orig_size, \ + f"Input patch_embed spatial size {input_size} does not match " \ + f"module's expected original size {self.orig_size}" + + new_size_tuple: Tuple[int, int] = tuple(new_size) + + # Check no-op case against self.orig_size + if self.orig_size == new_size_tuple: + return patch_embed + + device = patch_embed.device + orig_dtype = patch_embed.dtype + + # Get or compute the required pseudoinverse matrix + pinv_matrix = self._get_or_create_pinv_matrix(new_size_tuple, device) + + # Apply the resampling + resampled_patch_embed = _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype) + + return resampled_patch_embed + + # def divs(n, m=None): # m = m or n // 2 # if m == 1: diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 238c1ccca5..61e2ac7b7e 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -72,6 +72,7 @@ from .visformer import * from .vision_transformer import * from .vision_transformer_hybrid import * +from .vision_transformer_flex import * from .vision_transformer_relpos import * from .vision_transformer_sam import * from .vitamin import * diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 3c7b9a2277..c870e6c222 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -41,8 +41,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ - trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ +from timm.layers import Attention, PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, \ + SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ get_act_layer, get_norm_layer, LayerType from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -55,58 +55,6 @@ _logger = logging.getLogger(__name__) -class Attention(nn.Module): - fused_attn: Final[bool] - - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = False, - proj_bias: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: Type[nn.Module] = nn.LayerNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.fused_attn = use_fused_attn() - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim, bias=proj_bias) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - q, k = self.q_norm(q), self.k_norm(k) - - if self.fused_attn: - x = F.scaled_dot_product_attention( - q, k, v, - dropout_p=self.attn_drop.p if self.training else 0., - ) - else: - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - class LayerScale(nn.Module): def __init__( self, @@ -165,8 +113,8 @@ def __init__( self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x @@ -222,8 +170,8 @@ def init_weights(self) -> None: nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.drop_path1(self.norm1(self.attn(x))) + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = x + self.drop_path1(self.norm1(self.attn(x, attn_mask=attn_mask))) x = x + self.drop_path2(self.norm2(self.mlp(x))) return x @@ -282,7 +230,7 @@ def __init__( self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: B, N, C = x.shape # Combined MLP fc1 & qkv projections @@ -302,14 +250,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.fused_attn: x_attn = F.scaled_dot_product_attention( q, k, v, + attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) + if attn_mask is not None: + attn = attn + attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x_attn = attn @ v + x_attn = x_attn.transpose(1, 2).reshape(B, N, C) x_attn = self.attn_out_proj(x_attn) @@ -379,23 +331,21 @@ def __init__( ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) - def _forward_jit(self, x: torch.Tensor) -> torch.Tensor: - x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if attn_mask is not None: + attn_out = [] + for attn in self.attns: + x_attn = attn.norm(x) + x_attn = attn.attn(x_attn, attn_mask=attn_mask) + x_attn = attn.ls(x_attn) + x_attn = attn.drop_path(x_attn) + attn_out.append(x_attn) + x = x + torch.stack(attn_out).sum(dim=0) + else: + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) return x - @torch.jit.ignore - def _forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + sum(attn(x) for attn in self.attns) - x = x + sum(ffn(x) for ffn in self.ffns) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return self._forward_jit(x) - else: - return self._forward(x) - def global_pool_nlc( x: torch.Tensor, @@ -728,7 +678,9 @@ def forward_intermediates( stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, - ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + output_dict: bool = False, + attn_mask: Optional[torch.Tensor] = None, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]: """ Forward features that returns intermediates. Args: @@ -739,8 +691,11 @@ def forward_intermediates( stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs intermediates_only: Only return intermediate features + output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys + attn_mask: Optional attention mask for masked attention (e.g., for NaFlex) Returns: - + A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing + 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix') """ assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' @@ -759,7 +714,7 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] for i, blk in enumerate(blocks): - x = blk(x) + x = blk(x, attn_mask=attn_mask) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) @@ -776,6 +731,23 @@ def forward_intermediates( # reshape to BCHW output format H, W = self.patch_embed.dynamic_feat_size((height, width)) intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + + # For dictionary output, handle prefix tokens separately + if output_dict: + result_dict = {} + # Intermediates are always included + result_dict['image_intermediates'] = intermediates + if prefix_tokens is not None and return_prefix_tokens: + result_dict['image_intermediates_prefix'] = prefix_tokens + + # Only include features if not intermediates_only + if not intermediates_only: + x_final = self.norm(x) + result_dict['image_features'] = x_final + + return result_dict + + # For non-dictionary output, maintain the original behavior if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: # return_prefix not support in torchscript due to poor type handling intermediates = list(zip(intermediates, prefix_tokens)) @@ -811,6 +783,7 @@ def get_intermediate_layers( reshape: bool = False, return_prefix_tokens: bool = False, norm: bool = False, + attn_mask: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """ Intermediate layer accessor inspired by DINO / DINOv2 interface. NOTE: This API is for backwards compat, favour using forward_intermediates() directly. @@ -821,17 +794,24 @@ def get_intermediate_layers( norm=norm, output_fmt='NCHW' if reshape else 'NLC', intermediates_only=True, + attn_mask=attn_mask, ) - def forward_features(self, x: torch.Tensor) -> torch.Tensor: + def forward_features(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) - if self.grad_checkpointing and not torch.jit.is_scripting(): + + if attn_mask is not None: + # If mask provided, we need to apply blocks one by one + for blk in self.blocks: + x = blk(x, attn_mask=attn_mask) + elif self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) + x = self.norm(x) return x @@ -849,8 +829,8 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso x = self.head_drop(x) return x if pre_logits else self.head(x) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.forward_features(x) + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self.forward_features(x, attn_mask=attn_mask) x = self.forward_head(x) return x diff --git a/timm/models/vision_transformer_flex.py b/timm/models/vision_transformer_flex.py new file mode 100644 index 0000000000..e3a8d8efc3 --- /dev/null +++ b/timm/models/vision_transformer_flex.py @@ -0,0 +1,1112 @@ +""" Vision Transformer (New) + +An improved version of the Vision Transformer with: +1. Encapsulated embedding and position encoding in a single module +2. Support for linear patch embedding on pre-patchified inputs +3. Support for NaFlex functionality (NaViT + FlexiViT) + +Based on: +- Original Vision Transformer: https://arxiv.org/abs/2010.11929 +- FlexiViT: https://arxiv.org/abs/2212.08013 +- NaViT: https://arxiv.org/abs/2307.06304 + +Copyright 2025 +""" + +import logging +import math +from collections import OrderedDict +from functools import partial +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Final, Any, Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert +from timm.models._builder import build_model_with_cfg +from timm.models._features import feature_take_indices +from timm.models._features_fx import register_notrace_function, register_notrace_module +from timm.models._registry import register_model, generate_default_cfgs +from timm.models._manipulate import checkpoint_seq, named_apply + +from .vision_transformer import Block, global_pool_nlc + +_logger = logging.getLogger(__name__) + + +def batch_patchify( + x: torch.Tensor, + patch_size: Tuple[int, int], + pad: bool = True, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + B, C, H, W = x.shape + ph, pw = patch_size + + # Ensure the image is divisible by patch size + if pad and (H % ph != 0 or W % pw != 0): + pad_h = (ph - H % ph) % ph + pad_w = (pw - W % pw) % pw + x = F.pad(x, (0, pad_w, 0, pad_h)) + + nh, nw = H // ph, W // pw + patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C) + # FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw + + return patches, (nh, nw) + + +@register_notrace_module +class FlexEmbeds(nn.Module): + """ Na(Flex) Embedding module for Vision Transformers + + This module encapsulates the complete embedding process for Vision Transformers: + 1. Patch embedding (via Conv2d or Linear) + 2. Class and register token preparation + 3. Position embedding addition + 4. Pre-normalization (if requested) + 5. Dropout application + + Also supports NaFlex functionality (NaViT + FlexiViT): + - Variable aspect ratio and resolution via patch coordinates + - Patch type indicators for handling padding tokens in attention + - Flexible position embedding interpolation for arbitrary grid sizes + + Note: Only supports non-overlapping position and register tokens + (i.e., position embeddings do not include class/register tokens) + + The patch embedding can be one of two types: + 1. Conv2d-based (default): For standard image inputs [B, C, H, W] + 2. Linear-based: For pre-patchified inputs [B, N, P*P*C] + + """ + + def __init__( + self, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + embed_layer: Optional[str] = None, # 'conv' or 'linear', default is 'linear' + input_norm_layer: Optional[Type[nn.Module]] = None, + proj_norm_layer: Optional[Type[nn.Module]] = None, + final_norm_layer: Optional[Type[nn.Module]] = None, + pos_embed: str = 'learned', + pos_drop_rate: float = 0., + patch_drop_rate: float = 0., + class_token: bool = True, + reg_tokens: int = 0, + bias: bool = True, + dynamic_img_pad: bool = False, + pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14), + pos_embed_interp_mode: str = 'bicubic', + default_img_size: Union[int, Tuple[int, int]] = None, + ): + super().__init__() + self.has_class_token = class_token + self.num_reg_tokens = reg_tokens + self.pos_embed_interp_mode = pos_embed_interp_mode + self.patch_size = to_2tuple(patch_size) + self.in_chans = in_chans + self.embed_dim = embed_dim + self.dynamic_img_pad = dynamic_img_pad + + # Calculate number of prefix tokens + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + + # Create class and register tokens + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + + # Calculate grid size and number of patches + self.default_img_size: Optional[Tuple[int, int]] = None + self.pos_embed_grid_size: Optional[ + Tuple[int, int]] = None # Stores the grid size used for learned pos embed init + if pos_embed_grid_size is None and default_img_size is not None: + self.default_img_size = to_2tuple(default_img_size) + self.pos_embed_grid_size = tuple([s // p for s, p in zip(self.default_img_size, self.patch_size)]) + elif pos_embed_grid_size is not None: + # Use provided pos_embed_grid_size for NaFlex mode + self.pos_embed_grid_size = pos_embed_grid_size + + # Determine patch embedding type (linear or conv2d) + if embed_layer == 'linear': + # Create linear projection for pre-patchified inputs + # Input dimension is patch_size^2 * in_chans + patch_dim = self.patch_size[0] * self.patch_size[1] * in_chans + self.norm_input = proj_norm_layer(patch_dim) if input_norm_layer else None + self.proj = nn.Linear(patch_dim, embed_dim, bias=bias) + self.flatten = False + self.is_linear = True + else: + # Default to convolutional patch embedding for image inputs + assert input_norm_layer is None + self.norm_input = None + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias + ) + self.flatten = True + self.is_linear = False + + # Create normalization layer after the projection + self.norm_proj = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity() + + # Create position embedding if needed - only for patches, never for prefix tokens + if not pos_embed or pos_embed == 'none': + self.pos_embed = None + self.pos_embed_type = 'none' + elif pos_embed == 'rope': + self.pos_embed = None + self.pos_embed_type = 'rope' + # Rotary embeddings will be computed on-the-fly in the forward pass + else: + # Store position embedding in (1, H, W, dim) format for easier resizing + if self.pos_embed_grid_size is not None: + h, w = self.pos_embed_grid_size + self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02) + self.pos_embed_type = 'learned' + else: + raise ValueError("Cannot initialize position embeddings without grid_size. " + "Please provide img_size or pos_embed_grid_size") + + # Pre-normalization layer (separate from the main norm layer) + self.norm_final = final_norm_layer(embed_dim) if final_norm_layer is not None else nn.Identity() + + # Dropout layers + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + from timm.layers.patch_dropout import PatchDropout + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + + def feature_info(self, location): + """Feature info utility method for feature extraction.""" + return dict(num_chs=self.embed_dim, reduction=self.patch_size) + + def feat_ratio(self, as_scalar=True): + """Return the feature reduction ratio (stride).""" + if as_scalar: + return max(self.patch_size) + else: + return self.patch_size + + def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: + """ Get grid (feature) size for given image size taking account of dynamic padding. + """ + if self.dynamic_img_pad: + return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1]) + else: + return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] + + def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None): + """Forward pass for combined embedding + + Args: + x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C] + patch_coord: Optional patch coordinates [B, N, 2] for NaFlex + + Returns: + Embedded tensor with position encoding and class/register tokens applied + If patch_type is provided, also returns attention mask + """ + # Apply patch embedding + naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None + grid_size: Optional[List[int]] = None + + B = x.shape[0] + if self.is_linear: + # Linear embedding path, works with NaFlex mode or standard 2D mode + if patch_coord is not None: + _assert(x.ndim == 3, 'Expecting patchified input with ndim == 3') + # Pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches + # Calculate the appropriate grid size from coords + max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1 + max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1 + naflex_grid_sizes = [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)] + else: + _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4') + x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad) + + if self.norm_input is not None: + x = self.norm_input(x) + + x = self.proj(x) + else: + assert x.ndim == 4, 'Convolutional input must be 4D' + if self.dynamic_img_pad: + H, W = x.shape[-2:] + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + + x = self.proj(x) + + grid_size = x.shape[-2:] + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + + # Apply normalization after flattening + x = self.norm_proj(x) + + if self.pos_embed_type == 'learned': + if naflex_grid_sizes is not None: + self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) + else: + assert grid_size is not None + self._apply_learned_pos_embed(x, grid_size=grid_size) + elif self.pos_embed_type == 'rope': + assert False, "ROPE not yet implemented" + + # Prepare and add class and register tokens + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(B, -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(B, -1, -1)) + # Add tokens to the beginning + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + + # Apply final pre-transformer normalization if specified + x = self.norm_final(x) + + # Apply dropouts + x = self.patch_drop(self.pos_drop(x)) + + return x + + #@torch.compiler.disable() + def _apply_learned_naflex_pos_embed( + self, + x: torch.Tensor, + naflex_grid_sizes: List[Tuple[int, int]], + ): + # Handle each batch element separately with its own grid size + orig_h, orig_w = self.pos_embed.shape[1:3] + pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W + + def _interp(_size): + if (_size[0] == orig_h) and (_size[1] == orig_w): + pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) + else: + pos_embed_flat = F.interpolate( + pos_embed_nchw, + size=_size, + mode=self.pos_embed_interp_mode, + align_corners=False, + antialias=True, + ).flatten(2).transpose(1, 2) + return pos_embed_flat.to(dtype=x.dtype) + + # FIXME leaving alternative code commented here for now for comparisons + # pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {} + # for i, s in enumerate(naflex_grid_sizes): + # if s in pos_embed_cache: + # pos_embed_flat = pos_embed_cache[s] + # else: + # pos_embed_flat = _interp(s) + # pos_embed_cache[s] = pos_embed_flat + # + # seq_len = min(x.shape[1], pos_embed_flat.shape[1]) + # x[i, :seq_len] += pos_embed_flat[0, :seq_len] + + # Determine unique grid sizes + size_to_indices: Dict[Tuple[int, int], List[int]] = {} + for bi, k in enumerate(naflex_grid_sizes): + # k = h << 16 | w # FIXME can get jit compat with this + size_to_indices.setdefault(k, []).append(bi) + + for k, batch_indices in size_to_indices.items(): + # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this + # Interpolate only once for this (h, w) + pos_embed_flat = _interp(k) + seq_len = min(x.shape[1], pos_embed_flat.shape[1]) + x[:, :seq_len].index_add_( + 0, + torch.as_tensor(batch_indices, device=x.device), + pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1) + ) + + def _apply_learned_pos_embed( + self, + x: torch.Tensor, + grid_size: List[int], + ): + orig_h, orig_w = self.pos_embed.shape[1:3] + if grid_size[0] == orig_h or grid_size[1] == orig_w: + # No resize needed, just flatten + pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) + else: + # Resize if needed - directly using F.interpolate + pos_embed_flat = F.interpolate( + self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W + size=grid_size, + mode=self.pos_embed_interp_mode, + align_corners=False, + antialias=True, + ).flatten(2).transpose(1, 2) + pos_embed_flat = pos_embed_flat.to(dtype=x.dtype) + + x.add_(pos_embed_flat) + + +@register_notrace_function +def create_attention_mask( + patch_valid: torch.Tensor, + num_prefix_tokens: int = 0, + symmetric: bool = True, + q_len: Optional[int] = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Creates an attention mask from patch validity information. + + Supports two modes controlled by `symmetric`: + 1. `symmetric=True` (default): Creates a symmetric mask of shape + [B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if + both token i and token j are valid. Suitable for standard self-attention. + 2. `symmetric=False`: Creates a potentially non-square mask of shape + [B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if + the key/value token k is valid. Query token validity is not checked + in the mask itself. Useful for cross-attention or specific self-attention + implementations `q_len` can be specified. + + Used for NaFlex mode to handle variable token counts and padding tokens. + + Args: + patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding. + num_prefix_tokens: Number of prefix tokens (class token, register tokens) + to prepend, which are always considered valid. + symmetric: If True, create a symmetric mask. + If False, create an expanded mask based only on key/value validity. + q_len: Query sequence length override. Only used when `symmetric` is False. + Defaults to the key/value sequence length (`kv_len`) if None. + dtype: Dtype of the output attention mask (e.g., torch.float32). + + Returns: + Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked). + Shape is [B, 1, seq_len, seq_len] if symmetric=True, + or [B, 1, q_len, kv_len] if symmetric=False. + """ + patch_valid = patch_valid.bool() # Ensure boolean type + B, N = patch_valid.shape + kv_len = N # Initial key/value length is the number of patches + + # Prepend prefix tokens if any + if num_prefix_tokens > 0: + # Create prefix validity tensor on the same device/dtype base as patch_valid + prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool) + # Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N] + patch_valid = torch.cat([prefix_valid, patch_valid], dim=1) + kv_len += num_prefix_tokens # Update total key/value sequence length + + if symmetric: + # Symmetric mask is True where BOTH query and key are valid + mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1) + mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len] + else: + # Expanded mask + q_len = q_len or kv_len + mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len) + + # Create the float mask and apply masking using additive mask convention + mask_float = torch.zeros_like(mask_bool, dtype=dtype) + # Fill with negative infinity where mask_bool is False (masked positions) + mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min) + + return mask_float + + +class VisionTransformerFlex(nn.Module): + """ Vision Transformer (Na)Flex + + A flexible implementation of Vision Transformer with: + 1. Encapsulated embedding and position encoding in a single module + 2. Support for linear patch embedding on pre-patchified inputs + 3. Support for variable sequence length / aspect ratio images (NaFlex) + """ + + def __init__( + self, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_norm: bool = False, + proj_bias: bool = True, + init_values: Optional[float] = None, + class_token: bool = False, + reg_tokens: int = 0, + pos_embed: str = 'learn', + pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16), + pos_embed_interp_mode: str = 'bicubic', + default_img_size: Union[int, Tuple[int, int]] = 256, + dynamic_img_pad: bool = False, + pre_norm: bool = False, + final_norm: bool = True, + fc_norm: Optional[bool] = None, + num_classes: int = 1000, + global_pool: str = 'map', + drop_rate: float = 0., + pos_drop_rate: float = 0., + patch_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init: str = '', + fix_init: bool = True, + embed_layer_type: str = 'linear', + embed_norm_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ) -> None: + """ + Args: + patch_size: Patch size. + in_chans: Number of image input channels. + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + reg_tokens: Number of register tokens. + pos_embed: Type of position embedding. + pos_embed_grid_size: Size of position embedding grid. + pos_embed_interp_mode: Interpolation mode for position embedding. + default_img_size: Input image size. + pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT). + final_norm: Enable norm after transformer blocks, before head (standard in most ViT). + fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + num_classes: Number of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + fix_init: Apply weight initialization fix (scaling w/ layer index). + embed_layer_type: Patch embedding implementation (e.g., 'linear', 'conv'). + embed_norm_layer: Normalization layer to use / override in patch embed module. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') + assert class_token or global_pool != 'token' + assert pos_embed in ('', 'none', 'learn') + norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + embed_norm_layer = get_norm_layer(embed_norm_layer) + act_layer = get_act_layer(act_layer) or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.grad_checkpointing = False + + # Initialize embedding module (includes patch, position embedding, and class/reg tokens) + # VisionTransformerEmbeds is always used - handles both linear and conv embedding + self.embeds = FlexEmbeds( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + embed_layer=embed_layer_type, + proj_norm_layer=embed_norm_layer, + final_norm_layer=norm_layer if pre_norm else None, + pos_embed=pos_embed, + pos_embed_grid_size=pos_embed_grid_size, + pos_embed_interp_mode=pos_embed_interp_mode, + pos_drop_rate=pos_drop_rate, + patch_drop_rate=patch_drop_rate, + class_token=class_token, + reg_tokens=reg_tokens, + default_img_size=default_img_size, + dynamic_img_pad=dynamic_img_pad, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + ) + + # Transformer blocks + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_bias=proj_bias, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth)]) + + # Feature info for downstream tasks + patch_reduction = to_2tuple(patch_size) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_reduction) for i in range(depth)] + + self.norm = norm_layer(embed_dim) if final_norm and not fc_norm else nn.Identity() + + # Classifier Head + if global_pool == 'map': + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + ) + else: + self.attn_pool = None + + self.fc_norm = norm_layer(embed_dim) if final_norm and fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + if fix_init: + self.fix_init_weight() + + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def init_weights(self, mode: str = '') -> None: + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + named_apply(get_init_weights_vit(mode, head_bias), self) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: + # Custom loading for the new model structure + from .vision_transformer import _load_weights as _orig_load_weights + + def _load_weights_adapter(model, checkpoint_path, prefix=''): + """Adapter function to handle the different model structure""" + state_dict = torch.load(checkpoint_path, map_location='cpu') + if isinstance(state_dict, dict) and 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + # Map original keys to new structure + for k in list(state_dict.keys()): + if k.startswith('cls_token'): + state_dict['embeds.' + k] = state_dict.pop(k) + elif k.startswith('reg_token'): + state_dict['embeds.' + k] = state_dict.pop(k) + elif k.startswith('pos_embed'): + state_dict['embeds.' + k] = state_dict.pop(k) + elif k.startswith('patch_embed'): + state_dict['embeds.' + k[12:]] = state_dict.pop(k) + + return _orig_load_weights(model, state_dict, prefix) + + _load_weights_adapter(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'} + return skip_list + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r'^embeds', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'): + self.embeds.patch_embed.set_grad_checkpointing(enable) + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') + if global_pool == 'map' and self.attn_pool is None: + assert False, "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != 'map' and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + output_dict: bool = False, + patch_coord: Optional[torch.Tensor] = None, + patch_valid: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys + patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode + patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex + mask: Optional attention mask + Returns: + A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing + 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix') + """ + + # FIXME unfinished / untested + + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # Create attention mask if patch_type is provided and mask is not + if mask is None and patch_valid is not None: + mask = create_attention_mask(patch_valid, self.num_prefix_tokens, x.dtype) + + # Forward pass through embedding + x = self.embeds(x, patch_coord=patch_coord, patch_valid=patch_valid) + + # Forward pass through blocks + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + + for i, blk in enumerate(blocks): + x = blk(x, attn_mask=mask) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # Process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + else: + prefix_tokens = None + + if reshape: + # reshape to BCHW output format + grid_size = self.embeds.pos_embed_grid_size + if hasattr(self.embeds, 'dynamic_feat_size') and len(x.shape) >= 4: + _, height, width, _ = x.shape if len(x.shape) == 4 else (None, *x.shape[-3:-1], None) + H, W = self.embeds.dynamic_feat_size((height, width)) + else: + H, W = grid_size + intermediates = [y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous() + for y in intermediates] + + # For dictionary output + if output_dict: + result_dict = {} + # Intermediates are always included + result_dict['image_intermediates'] = intermediates + if prefix_tokens is not None and return_prefix_tokens: + result_dict['image_intermediates_prefix'] = prefix_tokens + + # Only include features if not intermediates_only + if not intermediates_only: + x_final = self.norm(x) + result_dict['image_features'] = x_final + + return result_dict + + # For non-dictionary output, maintain the original behavior + if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def forward_features( + self, + x: torch.Tensor, + patch_coord: Optional[torch.Tensor] = None, + patch_valid: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if attn_mask is None and patch_valid is not None: + attn_mask = create_attention_mask( + patch_valid, + num_prefix_tokens=self.num_prefix_tokens, + dtype=x.dtype + ) + + # Pass through embedding module with patch coordinate/type support + x = self.embeds(x, patch_coord=patch_coord) + + # Apply transformer blocks with masked attention if mask provided + if attn_mask is not None: + # We need to apply blocks one by one with mask + for blk in self.blocks: + x = blk(x, attn_mask=attn_mask) + elif self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + + x = self.norm(x) + return x + + def _pool( + self, + x: torch.Tensor, + pool_type: Optional[str] = None, + patch_valid: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.attn_pool is not None: + # For attention pooling, we need to pass the mask for NaFlex models + attn_mask = create_attention_mask( + patch_valid, + symmetric=False, + q_len=1, + dtype=x.dtype, + ) + x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask) + return x + + pool_type = self.global_pool if pool_type is None else pool_type + + # Handle padding mask for average pooling + if patch_valid is not None and pool_type in ('avg', 'avgmax'): + # For NaFlex mode, we need to apply masked pooling to exclude padding tokens + # Extract only the patch part of the mask (excluding prefix tokens) + if self.num_prefix_tokens > 0: + # Apply the mask to extract only valid tokens + x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling + + patch_valid_float = patch_valid.to(x.dtype) + if pool_type == 'avg': + # Compute masked average pooling, sum valid tokens and divide by count of valid tokens + masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) + valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) + pooled = masked_sums / valid_counts + return pooled + elif pool_type == 'avgmax': + # For avgmax, compute masked average and masked max + masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) + valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) + masked_avg = masked_sums / valid_counts + + # For max pooling we set masked positions to large negative value + masked_x = x.clone() + masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min + masked_max = masked_x.max(dim=1)[0] + + # Combine average and max + return 0.5 * (masked_avg + masked_max) + + # Fall back to standard pooling + x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens) + return x + + def forward_head( + self, + x: torch.Tensor, + pre_logits: bool = False, + patch_valid: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = self._pool(x, patch_valid=patch_valid) + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward( + self, + x: Union[torch.Tensor, Dict[str, torch.Tensor]], + patch_coord: Optional[torch.Tensor] = None, + patch_valid: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with optional NaFlex support. + + Args: + x: Input tensor [B, C, H, W] or pre-patchified tensor [B, N, P*P*C] or NaFlex dict + patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode + patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex + + Returns: + Model output tensor + """ + if isinstance(x, Dict): + # Handle dictionary input from NaFlex collator + patch_coord = x['patch_coord'] + patch_valid = x['patch_valid'] + patches = x['patches'] + + # DEBUG, reconstruct patches + # for i in range(len(patches)): + # patch = patches[i][patch_valid[i]] + # h = (patch_coord[i, :, 0].max() + 1).item() + # w = (patch_coord[i, :, 1].max() + 1).item() + # patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3) + # patch = patch.reshape(3, h*16, w*16) + # from torchvision.utils import save_image + # save_image(patch, f'patch_{i}.jpg', normalize=True) + else: + patches = x + + # Create attention mask if patch_type is provided + if patch_valid is not None: + attn_mask = create_attention_mask( + patch_valid, + num_prefix_tokens=self.num_prefix_tokens, + dtype=patches.dtype + ) + else: + attn_mask = None + + # Forward features with mask + x = self.forward_features( + patches, + patch_coord=patch_coord, + patch_valid=patch_valid, + attn_mask=attn_mask, + ) + + # Pass mask to forward_head for masked pooling + x = self.forward_head( + x, + patch_valid=patch_valid, + ) + return x + + +def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: + """Function imported from vision_transformer.py to maintain compatibility""" + from .vision_transformer import init_weights_vit_jax, init_weights_vit_moco, init_weights_vit_timm + + if 'jax' in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif 'moco' in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm + + +def checkpoint_filter_fn(state_dict, model): + """Handle state dict conversion from original ViT to the new version with combined embedding.""" + from .vision_transformer import checkpoint_filter_fn as orig_filter_fn + + # Handle CombinedEmbed module pattern + out_dict = {} + for k, v in state_dict.items(): + # Convert tokens and embeddings to combined_embed structure + if k == 'pos_embed': + # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C) + if hasattr(model.embeds, 'pos_embed') and v.ndim == 3: + num_cls_token = 0 + num_reg_token = 0 + if 'reg_token' in state_dict: + num_reg_token = state_dict['reg_token'].shape[1] + if 'cls_token' in state_dict: + num_cls_token = state_dict['cls_token'].shape[1] + num_prefix_tokens = num_cls_token + num_reg_token + + # Original format is (1, N, C), need to reshape to (1, H, W, C) + num_patches = v.shape[1] + num_patches_no_prefix = num_patches - num_prefix_tokens + grid_size_no_prefix = math.sqrt(num_patches_no_prefix) + grid_size = math.sqrt(num_patches) + if (grid_size_no_prefix != grid_size and ( + grid_size_no_prefix.is_integer() and not grid_size.is_integer())): + # make a decision, did the pos_embed of the original include the prefix tokens? + num_patches = num_patches_no_prefix + cls_token_emb = v[:, 0:num_cls_token] + if cls_token_emb.numel(): + state_dict['cls_token'] += cls_token_emb + reg_token_emb = v[:, num_cls_token:num_reg_token] + if reg_token_emb.numel(): + state_dict['reg_token'] += reg_token_emb + v = v[:, num_prefix_tokens:] + grid_size = grid_size_no_prefix + grid_size = int(grid_size) + + # Check if it's a perfect square for a standard grid + if grid_size * grid_size == num_patches: + # Reshape from (1, N, C) to (1, H, W, C) + v = v.reshape(1, grid_size, grid_size, v.shape[2]) + else: + # Not a square grid, we need to get the actual dimensions + if hasattr(model.embeds.patch_embed, 'grid_size'): + h, w = model.embeds.patch_embed.grid_size + if h * w == num_patches: + # We have the right dimensions + v = v.reshape(1, h, w, v.shape[2]) + else: + # Dimensions don't match, use interpolation + _logger.warning( + f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. " + f"Using default initialization and will resize in forward pass." + ) + # Keep v as is, the forward pass will handle resizing + + out_dict['embeds.pos_embed'] = v + elif k == 'cls_token': + out_dict['embeds.cls_token'] = v + elif k == 'reg_token': + out_dict['embeds.reg_token'] = v + # Convert patch_embed.X to embeds.patch_embed.X + elif k.startswith('patch_embed.'): + suffix = k[12:] + if suffix == 'proj.weight': + # FIXME confirm patchify memory layout across use cases + v = v.permute(0, 2, 3, 1).flatten(1) + new_key = 'embeds.' + suffix + out_dict[new_key] = v + else: + out_dict[k] = v + + return out_dict + + +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + return { + 'url': url, + 'num_classes': 1000, + 'input_size': (3, 256, 256), + 'pool_size': None, + 'crop_pct': 0.95, + 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, + 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'embeds.proj', + 'classifier': 'head', + 'license': 'apache-2.0', + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'vit_naflex_base_patch16_gap': _cfg(), + 'vit_naflex_base_patch16_map': _cfg(), + + # sbb model testijg + 'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r256_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k', + input_size=(3, 256, 256), crop_pct=1.0), + 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r384_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r448_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k', + input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'), + + # traditional vit testing + 'vit_naflex_base_patch16.augreg2_r224_in21k_ft_in1k': _cfg( + hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'), + 'vit_naflex_base_patch8.augreg2_r224_in21k_ft_in1k': _cfg( + hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'), +}) + + +def _create_vision_transformer_flex(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + VisionTransformerFlex, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs, + ) + return model + + +@register_model +def vit_naflex_base_patch16_gap(pretrained=False, **kwargs): + """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, + global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs) + model = _create_vision_transformer_flex( + 'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_naflex_base_patch16_map(pretrained=False, **kwargs): + """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, + global_pool='map', reg_tokens=1) + model = _create_vision_transformer_flex( + 'vit_naflex_base_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs): + """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. + + This model supports: + 1. Variable aspect ratios and resolutions via patch coordinates + 2. Position embedding interpolation for arbitrary grid sizes + 3. Explicit patch coordinates and valid token masking + """ + model_args = dict( + patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, + qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True) + model = _create_vision_transformer_flex( + 'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_naflex_base_patch16(pretrained: bool = False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, + global_pool='token', class_token=True, pos_embed_grid_size=(14, 14)) + model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) + return model diff --git a/train.py b/train.py index fda3c18413..efbe28929d 100755 --- a/train.py +++ b/train.py @@ -33,7 +33,8 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm import utils -from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset +from timm.data import create_dataset, create_loader, create_naflex_loader, resolve_data_config, \ + Mixup, FastCollateMixup, AugMixDataset from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters @@ -396,6 +397,16 @@ group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID', help='If resuming a run, the id of the run in wandb') +# NaFlex scheduled loader arguments +group.add_argument('--naflex-loader', action='store_true', default=False, + help='Use NaFlex loader (Requires NaFlex compatible model)') +group.add_argument('--naflex-train-seq-lens', type=int, nargs='+', default=[128, 256, 576, 784, 1024], + help='Sequence lengths to use for NaFlex loader') +group.add_argument('--naflex-max-seq-len', type=int, default=576, + help='Fixed maximum sequence length for NaFlex loader (validation)') +group.add_argument('--naflex-loss-scale', default='linear', type=str, + help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")') + def _parse_args(): # Do we have a config file to parse? @@ -669,6 +680,7 @@ def main(): trust_remote_code=args.dataset_trust_remote_code, ) + dataset_eval = None if args.val_split: dataset_eval = create_dataset( args.dataset, @@ -690,6 +702,7 @@ def main(): mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: + assert not args.naflex_loader, "Mixup/Cutmix not currently supported for NaFlex loading." mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, @@ -714,9 +727,19 @@ def main(): train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] - loader_train = create_loader( - dataset_train, - input_size=data_config['input_size'], + + # Check if we should use the NaFlex scheduled loader + common_loader_kwargs = dict( + mean=data_config['mean'], + std=data_config['std'], + pin_memory=args.pin_mem, + img_dtype=model_dtype or torch.float32, + device=device, + distributed=args.distributed, + use_prefetcher=args.prefetcher, + ) + + train_loader_kwargs = dict( batch_size=args.batch_size, is_training=True, no_aug=args.no_aug, @@ -737,42 +760,70 @@ def main(): num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, - mean=data_config['mean'], - std=data_config['std'], num_workers=args.workers, - distributed=args.distributed, - collate_fn=collate_fn, - pin_memory=args.pin_mem, - img_dtype=model_dtype or torch.float32, - device=device, - use_prefetcher=args.prefetcher, - use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) + naflex_mode = False + if args.naflex_loader: + if utils.is_primary(args): + _logger.info('Using NaFlex loader') + + naflex_mode = True + loader_train = create_naflex_loader( + dataset=dataset_train, + patch_size=16, # Could be derived from model config + train_seq_lens=args.naflex_train_seq_lens, + rank=args.rank, + world_size=args.world_size, + **common_loader_kwargs, + **train_loader_kwargs, + ) + else: + # Use standard loader + loader_train = create_loader( + dataset_train, + input_size=data_config['input_size'], + collate_fn=collate_fn, + use_multi_epochs_loader=args.use_multi_epochs_loader, + **common_loader_kwargs, + **train_loader_kwargs, + ) + loader_eval = None if args.val_split: + assert dataset_eval is not None eval_workers = args.workers if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset): # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training eval_workers = min(2, args.workers) - loader_eval = create_loader( - dataset_eval, - input_size=data_config['input_size'], + + eval_loader_kwargs = dict( batch_size=args.validation_batch_size or args.batch_size, is_training=False, interpolation=data_config['interpolation'], - mean=data_config['mean'], - std=data_config['std'], num_workers=eval_workers, - distributed=args.distributed, crop_pct=data_config['crop_pct'], - pin_memory=args.pin_mem, - img_dtype=model_dtype or torch.float32, - device=device, - use_prefetcher=args.prefetcher, ) + if args.naflex_loader: + # Use largest sequence length for validation + loader_eval = create_naflex_loader( + dataset=dataset_eval, + patch_size=16, # Could be derived from model config + max_seq_len=args.naflex_max_seq_len, + **common_loader_kwargs, + **eval_loader_kwargs + ) + else: + # Use standard loader + loader_eval = create_loader( + dataset_eval, + input_size=data_config['input_size'], + **common_loader_kwargs, + **eval_loader_kwargs, + ) + # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set @@ -901,6 +952,7 @@ def main(): model_ema=model_ema, mixup_fn=mixup_fn, num_updates_total=num_epochs * updates_per_epoch, + naflex_mode=naflex_mode, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -1003,6 +1055,7 @@ def train_one_epoch( model_ema=None, mixup_fn=None, num_updates_total=None, + naflex_mode=False, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -1048,10 +1101,10 @@ def train_one_epoch( def _forward(): with amp_autocast(): output = model(input) - loss = loss_fn(output, target) + _loss = loss_fn(output, target) if accum_steps > 1: - loss /= accum_steps - return loss + _loss /= accum_steps + return _loss def _backward(_loss): if loss_scaler is not None: @@ -1075,16 +1128,53 @@ def _backward(_loss): ) optimizer.step() - if has_no_sync and not need_update: - with model.no_sync(): + if naflex_mode: + assert isinstance(input, dict) + batch_size = input['patches'].shape[0] + + # scale gradient vs the minimum batch size (for max seq len) + if not args.naflex_loss_scale or args.naflex_loss_scale == 'none': + local_scale = 1.0 + else: + local_scale = (batch_size / args.batch_size) + if local_scale == 'sqrt': + local_scale = local_scale ** 0.5 + + if args.distributed: + # scale gradient btw distributed ranks, each one can have different batch size + global_batch_size = utils.reduce_tensor( + torch.tensor(batch_size, device=device, dtype=torch.float32), + 1 # SUM + ) + dist_scale = args.world_size * batch_size / global_batch_size + else: + dist_scale = None + + if has_no_sync and not need_update: + with model.no_sync(): + loss = _forward() + scaled_loss = local_scale * loss + if dist_scale is not None: + scaled_loss *= dist_scale + _backward(scaled_loss) + else: loss = _forward() - _backward(loss) + scaled_loss = local_scale * loss + if dist_scale is not None: + scaled_loss *= dist_scale + _backward(scaled_loss) else: - loss = _forward() - _backward(loss) + batch_size = input.shape[0] + if has_no_sync and not need_update: + with model.no_sync(): + loss = _forward() + _backward(loss) + else: + loss = _forward() + _backward(loss) - losses_m.update(loss.item() * accum_steps, input.size(0)) - update_sample_count += input.size(0) + losses_m.update(loss.item() * accum_steps, batch_size) + update_sample_count += batch_size if not need_update: data_start_time = time.time() @@ -1101,7 +1191,8 @@ def _backward(_loss): elif device.type == 'npu': torch.npu.synchronize() time_now = time.time() - update_time_m.update(time.time() - update_start_time) + + update_time_m.update((time.time() - update_start_time) / update_sample_count, update_sample_count) update_start_time = time_now if update_idx % args.log_interval == 0: @@ -1120,8 +1211,8 @@ def _backward(_loss): f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} ' f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] ' f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) ' - f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s ' - f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) ' + f'Time: {update_time_m.val:.3f}s, {1 / update_time_m.val:>7.2f}/s ' + f'({update_time_m.avg:.3f}s, {1 / update_time_m.avg:>7.2f}/s) ' f'LR: {lr:.3e} ' f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})' ) @@ -1210,9 +1301,10 @@ def validate( elif device.type == "npu": torch.npu.synchronize() - losses_m.update(reduced_loss.item(), input.size(0)) - top1_m.update(acc1.item(), output.size(0)) - top5_m.update(acc5.item(), output.size(0)) + batch_size = output.shape[0] + losses_m.update(reduced_loss.item(), batch_size) + top1_m.update(acc1.item(), batch_size) + top5_m.update(acc5.item(), batch_size) batch_time_m.update(time.time() - end) end = time.time() diff --git a/validate.py b/validate.py index f757855e4e..59e78a91fd 100755 --- a/validate.py +++ b/validate.py @@ -158,6 +158,12 @@ parser.add_argument('--retry', default=False, action='store_true', help='Enable batch size decay & retry for single model validation') +# NaFlex loader arguments +parser.add_argument('--naflex-loader', action='store_true', default=False, + help='Use NaFlex loader (Requires NaFlex compatible model)') +parser.add_argument('--naflex-max-seq-len', type=int, default=576, + help='Fixed maximum sequence length for NaFlex loader (validation)') + def validate(args): # might as well try to validate something @@ -293,23 +299,43 @@ def validate(args): real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] - loader = create_loader( - dataset, - input_size=data_config['input_size'], - batch_size=args.batch_size, - use_prefetcher=args.prefetcher, - interpolation=data_config['interpolation'], - mean=data_config['mean'], - std=data_config['std'], - num_workers=args.workers, - crop_pct=crop_pct, - crop_mode=data_config['crop_mode'], - crop_border_pixels=args.crop_border_pixels, - pin_memory=args.pin_mem, - device=device, - img_dtype=model_dtype or torch.float32, - tf_preprocessing=args.tf_preprocessing, - ) + if args.naflex_loader: + from timm.data import create_naflex_loader + loader = create_naflex_loader( + dataset, + batch_size=args.batch_size, + use_prefetcher=args.prefetcher, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=crop_pct, + crop_mode=data_config['crop_mode'], + crop_border_pixels=args.crop_border_pixels, + pin_memory=args.pin_mem, + device=device, + img_dtype=model_dtype or torch.float32, + patch_size=16, # Could be derived from model config + max_seq_len=args.naflex_max_seq_len, + ) + else: + loader = create_loader( + dataset, + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=args.prefetcher, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=crop_pct, + crop_mode=data_config['crop_mode'], + crop_border_pixels=args.crop_border_pixels, + pin_memory=args.pin_mem, + device=device, + img_dtype=model_dtype or torch.float32, + tf_preprocessing=args.tf_preprocessing, + ) batch_time = AverageMeter() losses = AverageMeter() @@ -345,10 +371,11 @@ def validate(args): real_labels.add_result(output) # measure accuracy and record loss + batch_size = output.shape[0] acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) - losses.update(loss.item(), input.size(0)) - top1.update(acc1.item(), input.size(0)) - top5.update(acc5.item(), input.size(0)) + losses.update(loss.item(), batch_size) + top1.update(acc1.item(), batch_size) + top5.update(acc5.item(), batch_size) # measure elapsed time batch_time.update(time.time() - end) @@ -364,7 +391,7 @@ def validate(args): batch_idx, len(loader), batch_time=batch_time, - rate_avg=input.size(0) / batch_time.avg, + rate_avg=batch_size / batch_time.avg, loss=losses, top1=top1, top5=top5