-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrain.py
48 lines (35 loc) · 1.47 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import datetime
import tensorflow as tf
from Dataset import decode
from Dataset import draw_random_sample_batches
from Parameters import block_size, learning_rate
# model = InitialModel(vocab_size)
from RoPEModel import RoPEModel
def write_summary(loss, epoch):
with train_summary_writer.as_default():
tf.summary.scalar('Ilama loss', loss, step=epoch)
model = RoPEModel()
x, y = draw_random_sample_batches(block_size)
logits, loss = model(x, y)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
logdir = "/Users/anu/PycharmProjects/TensorFlow2/logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_summary_writer = tf.summary.create_file_writer(logdir)
epochs = 1
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
for step in range(10000):
with tf.GradientTape() as tape:
x,y = draw_random_sample_batches(block_size)
logits,loss = model(x,y)
grads = tape.gradient(loss, model.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# Log every 200 batches.
if step % 200 == 0:
print(
"Training loss at step %d: %.4f"
% (step, float(loss))
)
write_summary(loss=float(loss),epoch=step)
print("Seen so far: %s samples" % ((step + 1)))