Skip to content

Commit 53667f6

Browse files
authored
Fix FlashQwen3 (#650)
1 parent f7aa35b commit 53667f6

File tree

3 files changed

+4028
-4025
lines changed

3 files changed

+4028
-4025
lines changed

backends/candle/src/models/flash_qwen3.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct Qwen3Attention {
2727
impl Qwen3Attention {
2828
pub fn load(vb: VarBuilder, config: &Qwen3Config) -> Result<Self> {
2929
if config.use_sliding_window {
30-
candle::bail!("Sliding window is not supported for Qwen3",);
30+
candle::bail!("Sliding window is not supported for Qwen3");
3131
}
3232

3333
let num_attention_heads = config.num_attention_heads;
@@ -143,8 +143,8 @@ impl Qwen3Attention {
143143
)?;
144144

145145
// Apply normalization layers
146-
let (q, _res) = self.q_norm.forward(&q, None)?;
147-
let (k, _res) = self.k_norm.forward(&k, None)?;
146+
let (q, _) = self.q_norm.forward(&q, None)?;
147+
let (k, _) = self.k_norm.forward(&k, None)?;
148148

149149
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
150150

@@ -158,7 +158,7 @@ impl Qwen3Attention {
158158
max_s,
159159
max_s,
160160
self.softmax_scale,
161-
false,
161+
true,
162162
None,
163163
None,
164164
)?;
@@ -215,8 +215,8 @@ impl Qwen3MLP {
215215
let up_states = gate_up_states.narrow(1, self.intermediate_size, self.intermediate_size)?;
216216

217217
let gate_states = self.act.forward(&gate_states)?;
218-
let r = self.down_proj.forward(&(gate_states * up_states)?);
219-
r
218+
219+
self.down_proj.forward(&(gate_states * up_states)?)
220220
}
221221
}
222222

@@ -266,12 +266,15 @@ impl Qwen3Layer {
266266
let _enter = self.span.enter();
267267

268268
let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, residual)?;
269+
269270
let attn_output =
270271
self.attention
271272
.forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s)?;
273+
272274
let (normed_attn_res_output, attn_res) = self
273275
.post_attention_layer_norm
274276
.forward(&attn_output, Some(&res))?;
277+
275278
let mlp_output = self.mlp.forward(&normed_attn_res_output)?;
276279

277280
Ok((mlp_output, attn_res))

0 commit comments

Comments
 (0)