-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMaskedRoPEAttention.py
44 lines (37 loc) · 2.56 KB
/
MaskedRoPEAttention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import tensorflow as tf
from Parameters import n_embd, n_heads, block_size, batch_size
from RotaryPositionalEmbeddings import RotaryPositionalEmbeddings
class MaskedRoPEAttention(tf.keras.Model):
def __init__(self):
super(MaskedRoPEAttention, self).__init__()
self.w_q = tf.keras.layers.Dense(n_embd, input_shape=(n_embd,), use_bias=False)
self.w_k = tf.keras.layers.Dense(n_embd, input_shape=(n_embd,), use_bias=False)
self.w_v = tf.keras.layers.Dense(n_embd, input_shape=(n_embd,), use_bias=False)
self.multihead = tf.keras.layers.MultiHeadAttention(
num_heads=n_heads, key_dim=n_embd//n_heads, dropout=0.1)
print(f'key_dim is {n_embd//n_heads}')
rpe = RotaryPositionalEmbeddings()
self.R = rpe.rotary_matrix(block_size,n_embd)
def call(self, x, return_attn_weights=False):
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
# print(f'Shape of x,R,q {tf.shape(x)}{tf.shape(self.R)}{tf.shape(q)}')
# print(f'Shape of {tf.shape(tf.transpose(q, perm=[1, 0, 2]))}')
q_t = tf.transpose(q, perm=[1, 0, 2]) # Transpose q from (batch_size, seq_length, q_dim) to (batch_size, q_dim, seq_length)
q_out = tf.transpose(tf.matmul(q_t, self.R),perm=[1, 0, 2]) # Transpose back to (batch_size, seq_length, r_dim)
k_t = tf.transpose(k, perm=[1, 0, 2]) # Transpose q from (batch_size, seq_length, q_dim) to (batch_size, q_dim, seq_length)
k_out = tf.transpose(tf.matmul(k_t, self.R),perm=[1, 0, 2]) # Transpose back to (batch_size, seq_length, r_dim)
v_t = tf.transpose(v, perm=[1, 0, 2]) # Transpose q from (batch_size, seq_length, q_dim) to (batch_size, q_dim, seq_length)
v_out = tf.transpose(tf.matmul(v_t, self.R),perm=[1, 0, 2]) # Transpose back to (batch_size, seq_length, r_dim)
# print(f'Shapes of input(x),key,query and value are {tf.shape(x)}{tf.shape(k_out)}{tf.shape(q_out)}{tf.shape(v_out)}')
activations, attn_weights = self.multihead(query=q_out,
value=v_out,
key=k_out,
return_attention_scores=True,
attention_mask = (1 - tf.linalg.band_part(tf.ones((tf.shape(x)[2],
tf.shape(x)[2])), -1, 0)) * -1e9
)
if return_attn_weights:
return activations, attn_weights
return activations