Skip to content

Commit 26fe7d7

Browse files
Narsilalvarobartt
andauthored
Make flake work on metal (#654)
Co-authored-by: Alvaro Bartolome <[email protected]>
1 parent 53667f6 commit 26fe7d7

File tree

3 files changed

+52
-308
lines changed

3 files changed

+52
-308
lines changed

backends/candle/src/models/qwen3.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::layers::{
22
apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, Linear, RMSNorm,
33
};
44
use crate::models::Model;
5-
use candle::{Device, IndexOp, Result, Tensor, D};
5+
use candle::{DType, Device, IndexOp, Result, Tensor, D};
66
use candle_nn::{Embedding, Module, VarBuilder};
77
use serde::Deserialize;
88
use text_embeddings_backend_core::{Batch, ModelType, Pool};
@@ -382,10 +382,12 @@ pub struct Qwen3Model {
382382
rotary_cache: (Tensor, Tensor),
383383
rotary_dim: usize,
384384
pool: Pool,
385-
pub device: Device,
386385
num_attention_heads: usize,
387386
pad_token_id: u32,
388387

388+
dtype: DType,
389+
device: Device,
390+
389391
span: tracing::Span,
390392
}
391393

@@ -435,30 +437,30 @@ impl Qwen3Model {
435437
rotary_dim,
436438
pool,
437439
pad_token_id: config.eos_token_id as u32,
438-
device: vb.device().clone(),
439440
num_attention_heads: config.num_attention_heads,
441+
dtype: vb.dtype(),
442+
device: vb.device().clone(),
440443
span: tracing::span!(tracing::Level::TRACE, "model"),
441444
})
442445
}
443446

444447
fn get_causal_attention_bias(&self, attention_bias: Tensor) -> Result<Tensor> {
445448
let (bs, dim, seq_len, _) = attention_bias.dims4()?;
446449

447-
let device = attention_bias.device();
448-
449450
let mask: Vec<u8> = (0..seq_len)
450451
.flat_map(|i| (0..seq_len).map(move |j| (j > i) as u8))
451452
.collect();
452453

453454
let causal_mask = Tensor::from_slice(&mask, (seq_len, seq_len), &Device::Cpu)?;
454455
let causal_mask = causal_mask.expand(&[bs, dim, seq_len, seq_len])?;
455456

456-
let negatives = Tensor::full(f32::MIN, attention_bias.shape(), &Device::Cpu)?;
457-
let zeros = Tensor::zeros_like(&attention_bias)?.to_device(&Device::Cpu)?;
457+
let negatives =
458+
Tensor::full(f32::MIN, attention_bias.shape(), &Device::Cpu)?.to_dtype(self.dtype)?;
459+
let zeros = Tensor::zeros_like(&attention_bias)?.to_dtype(self.dtype)?;
458460

459461
let causal_mask = causal_mask
460462
.where_cond(&negatives, &zeros)?
461-
.to_device(device)?;
463+
.to_device(&self.device)?;
462464

463465
attention_bias.broadcast_add(&causal_mask)
464466
}
@@ -494,7 +496,7 @@ impl Qwen3Model {
494496
for _ in 0..padding {
495497
input_ids.push(self.pad_token_id);
496498
position_ids.push(0);
497-
attention_bias.push(f32::MIN);
499+
attention_bias.push(f32::NEG_INFINITY);
498500
}
499501
}
500502

@@ -539,7 +541,7 @@ impl Qwen3Model {
539541
// Create attention bias for causal masking even for single sequences
540542
let attention_bias = Tensor::zeros(
541543
(1, self.num_attention_heads, seq_len, seq_len),
542-
candle::DType::F32,
544+
self.dtype,
543545
&self.device,
544546
)?;
545547

flake.lock

Lines changed: 5 additions & 76 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)