Skip to content

Commit c22ceeb

Browse files
committed
fixing var order
1 parent 120c784 commit c22ceeb

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

MaxText/tests/multi_token_prediction_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,10 @@ class MTPBlockTestModel(nn.Module):
125125

126126
def setup(self):
127127

128-
self.output_head = models.OutputHead(config=self.config, shared_embedding=self.shared_embedding)
129-
130128
self.shared_embedding = embeddings.Embed(
131129
num_embeddings=self.config.vocab_size, features=self.config.base_emb_dim, name="token_embedder", config=self.config
132130
)
131+
self.output_head = models.OutputHead(config=self.config, shared_embedding=self.shared_embedding)
133132
self.mtp_block = multi_token_prediction.MultiTokenPredictionBlock(
134133
config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=blocks.DecoderLayer
135134
)

0 commit comments

Comments
 (0)