-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestRMS.py
38 lines (31 loc) · 1.48 KB
/
testRMS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf
from Parameters import batch_size,block_size,n_embd
from RMSNorm import RMSNorm
class TestRMSNorm(tf.test.TestCase):
def setUp(self):
super(TestRMSNorm, self).setUp()
self.batch = tf.random.normal((batch_size, block_size, n_embd))
def test_RMSNormTest(self):
normalized_mat, norm = tf.linalg.normalize(self.batch, axis=(1, 2))
ff_rms = tf.multiply(norm,
tf.pow(tf.cast(tf.size(self.batch[0]), tf.float32), -0.5))
ffx = tf.Variable(tf.zeros_like(self.batch))
print(tf.shape(ffx))
for i in range(self.batch.shape[0]):
ffx[i, :, : ].assign(tf.divide(self.batch[i] , ff_rms[i]))
normalized_mat, norm = tf.linalg.normalize(self.batch, axis=(1, 2))
print(tf.pow(norm,2))
# The values are close to 1024 but close enough for default
# tolerance levels to pass the test. So it will fail unless
# I pass a different tolerance level. I believe this is a temporary
# fix until I understand the issue.
self.assertAllClose(tf.pow(norm,2),
tf.reshape(
tf.repeat([tf.constant(1024,tf.float32)], repeats=[4], axis=0),
(4,1,1)),50,50)
def test_RMSNorm(self):
batch = tf.random.normal((batch_size, block_size, n_embd))
rms = RMSNorm([block_size, n_embd])
g = rms(batch)
print(g.shape)
tf.test.main()