@@ -142,6 +142,67 @@ def __call__(
142
142
return inputs
143
143
144
144
145
+ class OutputHead (nn .Module ):
146
+ """
147
+ The final logit projection pipeline.
148
+ This module encapsulates Normalization, Dropout, and the final Logit Head
149
+ to ensure architectural consistency between the main model and auxiliary heads.
150
+ """
151
+
152
+ config : Config
153
+ shared_embedding : nn .Module
154
+
155
+ @nn .compact
156
+ def __call__ (self , hidden_states : jnp .ndarray , deterministic : bool , model_mode : str ) -> jnp .ndarray :
157
+ cfg = self .config
158
+
159
+ # 1. Final Normalization
160
+ y = RMSNorm (
161
+ dtype = cfg .dtype ,
162
+ weight_dtype = cfg .weight_dtype ,
163
+ name = "decoder_norm" ,
164
+ epsilon = cfg .normalization_layer_epsilon ,
165
+ kernel_axes = ("norm" ,),
166
+ )(hidden_states )
167
+
168
+ # 2. Final Dropout
169
+ y = nn .Dropout (rate = cfg .dropout_rate , broadcast_dims = (- 2 ,))(y , deterministic = deterministic )
170
+
171
+ # 3. Logit Projection (handles both methods)
172
+ if cfg .logits_via_embedding :
173
+ logits = self .shared_embedding .attend (y )
174
+ if cfg .normalize_embedding_logits :
175
+ logits = logits / jnp .sqrt (y .shape [- 1 ])
176
+ if cfg .final_logits_soft_cap :
177
+ logits = jnp .tanh (logits / cfg .final_logits_soft_cap ) * cfg .final_logits_soft_cap
178
+ else :
179
+ dense_layer = linears .dense_general (
180
+ inputs_shape = y .shape ,
181
+ features = cfg .vocab_size ,
182
+ weight_dtype = cfg .weight_dtype ,
183
+ dtype = jnp .float32 if cfg .logits_dot_in_fp32 else cfg .dtype ,
184
+ kernel_axes = ("embed" , "vocab" ),
185
+ name = "logits_dense" ,
186
+ matmul_precision = self .config .matmul_precision ,
187
+ )
188
+ # Then, call the instance with the input tensor.
189
+ logits = dense_layer (y )
190
+
191
+ # 4. Final Casting
192
+ if cfg .cast_logits_to_fp32 :
193
+ logits = logits .astype (jnp .float32 )
194
+
195
+ # 5. Logical Constraints
196
+ if model_mode in (MODEL_MODE_PREFILL , MODEL_MODE_AUTOREGRESSIVE ):
197
+ logits = nn .with_logical_constraint (logits , (None , None , "activation_vocab" ))
198
+ else :
199
+ logits = nn .with_logical_constraint (
200
+ logits , ("activation_embed_and_logits_batch" , "activation_length" , "activation_vocab" )
201
+ )
202
+
203
+ return logits
204
+
205
+
145
206
class Decoder (nn .Module ):
146
207
"""A stack of decoder layers as a part of an encoder-decoder architecture."""
147
208
@@ -540,53 +601,7 @@ def __call__(
540
601
** layer_call_kwargs ,
541
602
)
542
603
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
543
- final_hidden_state = y
544
- y = self .get_norm_layer ()(
545
- dtype = cfg .dtype ,
546
- weight_dtype = cfg .weight_dtype ,
547
- name = "decoder_norm" ,
548
- epsilon = cfg .normalization_layer_epsilon ,
549
- kernel_axes = ("norm" ,),
550
- parameter_memory_host_offload = cfg .parameter_memory_host_offload ,
551
- )(y )
552
- y = nn .Dropout (rate = cfg .dropout_rate , broadcast_dims = (- 2 ,))(y , deterministic = deterministic )
553
-
554
- # [batch, length, emb_dim] -> [batch, length, vocab_size]
555
- if cfg .logits_via_embedding :
556
- # Use the transpose of embedding matrix for logit transform.
557
- logits = self .shared_embedding .attend (y )
558
- if self .config .normalize_embedding_logits :
559
- # Correctly normalize pre-softmax logits for this shared case.
560
- logits = logits / jnp .sqrt (y .shape [- 1 ])
561
- if cfg .final_logits_soft_cap :
562
- logits = logits / cfg .final_logits_soft_cap
563
- logits = jnp .tanh (logits ) * cfg .final_logits_soft_cap
564
- else :
565
- logits = linears .dense_general (
566
- inputs_shape = y .shape ,
567
- features = cfg .vocab_size ,
568
- weight_dtype = cfg .weight_dtype ,
569
- dtype = jnp .float32 if cfg .logits_dot_in_fp32 else cfg .dtype , # for logit training stability
570
- kernel_axes = ("embed" , "vocab" ),
571
- name = "logits_dense" ,
572
- matmul_precision = self .config .matmul_precision ,
573
- parameter_memory_host_offload = cfg .parameter_memory_host_offload ,
574
- )(
575
- y
576
- ) # We do not quantize the logits matmul.
577
-
578
- if model_mode in (MODEL_MODE_PREFILL , MODEL_MODE_AUTOREGRESSIVE ):
579
- logits = nn .with_logical_constraint (logits , (None , None , "activation_vocab" ))
580
- else :
581
- logits = nn .with_logical_constraint (
582
- logits , ("activation_embed_and_logits_batch" , "activation_length" , "activation_vocab" )
583
- )
584
-
585
- if self .config .cast_logits_to_fp32 :
586
- logits = logits .astype (jnp .float32 )
587
- # The API of the Decoder is now a tuple, providing both the main output
588
- # and the raw hidden state needed for auxiliary tasks.
589
- return logits , final_hidden_state
604
+ return y
590
605
591
606
592
607
class VisionEncoder (nn .Module ):
@@ -662,6 +677,8 @@ def setup(self):
662
677
config = self .config , mesh = self .mesh , name = "mtp_block" , transformer_layer_module = mtp_layer
663
678
)
664
679
self .vision_encoder = VisionEncoder (config = cfg , mesh = mesh ) if cfg .use_multimodal else None
680
+ # Instantiate ONE OutputHead, which will be shared by the main path and MTP.
681
+ self .output_head = OutputHead (config = cfg , shared_embedding = self .shared_embedding )
665
682
self .decoder = Decoder (config = cfg , shared_embedding = self .shared_embedding , mesh = mesh , quant = self .quant )
666
683
667
684
def __call__ (
@@ -702,7 +719,7 @@ def __call__(
702
719
if self .config .decoder_block == DecoderBlockType .GEMMA3 :
703
720
bidirectional_mask = decoder_input_tokens == multimodal_utils .GEMMA_TOKEN_PLACEHOLDER
704
721
705
- logits , final_hidden_state = self .decoder (
722
+ final_hidden_state = self .decoder (
706
723
decoder_input_tokens = decoder_input_tokens ,
707
724
decoder_positions = decoder_positions ,
708
725
decoder_segment_ids = decoder_segment_ids ,
@@ -715,6 +732,9 @@ def __call__(
715
732
image_embeddings = image_embeddings ,
716
733
)
717
734
735
+ # The main logits are now computed by calling the dedicated OutputHead.
736
+ logits = self .output_head (hidden_states = final_hidden_state , deterministic = not enable_dropout , model_mode = model_mode )
737
+
718
738
# If we are initializing the model AND MTP is enabled, we must create
719
739
# dummy target tensors. This allows Flax to trace the MTPBlock and create
720
740
# all its necessary parameters, without requiring the main training pipeline
@@ -736,8 +756,9 @@ def __call__(
736
756
if self .config .mtp_num_layers > 0 and model_mode == MODEL_MODE_TRAIN :
737
757
self .mtp_block (
738
758
main_hidden_state = final_hidden_state ,
739
- input_ids = decoder_input_tokens ,
740
759
shared_embedding = self .shared_embedding ,
760
+ output_head = self .output_head ,
761
+ input_ids = decoder_input_tokens ,
741
762
target_ids = decoder_target_tokens ,
742
763
target_mask = decoder_target_mask ,
743
764
position_ids = decoder_positions ,
0 commit comments