diff --git a/keras_hub/src/layers/modeling/position_embedding.py b/keras_hub/src/layers/modeling/position_embedding.py index afa0c8f679..8e77207baf 100644 --- a/keras_hub/src/layers/modeling/position_embedding.py +++ b/keras_hub/src/layers/modeling/position_embedding.py @@ -31,6 +31,10 @@ class PositionEmbedding(keras.layers.Layer): start_index: An integer or integer tensor. The starting position to compute the position embedding from. This is useful during cached decoding, where each position is predicted separately in a loop. + hierarchical_alpha:Hyperparameters of hierarchical positional encoding. + Hierarchical positional encoding allows models such as BERT and ViT + to be seamlessly expanded to a length of sequence_length**2. + The range of this hyperparameter is (0,1) and not equal to 0.5 Example: @@ -61,6 +65,7 @@ def __init__( self, sequence_length, initializer="glorot_uniform", + hierarchical_alpha=0.4, **kwargs, ): super().__init__(**kwargs) @@ -69,13 +74,23 @@ def __init__( "`sequence_length` must be an Integer, received `None`." ) self.sequence_length = int(sequence_length) + self.hierarchical_alpha = hierarchical_alpha self.initializer = keras.initializers.get(initializer) + if ( + hierarchical_alpha <= 0 + or hierarchical_alpha >= 1 + or hierarchical_alpha == 0.5 + ): + raise ValueError( + "`hierarchical_alpha` must be in (0,1) and not equal to 0.5." + ) def get_config(self): config = super().get_config() config.update( { "sequence_length": self.sequence_length, + "hierarchical_alpha": self.hierarchical_alpha, "initializer": keras.initializers.serialize(self.initializer), } ) @@ -98,11 +113,32 @@ def call(self, inputs, start_index=0): # trim to match the length of the input sequence, which might be less # than the sequence_length of the layer. position_embeddings = ops.convert_to_tensor(self.position_embeddings) - position_embeddings = ops.slice( - position_embeddings, - (start_index, 0), - (sequence_length, feature_length), - ) + if sequence_length < self.sequence_length: + position_embeddings = ops.slice( + position_embeddings, + (start_index, 0), + (sequence_length, feature_length), + ) + + else: + embeddings = ( + position_embeddings + - self.hierarchical_alpha * position_embeddings[:1] + ) + embeddings = embeddings / (1 - self.hierarchical_alpha) + position_ids = ( + ops.arange(sequence_length, dtype="int32") + start_index + ) + embeddings_x = ops.take( + embeddings, position_ids // self.sequence_length, axis=0 + ) + embeddings_y = ops.take( + embeddings, position_ids % self.sequence_length, axis=0 + ) + position_embeddings = ( + self.hierarchical_alpha * embeddings_x + + (1 - self.hierarchical_alpha) * embeddings_y + ) return ops.broadcast_to(position_embeddings, shape) def compute_output_shape(self, input_shape): diff --git a/keras_hub/src/layers/modeling/position_embedding_test.py b/keras_hub/src/layers/modeling/position_embedding_test.py index d6e577c66e..a086942e29 100644 --- a/keras_hub/src/layers/modeling/position_embedding_test.py +++ b/keras_hub/src/layers/modeling/position_embedding_test.py @@ -37,6 +37,21 @@ def test_layer_behaviors_4d(self): expected_num_trainable_weights=1, ) + def test_layer_behaviors_hierarchical(self): + self.run_layer_test( + cls=PositionEmbedding, + init_kwargs={ + "sequence_length": 4, + }, + input_data=random.uniform(shape=(4, 16, 30)), + expected_output_shape=(4, 16, 30), + expected_num_trainable_weights=1, + ) + layer = PositionEmbedding(sequence_length=8) + outputs1 = layer(random.uniform(shape=(2, 4, 30))) + outputs2 = layer(random.uniform(shape=(2, 16, 30))) + self.assertAllClose(outputs1, outputs2[:, :4], rtol=1e-4, atol=1e-7) + def test_float16_dtype(self): # Create a 3-dimensional input (the first dimension is implicit). sequence_length = 21