From df53d6125bba7bcf22fc1c79d5efcfe32106ab9d Mon Sep 17 00:00:00 2001 From: sravanneeli Date: Sat, 1 Feb 2025 21:39:57 -0800 Subject: [PATCH 1/3] dinov2 base arch --- .../src/models/dinov2/dinov2_backbone.py | 124 ++++++++++++++ .../models/dinov2/dinov2_image_converter.py | 73 ++++++++ keras_hub/src/models/dinov2/dinov2_layers.py | 162 ++++++++++++++++++ 3 files changed, 359 insertions(+) create mode 100644 keras_hub/src/models/dinov2/dinov2_backbone.py create mode 100644 keras_hub/src/models/dinov2/dinov2_image_converter.py create mode 100644 keras_hub/src/models/dinov2/dinov2_layers.py diff --git a/keras_hub/src/models/dinov2/dinov2_backbone.py b/keras_hub/src/models/dinov2/dinov2_backbone.py new file mode 100644 index 0000000000..7c8e8cbdba --- /dev/null +++ b/keras_hub/src/models/dinov2/dinov2_backbone.py @@ -0,0 +1,124 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.dinov2.dinov2_layers import Dinov2PatchAndEmbeddings +from keras_hub.src.utils.keras_utils import standardize_data_format + + +@keras_hub_export("keras_hub.models.DinoV2Backbone") +class DinoV2Backbone(Backbone): + """DinoV2 backbone. + + This backbone implements the Vision Transformer architecture as described in + [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). + It transforms the input image into a sequence of patches, embeds them, and + then processes them through a series of Transformer encoder layers. + + Args: + image_shape: A tuple or list of 3 integers representing the shape of the + input image `(height, width, channels)`, `height` and `width` must + be equal. + patch_size: (int, int). The size of each image patch, the input image + will be divided into patches of shape + `(patch_size_h, patch_size_w)`. + num_layers: int. The number of transformer encoder layers. + num_heads: int. specifying the number of attention heads in each + Transformer encoder layer. + hidden_dim: int. The dimensionality of the hidden representations. + mlp_dim: int. The dimensionality of the intermediate MLP layer in + each Transformer encoder layer. + dropout_rate: float. The dropout rate for the Transformer encoder + layers. + attention_dropout: float. The dropout rate for the attention mechanism + in each Transformer encoder layer. + layer_norm_epsilon: float. Value used for numerical stability in + layer normalization. + use_mha_bias: bool. Whether to use bias in the multi-head + attention layers. + use_mlp_bias: bool. Whether to use bias in the MLP layers. + data_format: str. `"channels_last"` or `"channels_first"`, specifying + the data format for the input image. If `None`, defaults to + `"channels_last"`. + dtype: The dtype of the layer weights. Defaults to None. + **kwargs: Additional keyword arguments to be passed to the parent + `Backbone` class. + """ + + def __init__( + self, + image_shape, + patch_size, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout_rate=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + use_mha_bias=True, + use_mlp_bias=True, + data_format=None, + dtype=None, + **kwargs, + ): + # === Laters === + data_format = standardize_data_format(data_format) + h_axis, w_axis, channels_axis = ( + (-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3) + ) + # Check that the input image is well specified. + if image_shape[h_axis] is None or image_shape[w_axis] is None: + raise ValueError( + f"Image shape must have defined height and width. Found `None` " + f"at index {h_axis} (height) or {w_axis} (width). " + f"Image shape: {image_shape}" + ) + + if image_shape[h_axis] % patch_size[0] != 0: + raise ValueError( + f"Input height {image_shape[h_axis]} should be divisible by " + f"patch size {patch_size[0]}." + ) + + if image_shape[w_axis] % patch_size[1] != 0: + raise ValueError( + f"Input width {image_shape[h_axis]} should be divisible by " + f"patch size {patch_size[1]}." + ) + + num_channels = image_shape[channels_axis] + + # === Functional Model === + inputs = keras.layers.Input(shape=image_shape) + + x = Dinov2PatchAndEmbeddings( + image_size=(image_shape[h_axis], image_shape[w_axis]), + patch_size=patch_size, + hidden_dim=hidden_dim, + num_channels=num_channels, + data_format=data_format, + dtype=dtype, + name="dinov2_patching_and_embedding", + )(inputs) + + super().__init__( + inputs=inputs, + outputs=x, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.image_shape = image_shape + self.patch_size = patch_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias + self.data_format = data_format diff --git a/keras_hub/src/models/dinov2/dinov2_image_converter.py b/keras_hub/src/models/dinov2/dinov2_image_converter.py new file mode 100644 index 0000000000..143f81f4aa --- /dev/null +++ b/keras_hub/src/models/dinov2/dinov2_image_converter.py @@ -0,0 +1,73 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.dinov2.dinov2_backbone import ViTBackbone +from keras_hub.src.utils.tensor_utils import preprocessing_function + + +@keras_hub_export("keras_hub.layers.ViTImageConverter") +class ViTImageConverter(ImageConverter): + """Converts images to the format expected by a ViT model. + + This layer performs image normalization using mean and standard deviation + values. By default, it uses the same normalization as the + "google/vit-large-patch16-224" model on Hugging Face: + `norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]` + ([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)). + These defaults are suitable for models pretrained using this normalization. + + Args: + norm_mean: list or tuple of floats. Mean values for image normalization. + Defaults to `[0.5, 0.5, 0.5]`. + norm_std: list or tuple of floats. Standard deviation values for + image normalization. Defaults to `[0.5, 0.5, 0.5]`. + **kwargs: Additional keyword arguments passed to + `keras_hub.layers.preprocessing.ImageConverter`. + + Examples: + ```python + import keras + import numpy as np + from keras_hub.src.layers import ViTImageConverter + + # Example image (replace with your actual image data) + image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C) + + # Create a ViTImageConverter instance + converter = ViTImageConverter( + image_size=(28,28), + scale=1/255. + ) + # Preprocess the image + preprocessed_image = converter(image) + ``` + """ + + backbone_cls = ViTBackbone + + def __init__( + self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs + ): + super().__init__(**kwargs) + self.norm_mean = norm_mean + self.norm_std = norm_std + + @preprocessing_function + def call(self, inputs): + x = super().call(inputs) + # By default normalize using imagenet mean and std + if self.norm_mean: + x = x - self._expand_non_channel_dims(self.norm_mean, x) + if self.norm_std: + x = x / self._expand_non_channel_dims(self.norm_std, x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "norm_mean": self.norm_mean, + "norm_std": self.norm_std, + } + ) + return config diff --git a/keras_hub/src/models/dinov2/dinov2_layers.py b/keras_hub/src/models/dinov2/dinov2_layers.py new file mode 100644 index 0000000000..f24c909348 --- /dev/null +++ b/keras_hub/src/models/dinov2/dinov2_layers.py @@ -0,0 +1,162 @@ +import keras +from keras import ops + +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class Dinov2PatchAndEmbeddings(keras.layers.Layer): + """Patches the image and embeds the patches. + + Args: + image_size: (int, int). Size of the input image. + patch_size: (int, int). Size of each image patch. + hidden_dim: int. Dimensionality of the patch embeddings. + num_channels: int. Number of channels in the input image. Defaults to + `3`. + use_class_token: bool. Whether to use class token to be part of + patch embedding. Defaults to `True`. + data_format: str. `"channels_last"` or `"channels_first"`. Defaults to + `None` (which uses `"channels_last"`). + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + + def __init__( + self, + image_size, + patch_size, + hidden_dim, + num_channels=3, + data_format=None, + dropout_rate=0.0, + **kwargs, + ): + super().__init__(**kwargs) + grid_size = tuple([s // p for s, p in zip(image_size, patch_size)]) + num_patches = grid_size[0] * grid_size[1] + num_positions = num_patches + 1 + + # === Config === + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.num_channels = num_channels + self.num_patches = num_patches + self.num_positions = num_positions + self.dropout_rate = dropout_rate + self.data_format = standardize_data_format(data_format) + + def build(self, input_shape): + self.mask_token = self.add_weight( + shape=(1, self.hidden_dim), + initializer="zeros", + dtype=self.variable_dtype, + name="mask_token", + ) + self.class_token = self.add_weight( + shape=( + 1, + 1, + self.hidden_dim, + ), + initializer="random_normal", + dtype=self.variable_dtype, + name="class_token", + ) + self.patch_embedding = keras.layers.Conv2D( + filters=self.hidden_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + padding="valid", + activation=None, + dtype=self.dtype_policy, + data_format=self.data_format, + name="patch_embedding", + ) + self.patch_embedding.build(input_shape) + self.position_embedding = keras.layers.Embedding( + self.num_positions, + self.hidden_dim, + dtype=self.dtype_policy, + embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02), + name="position_embedding", + ) + self.position_embedding.build((1, self.num_positions)) + self.dropout = keras.layers.Dropout(self.dropout_rate) + self.position_ids = keras.ops.expand_dims( + keras.ops.arange(self.num_positions), axis=0 + ) + self.built = True + + def interpolate_pos_encoding(self, embeddings, height, width): + """Interpolates positional embeddings for different image sizes.""" + num_patches = ops.shape(embeddings)[1] - 1 + num_positions = ops.shape(self.position_embedding)[1] - 1 + + # If image size is unchanged, return as is + if num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, :1] # CLS token position + patch_pos_embed = self.position_embedding[:, 1:] # Patch positions + + # Compute new patch grid size + new_height = height // self.patch_size[0] + new_width = width // self.patch_size[1] + patch_pos_embed = ops.reshape( + patch_pos_embed, + (1, int(num_positions**0.5), int(num_positions**0.5), -1), + ) + + # Interpolate the position embeddings + patch_pos_embed = keras.layers.Resizing( + new_height, new_width, interpolation="bicubic" + )(patch_pos_embed) + + patch_pos_embed = ops.reshape(patch_pos_embed, (1, -1, self.hidden_dim)) + + return ops.concatenate([class_pos_embed, patch_pos_embed], axis=1) + + def call(self, inputs, bool_masked_pos=None): + patch_embeddings = self.patch_embedding(inputs) + if self.data_format == "channels_first": + patch_embeddings = ops.transpose( + patch_embeddings, axes=(0, 2, 3, 1) + ) + embeddings_shape = ops.shape(patch_embeddings) + patch_embeddings = ops.reshape( + patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]] + ) + position_embeddings = self.position_embedding(self.position_ids) + position_embeddings = self.interpolate_pos_encoding( + position_embeddings, embeddings_shape[1], embeddings_shape[2] + ) + + class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1)) + patch_embeddings = ops.concatenate( + [class_token, patch_embeddings], axis=1 + ) + embeddings = ops.add(patch_embeddings, position_embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def compute_output_shape(self, input_shape): + return ( + input_shape[0], + self.num_positions, + self.hidden_dim, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "image_size": self.image_size, + "patch_size": self.patch_size, + "hidden_dim": self.hidden_dim, + "num_channels": self.num_channels, + "num_patches": self.num_patches, + "num_positions": self.num_positions, + "dropout_rate": self.dropout_rate, + } + ) + return config From cebfd42429961e15311b7badd7fe2fa783152c4c Mon Sep 17 00:00:00 2001 From: sravanneeli Date: Sat, 1 Feb 2025 21:47:15 -0800 Subject: [PATCH 2/3] Correct api changes --- keras_hub/api/layers/__init__.py | 3 +++ keras_hub/api/models/__init__.py | 1 + keras_hub/src/models/dinov2/__init__.py | 0 .../models/dinov2/dinov2_image_converter.py | 22 ++++++++----------- 4 files changed, 13 insertions(+), 13 deletions(-) create mode 100644 keras_hub/src/models/dinov2/__init__.py diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 0e0d31d3a7..6dd11405ec 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -45,6 +45,9 @@ from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter, ) +from keras_hub.src.models.dinov2.dinov2_image_converter import ( + DinoV2ImageConverter, +) from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( EfficientNetImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 7c7adbf97c..d2699873ab 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -107,6 +107,7 @@ from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( DenseNetImageClassifierPreprocessor, ) +from keras_hub.src.models.dinov2.dinov2_backbone import DinoV2Backbone from keras_hub.src.models.distil_bert.distil_bert_backbone import ( DistilBertBackbone, ) diff --git a/keras_hub/src/models/dinov2/__init__.py b/keras_hub/src/models/dinov2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/dinov2/dinov2_image_converter.py b/keras_hub/src/models/dinov2/dinov2_image_converter.py index 143f81f4aa..e2e916244f 100644 --- a/keras_hub/src/models/dinov2/dinov2_image_converter.py +++ b/keras_hub/src/models/dinov2/dinov2_image_converter.py @@ -1,19 +1,15 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.image_converter import ImageConverter -from keras_hub.src.models.dinov2.dinov2_backbone import ViTBackbone +from keras_hub.src.models.dinov2.dinov2_backbone import DinoV2Backbone from keras_hub.src.utils.tensor_utils import preprocessing_function -@keras_hub_export("keras_hub.layers.ViTImageConverter") -class ViTImageConverter(ImageConverter): - """Converts images to the format expected by a ViT model. +@keras_hub_export("keras_hub.layers.DinoV2ImageConverter") +class DinoV2ImageConverter(ImageConverter): + """Converts images to the format expected by a DinoV2 model. This layer performs image normalization using mean and standard deviation - values. By default, it uses the same normalization as the - "google/vit-large-patch16-224" model on Hugging Face: - `norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]` - ([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)). - These defaults are suitable for models pretrained using this normalization. + values. Args: norm_mean: list or tuple of floats. Mean values for image normalization. @@ -27,13 +23,13 @@ class ViTImageConverter(ImageConverter): ```python import keras import numpy as np - from keras_hub.src.layers import ViTImageConverter + from keras_hub.src.layers import DinoV2ImageConverter # Example image (replace with your actual image data) image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C) - # Create a ViTImageConverter instance - converter = ViTImageConverter( + # Create a DinoV2ImageConverter instance + converter = DinoV2ImageConverter( image_size=(28,28), scale=1/255. ) @@ -42,7 +38,7 @@ class ViTImageConverter(ImageConverter): ``` """ - backbone_cls = ViTBackbone + backbone_cls = DinoV2Backbone def __init__( self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs From 689c75a3505c3c0d84a1efc2377b8aabff28c90c Mon Sep 17 00:00:00 2001 From: sravanneeli Date: Sun, 2 Feb 2025 21:03:07 -0800 Subject: [PATCH 3/3] add dinov2 attention and mlp layers --- .../src/models/dinov2/dinov2_backbone.py | 38 +- keras_hub/src/models/dinov2/dinov2_layers.py | 398 ++++++++++++++++++ 2 files changed, 435 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/models/dinov2/dinov2_backbone.py b/keras_hub/src/models/dinov2/dinov2_backbone.py index 7c8e8cbdba..51f0393a59 100644 --- a/keras_hub/src/models/dinov2/dinov2_backbone.py +++ b/keras_hub/src/models/dinov2/dinov2_backbone.py @@ -2,6 +2,7 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.dinov2.dinov2_layers import DinoV2Encoder from keras_hub.src.models.dinov2.dinov2_layers import Dinov2PatchAndEmbeddings from keras_hub.src.utils.keras_utils import standardize_data_format @@ -54,8 +55,10 @@ def __init__( hidden_dim, mlp_dim, dropout_rate=0.0, + drop_path_rate=0.0, attention_dropout=0.0, layer_norm_epsilon=1e-6, + layer_scale_value=1.0, use_mha_bias=True, use_mlp_bias=True, data_format=None, @@ -98,13 +101,25 @@ def __init__( hidden_dim=hidden_dim, num_channels=num_channels, data_format=data_format, + dropout_rate=dropout_rate, dtype=dtype, name="dinov2_patching_and_embedding", )(inputs) + output, all_hidden_states, all_attention_scores = DinoV2Encoder( + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + use_mha_bias=use_mha_bias, + use_mlp_bias=use_mlp_bias, + dropout_rate=dropout_rate, + drop_path_rate=drop_path_rate, + )(x) + super().__init__( inputs=inputs, - outputs=x, + outputs=output, dtype=dtype, **kwargs, ) @@ -119,6 +134,27 @@ def __init__( self.dropout_rate = dropout_rate self.attention_dropout = attention_dropout self.layer_norm_epsilon = layer_norm_epsilon + self.layer_scale_value = layer_scale_value self.use_mha_bias = use_mha_bias self.use_mlp_bias = use_mlp_bias self.data_format = data_format + + def get_config(self): + config = super().get_config() + config.update( + { + "image_shape": self.image_shape, + "patch_size": self.patch_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_dim": self.mlp_dim, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + "layer_scale_value": self.layer_scale_value, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, + } + ) + return config diff --git a/keras_hub/src/models/dinov2/dinov2_layers.py b/keras_hub/src/models/dinov2/dinov2_layers.py index f24c909348..20d77c023b 100644 --- a/keras_hub/src/models/dinov2/dinov2_layers.py +++ b/keras_hub/src/models/dinov2/dinov2_layers.py @@ -160,3 +160,401 @@ def get_config(self): } ) return config + + +class DinoV2MLP(keras.layers.Layer): + """Multi-Layer Perceptron (MLP) block. + + Args: + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_bias: bool. Whether to use bias in the dense layers. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + + def __init__( + self, + hidden_dim, + mlp_dim, + use_bias=True, + dropout_rate=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + # === Config === + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.use_bias = use_bias + self.dropout_rate = dropout_rate + + def build(self, input_shape): + self.dense_1 = keras.layers.Dense( + units=self.mlp_dim, + use_bias=self.use_bias, + activation="gelu", + bias_initializer=( + keras.initializers.RandomNormal(stddev=1e-6) + if self.use_bias + else None + ), + dtype=self.dtype_policy, + name="dense_1", + ) + self.dense_1.build(input_shape) + self.dense_2 = keras.layers.Dense( + units=self.hidden_dim, + use_bias=self.use_bias, + bias_initializer=( + keras.initializers.RandomNormal(stddev=1e-6) + if self.use_bias + else None + ), + dtype=self.dtype_policy, + name="dense_2", + ) + self.dense_2.build((None, None, self.mlp_dim)) + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="dropout" + ) + self.built = True + + def call(self, inputs): + x = self.dense_1(inputs) + x = self.dense_2(x) + out = self.dropout(x) + return out + + +class DinoV2LayerScale(keras.layers.Layer): + """LayerScale layer introduced in + [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239v2). + + Args: + init_value: int. Value to initialize the diagonal matrix of + LayerScale. + hidden_dim: int. Dimensionality of the hidden representations. + """ + + def __init__(self, init_value: float, hidden_dim: int, **kwargs): + super().__init__(**kwargs) + self.init_value = init_value + self.hidden_dim = hidden_dim + + def build(self, input_shape): + self.lambda1 = self.add_weight( + shape=(self.hidden_dim,), + initializer=keras.initializers.Constant(self.init_value), + dtype=self.dtype_policy, + ) + self.built = True + + def call(self, x): + return x * self.lambda1 + + +class DinoV2DropPath(keras.layers.Layer): + """Drop path (Stochastic Depth) per sample applied int path of residual + blocks. + + """ + + def __init__(self, drop_prob, seed=None): + self.drop_prob = drop_prob + self.seed_generator = keras.random.SeedGenerator(seed) + + def call(self, x, training=False): + if self.drop_prob == 0.0 or not training: + return x + keep_prob = 1 - self.drop_prob + input_shape = ops.shape(x) + shape = (input_shape[0],) + (1,) * (len(input_shape) - 1) + random_tensor = keep_prob + keras.random.normal( + shape, dtype=self.dtype, seed=self.seed_generator + ) + random_tensor = ops.floor(random_tensor) + + output = random_tensor / keep_prob * random_tensor + return output + + +class DinoV2EncoderBlock(keras.layers.Layer): + """DinoV2 encoder block. + + Args: + num_heads: int. Number of attention heads. + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_mha_bias: bool. Whether to use bias in the multi-head attention + layer. Defaults to `True`. + use_mlp_bias: bool. Whether to use bias in the MLP layer. Defaults to + `True`. + drop_path_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + attention_dropout: float. Dropout rate for the attention mechanism. + Between 0 and 1. Defaults to `0.0`. + layer_norm_epsilon: float. Small float value for layer normalization + stability. Defaults to `1e-6`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + + def __init__( + self, + num_heads, + hidden_dim, + mlp_dim, + use_mha_bias=True, + use_mlp_bias=True, + dropout_rate=0.0, + drop_path_rate=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + layer_scale_value=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + key_dim = hidden_dim // num_heads + + # === Config === + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.key_dim = key_dim + self.mlp_dim = mlp_dim + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias + self.dropout_rate = dropout_rate + self.drop_path_rate = drop_path_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.layer_scale_value = layer_scale_value + + def build(self, input_shape): + # Attention block + self.layer_norm_1 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + name="ln_1", + dtype=self.dtype_policy, + ) + self.layer_norm_1.build(input_shape) + self.mha = keras.layers.MultiHeadAttention( + num_heads=self.num_heads, + key_dim=self.key_dim, + use_bias=self.use_mha_bias, + dropout=self.attention_dropout, + name="mha", + dtype=self.dtype_policy, + ) + self.mha.build(input_shape, input_shape) + self.drop_path = DinoV2DropPath(drop_prob=self.drop_path_rate) + + # LayerScale layers + self.layer_scale_1 = DinoV2LayerScale( + init_value=self.layer_scale_value, + hidden_dim=self.hidden_dim, + name="ls_1", + dtype=self.dtype_policy, + ) + self.layer_scale_1.build(input_shape) + self.layer_scale_2 = DinoV2LayerScale( + init_value=self.layer_scale_value, + hidden_dim=self.hidden_dim, + name="ls_2", + dtype=self.dtype_policy, + ) + self.layer_scale_2.build(input_shape) + + # MLP block + self.layer_norm_2 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + name="ln_2", + dtype=self.dtype_policy, + ) + self.layer_norm_2.build((None, None, self.hidden_dim)) + self.mlp = DinoV2MLP( + hidden_dim=self.hidden_dim, + mlp_dim=self.mlp_dim, + use_bias=self.use_mlp_bias, + dropout_rate=self.dropout_rate, + name="mlp", + dtype=self.dtype_policy, + ) + self.mlp.build((None, None, self.hidden_dim)) + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + return_attention_scores=False, + ): + attention_scores = None + x = self.layer_norm_1(hidden_states) + if return_attention_scores: + x, attention_scores = self.mha( + x, + x, + attention_mask=attention_mask, + return_attention_scores=return_attention_scores, + ) + else: + x = self.mha( + x, + x, + attention_mask=attention_mask, + ) + + x = self.layer_scale_1(x) + x = self.drop_path(x) + hidden_states + + y = self.layer_norm_2(x) + y = self.mlp(y) + y = self.layer_scale_2(x) + y = self.drop_path(y) + + return x + y, attention_scores + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "key_dim": self.key_dim, + "mlp_dim": self.mlp_dim, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, + "dropout_rate": self.dropout_rate, + "drop_path_rate": self.drop_path_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + "layer_scale_value": self.layer_scale_value, + } + ) + return config + + +class DinoV2Encoder(keras.layers.Layer): + """DinoV2 encoder. + + Args: + num_layers: int. Number of Transformer encoder blocks. + num_heads: int. Number of attention heads. + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_mha_bias: bool. Whether to use bias in the multi-head attention + layers. Defaults to `True`. + use_mlp_bias: bool. Whether to use bias in the MLP layers. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + attention_dropout: float. Dropout rate for the attention mechanism. + Between 0 and 1. Defaults to `0.0`. + layer_norm_epsilon: float. Small float value for layer normalization + tability. Defaults to `1e-6`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + + def __init__( + self, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + use_mha_bias=True, + use_mlp_bias=True, + drop_path_rate=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + layer_scale_value=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + # === config === + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias + self.drop_path_rate = drop_path_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.layer_scale_value = layer_scale_value + + def build(self, input_shape): + self.encoder_layers = [] + for i in range(self.num_layers): + encoder_block = DinoV2EncoderBlock( + num_heads=self.num_heads, + hidden_dim=self.hidden_dim, + mlp_dim=self.mlp_dim, + drop_path_rate=self.drop_path_rate, + use_mha_bias=self.use_mha_bias, + use_mlp_bias=self.use_mlp_bias, + attention_dropout=self.attention_dropout, + layer_norm_epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name=f"tranformer_block_{i + 1}", + ) + encoder_block.build((None, None, self.hidden_dim)) + self.encoder_layers.append(encoder_block) + self.layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="ln", + ) + self.layer_norm.build((None, None, self.hidden_dim)) + self.built = True + + def call( + self, + hidden_states, + attention_masks=None, + output_hidden_states=False, + return_attention_scores=False, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions_scores = () if return_attention_scores else None + + for i in range(self.num_layers): + attention_mask = ( + attention_masks[i] if attention_masks is not None else None + ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states, scores = self.encoder_layers[i]( + hidden_states, + attention_mask=attention_mask, + return_attention_scores=return_attention_scores, + ) + if return_attention_scores: + all_self_attentions_scores = all_self_attentions_scores + ( + scores, + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + x = self.layer_norm(hidden_states) + return x, all_hidden_states, all_self_attentions_scores + + def get_config(self): + config = super().get_config() + config.update( + { + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_dim": self.mlp_dim, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, + "drop_path_rate": self.drop_path_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + "layer_scale_value": self.layer_scale_value, + } + ) + return config