Skip to content

Commit 120c784

Browse files
committed
Adding output HEAD
1 parent 1a2066a commit 120c784

File tree

3 files changed

+77
-51
lines changed

3 files changed

+77
-51
lines changed

MaxText/layers/models.py

Lines changed: 70 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,67 @@ def __call__(
142142
return inputs
143143

144144

145+
class OutputHead(nn.Module):
146+
"""
147+
The final logit projection pipeline.
148+
This module encapsulates Normalization, Dropout, and the final Logit Head
149+
to ensure architectural consistency between the main model and auxiliary heads.
150+
"""
151+
152+
config: Config
153+
shared_embedding: nn.Module
154+
155+
@nn.compact
156+
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool, model_mode: str) -> jnp.ndarray:
157+
cfg = self.config
158+
159+
# 1. Final Normalization
160+
y = RMSNorm(
161+
dtype=cfg.dtype,
162+
weight_dtype=cfg.weight_dtype,
163+
name="decoder_norm",
164+
epsilon=cfg.normalization_layer_epsilon,
165+
kernel_axes=("norm",),
166+
)(hidden_states)
167+
168+
# 2. Final Dropout
169+
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)
170+
171+
# 3. Logit Projection (handles both methods)
172+
if cfg.logits_via_embedding:
173+
logits = self.shared_embedding.attend(y)
174+
if cfg.normalize_embedding_logits:
175+
logits = logits / jnp.sqrt(y.shape[-1])
176+
if cfg.final_logits_soft_cap:
177+
logits = jnp.tanh(logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap
178+
else:
179+
dense_layer = linears.dense_general(
180+
inputs_shape=y.shape,
181+
features=cfg.vocab_size,
182+
weight_dtype=cfg.weight_dtype,
183+
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype,
184+
kernel_axes=("embed", "vocab"),
185+
name="logits_dense",
186+
matmul_precision=self.config.matmul_precision,
187+
)
188+
# Then, call the instance with the input tensor.
189+
logits = dense_layer(y)
190+
191+
# 4. Final Casting
192+
if cfg.cast_logits_to_fp32:
193+
logits = logits.astype(jnp.float32)
194+
195+
# 5. Logical Constraints
196+
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
197+
logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab"))
198+
else:
199+
logits = nn.with_logical_constraint(
200+
logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab")
201+
)
202+
203+
return logits
204+
205+
145206
class Decoder(nn.Module):
146207
"""A stack of decoder layers as a part of an encoder-decoder architecture."""
147208

@@ -540,53 +601,7 @@ def __call__(
540601
**layer_call_kwargs,
541602
)
542603
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
543-
final_hidden_state = y
544-
y = self.get_norm_layer()(
545-
dtype=cfg.dtype,
546-
weight_dtype=cfg.weight_dtype,
547-
name="decoder_norm",
548-
epsilon=cfg.normalization_layer_epsilon,
549-
kernel_axes=("norm",),
550-
parameter_memory_host_offload=cfg.parameter_memory_host_offload,
551-
)(y)
552-
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)
553-
554-
# [batch, length, emb_dim] -> [batch, length, vocab_size]
555-
if cfg.logits_via_embedding:
556-
# Use the transpose of embedding matrix for logit transform.
557-
logits = self.shared_embedding.attend(y)
558-
if self.config.normalize_embedding_logits:
559-
# Correctly normalize pre-softmax logits for this shared case.
560-
logits = logits / jnp.sqrt(y.shape[-1])
561-
if cfg.final_logits_soft_cap:
562-
logits = logits / cfg.final_logits_soft_cap
563-
logits = jnp.tanh(logits) * cfg.final_logits_soft_cap
564-
else:
565-
logits = linears.dense_general(
566-
inputs_shape=y.shape,
567-
features=cfg.vocab_size,
568-
weight_dtype=cfg.weight_dtype,
569-
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
570-
kernel_axes=("embed", "vocab"),
571-
name="logits_dense",
572-
matmul_precision=self.config.matmul_precision,
573-
parameter_memory_host_offload=cfg.parameter_memory_host_offload,
574-
)(
575-
y
576-
) # We do not quantize the logits matmul.
577-
578-
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
579-
logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab"))
580-
else:
581-
logits = nn.with_logical_constraint(
582-
logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab")
583-
)
584-
585-
if self.config.cast_logits_to_fp32:
586-
logits = logits.astype(jnp.float32)
587-
# The API of the Decoder is now a tuple, providing both the main output
588-
# and the raw hidden state needed for auxiliary tasks.
589-
return logits, final_hidden_state
604+
return y
590605

591606

592607
class VisionEncoder(nn.Module):
@@ -662,6 +677,8 @@ def setup(self):
662677
config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer
663678
)
664679
self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh) if cfg.use_multimodal else None
680+
# Instantiate ONE OutputHead, which will be shared by the main path and MTP.
681+
self.output_head = OutputHead(config=cfg, shared_embedding=self.shared_embedding)
665682
self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant)
666683

667684
def __call__(
@@ -702,7 +719,7 @@ def __call__(
702719
if self.config.decoder_block == DecoderBlockType.GEMMA3:
703720
bidirectional_mask = decoder_input_tokens == multimodal_utils.GEMMA_TOKEN_PLACEHOLDER
704721

705-
logits, final_hidden_state = self.decoder(
722+
final_hidden_state = self.decoder(
706723
decoder_input_tokens=decoder_input_tokens,
707724
decoder_positions=decoder_positions,
708725
decoder_segment_ids=decoder_segment_ids,
@@ -715,6 +732,9 @@ def __call__(
715732
image_embeddings=image_embeddings,
716733
)
717734

735+
# The main logits are now computed by calling the dedicated OutputHead.
736+
logits = self.output_head(hidden_states=final_hidden_state, deterministic=not enable_dropout, model_mode=model_mode)
737+
718738
# If we are initializing the model AND MTP is enabled, we must create
719739
# dummy target tensors. This allows Flax to trace the MTPBlock and create
720740
# all its necessary parameters, without requiring the main training pipeline
@@ -736,8 +756,9 @@ def __call__(
736756
if self.config.mtp_num_layers > 0 and model_mode == MODEL_MODE_TRAIN:
737757
self.mtp_block(
738758
main_hidden_state=final_hidden_state,
739-
input_ids=decoder_input_tokens,
740759
shared_embedding=self.shared_embedding,
760+
output_head=self.output_head,
761+
input_ids=decoder_input_tokens,
741762
target_ids=decoder_target_tokens,
742763
target_mask=decoder_target_mask,
743764
position_ids=decoder_positions,

MaxText/layers/multi_token_prediction.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __call__(
153153
self,
154154
main_hidden_state,
155155
shared_embedding,
156+
output_head,
156157
input_ids,
157158
target_ids,
158159
target_mask,
@@ -193,8 +194,8 @@ def __call__(
193194
mtp_hidden_state, target_token_embedding, position_ids, decoder_segment_ids, deterministic
194195
)
195196

196-
# Project to logits using the shared embedding transpose
197-
mtp_logits = shared_embedding.attend(next_mtp_hidden_state)
197+
# Project to logits using the shared output head
198+
mtp_logits = output_head(hidden_states=next_mtp_hidden_state, deterministic=deterministic, model_mode=MODEL_MODE_TRAIN)
198199

199200
# Calculate cross-entropy loss for this specific layer's prediction
200201
mtp_xent, _ = max_utils.cross_entropy_with_logits(mtp_logits, jax.nn.one_hot(rolled_target_ids, cfg.vocab_size), 0.0)

MaxText/tests/multi_token_prediction_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ class MTPBlockTestModel(nn.Module):
124124
mesh: Mesh
125125

126126
def setup(self):
127+
128+
self.output_head = models.OutputHead(config=self.config, shared_embedding=self.shared_embedding)
129+
127130
self.shared_embedding = embeddings.Embed(
128131
num_embeddings=self.config.vocab_size, features=self.config.base_emb_dim, name="token_embedder", config=self.config
129132
)
@@ -137,6 +140,7 @@ def __call__(
137140
return self.mtp_block(
138141
main_hidden_state,
139142
self.shared_embedding,
143+
self.self.output_head,
140144
input_ids,
141145
target_ids,
142146
target_mask,

0 commit comments

Comments
 (0)