Skip to content

Commit 7892744

Browse files
committed
Add Qwen3
1 parent 4d4b6b0 commit 7892744

File tree

10 files changed

+359
-0
lines changed

10 files changed

+359
-0
lines changed

MaxText/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,4 @@ class DecoderBlockType(enum.Enum):
8181
SIMPLE = "simple"
8282
SIMPLE_MLP = "simple_mlp"
8383
LLAMA4 = "llama4"
84+
QWEN3 = "qwen3"

MaxText/configs/models/qwen3-0.6b.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for qwen3-0.6b
16+
17+
base_emb_dim: 1024
18+
base_num_query_heads: 16
19+
base_num_kv_heads: 8
20+
base_num_decoder_layers: 28
21+
base_mlp_dim: 3072
22+
head_dim: 128
23+
mlp_activations: ["silu","linear"]
24+
vocab_size: 151936
25+
enable_dropout: False
26+
logits_via_embedding: True
27+
normalization_layer_epsilon: 1.0e-6
28+
rope_max_timescale: 1_000_000
29+
decoder_block: "qwen3"

MaxText/configs/models/qwen3-1.7b.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for qwen3-1.7b
16+
17+
base_emb_dim: 2048
18+
base_num_query_heads: 16
19+
base_num_kv_heads: 8
20+
base_num_decoder_layers: 28
21+
base_mlp_dim: 6144
22+
head_dim: 128
23+
mlp_activations: ["silu","linear"]
24+
vocab_size: 151936
25+
enable_dropout: False
26+
logits_via_embedding: True
27+
normalization_layer_epsilon: 1.0e-6
28+
rope_max_timescale: 1_000_000
29+
decoder_block: "qwen3"

MaxText/configs/models/qwen3-14b.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for qwen3-14b
16+
17+
base_emb_dim: 5120
18+
base_num_query_heads: 40
19+
base_num_kv_heads: 8
20+
base_num_decoder_layers: 40
21+
base_mlp_dim: 17408
22+
head_dim: 128
23+
mlp_activations: ["silu","linear"]
24+
vocab_size: 151936
25+
enable_dropout: False
26+
logits_via_embedding: False
27+
normalization_layer_epsilon: 1.0e-6
28+
rope_max_timescale: 1_000_000
29+
decoder_block: "qwen3"

MaxText/configs/models/qwen3-32b.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for qwen3-32b
16+
17+
base_emb_dim: 5120
18+
base_num_query_heads: 64
19+
base_num_kv_heads: 8
20+
base_num_decoder_layers: 64
21+
base_mlp_dim: 25600
22+
head_dim: 128
23+
mlp_activations: ["silu","linear"]
24+
vocab_size: 151936
25+
enable_dropout: False
26+
logits_via_embedding: False
27+
normalization_layer_epsilon: 1.0e-6
28+
rope_max_timescale: 1_000_000
29+
decoder_block: "qwen3"

MaxText/configs/models/qwen3-4b.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for qwen3-4b
16+
17+
base_emb_dim: 2560
18+
base_num_query_heads: 32
19+
base_num_kv_heads: 8
20+
base_num_decoder_layers: 36
21+
base_mlp_dim: 9728
22+
head_dim: 128
23+
mlp_activations: ["silu","linear"]
24+
vocab_size: 151936
25+
enable_dropout: False
26+
logits_via_embedding: True
27+
normalization_layer_epsilon: 1.0e-6
28+
rope_max_timescale: 1_000_000
29+
decoder_block: "qwen3"

MaxText/configs/models/qwen3-8b.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for qwen3-8b
16+
17+
base_emb_dim: 4096
18+
base_num_query_heads: 32
19+
base_num_kv_heads: 8
20+
base_num_decoder_layers: 36
21+
base_mlp_dim: 12288
22+
head_dim: 128
23+
mlp_activations: ["silu","linear"]
24+
vocab_size: 151936
25+
enable_dropout: False
26+
logits_via_embedding: False
27+
normalization_layer_epsilon: 1.0e-6
28+
rope_max_timescale: 1_000_000
29+
decoder_block: "qwen3"

MaxText/layers/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,10 @@ def get_decoder_layers(self):
362362
return [llama4.Llama4ScannableBlock]
363363
else:
364364
return [llama4.Llama4DecoderLayer]
365+
elif self.config.decoder_block == DecoderBlockType.QWEN3:
366+
from MaxText.layers import qwen3 # pylint: disable=import-outside-toplevel
367+
368+
return [qwen3.Qwen3DecoderLayer]
365369
else:
366370
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
367371

@@ -379,6 +383,7 @@ def get_norm_layer(self):
379383
DecoderBlockType.SIMPLE,
380384
DecoderBlockType.SIMPLE_MLP,
381385
DecoderBlockType.LLAMA4,
386+
DecoderBlockType.QWEN3,
382387
):
383388
return RMSNorm
384389
elif self.config.decoder_block == DecoderBlockType.GPT3:

MaxText/layers/qwen3.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""
2+
Copyright 2023 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""Transformer model definition."""
18+
# pylint: disable=arguments-differ
19+
# pylint: disable=no-name-in-module
20+
21+
from typing import Optional
22+
23+
import jax.numpy as jnp
24+
from jax.ad_checkpoint import checkpoint_name
25+
from jax.sharding import Mesh
26+
# from jax.experimental.pallas.ops.tpu import flash_attention
27+
28+
from flax import linen as nn
29+
30+
from MaxText.inference import page_manager
31+
from MaxText.layers import linears
32+
from MaxText.layers import models
33+
from MaxText.layers import quantizations
34+
from MaxText.layers.attentions import Attention
35+
from MaxText.layers.quantizations import AqtQuantization as Quant
36+
from MaxText.layers.normalizations import RMSNorm
37+
38+
39+
# -----------------------------------------
40+
# The Decoder Layer specific for Qwen3
41+
# -----------------------------------------
42+
43+
44+
class Qwen3DecoderLayer(nn.Module):
45+
"""Transformer decoder layer that attends to the encoder."""
46+
47+
config: models.Config
48+
mesh: Mesh
49+
quant: Optional[Quant] = None
50+
51+
@nn.compact
52+
def __call__(
53+
self,
54+
inputs,
55+
decoder_segment_ids,
56+
decoder_positions,
57+
deterministic,
58+
model_mode,
59+
slot: Optional[int] = None,
60+
page_state: Optional[page_manager.PageState] = None,
61+
previous_chunk=None,
62+
):
63+
cfg = self.config
64+
mesh = self.mesh
65+
66+
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
67+
inputs = checkpoint_name(inputs, "decoder_layer_input")
68+
lnx_rms = RMSNorm(
69+
dtype=cfg.dtype,
70+
weight_dtype=cfg.weight_dtype,
71+
name="pre_self_attention_layer_norm",
72+
kernel_axes=("norm",),
73+
epsilon=cfg.normalization_layer_epsilon,
74+
)
75+
lnx = lnx_rms(inputs)
76+
77+
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
78+
# Instead of scaling the query values in the checkpoint conversion
79+
# we'll do it dynamically in the forward pass of Attention
80+
query_pre_attn_scalar = cfg.head_dim**-0.5
81+
82+
# Self-attention block
83+
attention_layer = Attention(
84+
config=cfg,
85+
num_query_heads=cfg.num_query_heads,
86+
num_kv_heads=cfg.num_kv_heads,
87+
head_dim=cfg.head_dim,
88+
max_target_length=cfg.max_target_length,
89+
max_prefill_predict_length=cfg.max_prefill_predict_length,
90+
attention_kernel=cfg.attention,
91+
mesh=mesh,
92+
dtype=cfg.dtype,
93+
weight_dtype=cfg.weight_dtype,
94+
dropout_rate=cfg.dropout_rate,
95+
name="self_attention",
96+
float32_qk_product=cfg.float32_qk_product,
97+
float32_logits=cfg.float32_logits,
98+
quant=self.quant,
99+
kv_quant=quantizations.configure_kv_quant(cfg),
100+
prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))),
101+
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
102+
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
103+
reshape_q=cfg.reshape_q,
104+
use_ragged_attention=cfg.use_ragged_attention,
105+
ragged_block_size=cfg.ragged_block_size,
106+
use_qk_norm=cfg.use_qk_norm,
107+
query_pre_attn_scalar=query_pre_attn_scalar,
108+
)
109+
110+
attention_lnx = attention_layer(
111+
lnx,
112+
lnx,
113+
decoder_positions,
114+
decoder_segment_ids=decoder_segment_ids,
115+
deterministic=deterministic,
116+
model_mode=model_mode,
117+
slot=slot,
118+
page_state=page_state,
119+
previous_chunk=previous_chunk,
120+
)
121+
122+
attention_lnx = nn.with_logical_constraint(
123+
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
124+
)
125+
intermediate_inputs = inputs + attention_lnx
126+
127+
# Fully Connected
128+
hidden_states = models.RMSNorm(
129+
dtype=cfg.dtype,
130+
weight_dtype=cfg.weight_dtype,
131+
name="post_self_attention_layer_norm",
132+
kernel_axes=("norm",),
133+
epsilon=cfg.normalization_layer_epsilon,
134+
)(intermediate_inputs)
135+
hidden_states = nn.with_logical_constraint(
136+
hidden_states, ("activation_batch", "activation_norm_length", "activation_embed")
137+
)
138+
139+
# MLP block.
140+
mlp_lnx = linears.MlpBlock(
141+
intermediate_dim=cfg.mlp_dim,
142+
activations=cfg.mlp_activations,
143+
intermediate_dropout_rate=cfg.dropout_rate,
144+
dtype=cfg.dtype,
145+
weight_dtype=cfg.weight_dtype,
146+
name="mlp",
147+
config=cfg,
148+
quant=self.quant,
149+
)(hidden_states, deterministic=deterministic)
150+
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
151+
152+
layer_output = mlp_lnx + intermediate_inputs
153+
154+
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
155+
156+
layer_output = nn.with_logical_constraint(
157+
layer_output,
158+
("activation_batch", "activation_norm_length", "activation_embed"),
159+
)
160+
161+
if cfg.record_internal_nn_metrics:
162+
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
163+
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
164+
self.sow(
165+
"intermediates",
166+
"activation_fraction_zero",
167+
jnp.sum(layer_output == 0) / jnp.size(layer_output),
168+
)
169+
170+
if cfg.scan_layers:
171+
return layer_output, None
172+
else:
173+
return layer_output

MaxText/pyconfig.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,12 @@ def validate_model_name(s: str) -> bool:
301301
"gpt3-52k",
302302
"llama4-17b-16e",
303303
"llama4-17b-128e",
304+
"qwen3-0.6b",
305+
"qwen3-1.7b",
306+
"qwen3-4b",
307+
"qwen3-8b",
308+
"qwen3-14b",
309+
"qwen3-32b",
304310
)
305311
if s not in valid_model_names:
306312
raise ValueError(f"Invalid model name was passed. Got {s}, Valid options {valid_model_names}")

0 commit comments

Comments
 (0)