@@ -2,7 +2,7 @@ use crate::layers::{
2
2
apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct , Linear , RMSNorm ,
3
3
} ;
4
4
use crate :: models:: Model ;
5
- use candle:: { Device , IndexOp , Result , Tensor , D } ;
5
+ use candle:: { DType , Device , IndexOp , Result , Tensor , D } ;
6
6
use candle_nn:: { Embedding , Module , VarBuilder } ;
7
7
use serde:: Deserialize ;
8
8
use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
@@ -382,10 +382,12 @@ pub struct Qwen3Model {
382
382
rotary_cache : ( Tensor , Tensor ) ,
383
383
rotary_dim : usize ,
384
384
pool : Pool ,
385
- pub device : Device ,
386
385
num_attention_heads : usize ,
387
386
pad_token_id : u32 ,
388
387
388
+ dtype : DType ,
389
+ device : Device ,
390
+
389
391
span : tracing:: Span ,
390
392
}
391
393
@@ -435,30 +437,30 @@ impl Qwen3Model {
435
437
rotary_dim,
436
438
pool,
437
439
pad_token_id : config. eos_token_id as u32 ,
438
- device : vb. device ( ) . clone ( ) ,
439
440
num_attention_heads : config. num_attention_heads ,
441
+ dtype : vb. dtype ( ) ,
442
+ device : vb. device ( ) . clone ( ) ,
440
443
span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
441
444
} )
442
445
}
443
446
444
447
fn get_causal_attention_bias ( & self , attention_bias : Tensor ) -> Result < Tensor > {
445
448
let ( bs, dim, seq_len, _) = attention_bias. dims4 ( ) ?;
446
449
447
- let device = attention_bias. device ( ) ;
448
-
449
450
let mask: Vec < u8 > = ( 0 ..seq_len)
450
451
. flat_map ( |i| ( 0 ..seq_len) . map ( move |j| ( j > i) as u8 ) )
451
452
. collect ( ) ;
452
453
453
454
let causal_mask = Tensor :: from_slice ( & mask, ( seq_len, seq_len) , & Device :: Cpu ) ?;
454
455
let causal_mask = causal_mask. expand ( & [ bs, dim, seq_len, seq_len] ) ?;
455
456
456
- let negatives = Tensor :: full ( f32:: MIN , attention_bias. shape ( ) , & Device :: Cpu ) ?;
457
- let zeros = Tensor :: zeros_like ( & attention_bias) ?. to_device ( & Device :: Cpu ) ?;
457
+ let negatives =
458
+ Tensor :: full ( f32:: MIN , attention_bias. shape ( ) , & Device :: Cpu ) ?. to_dtype ( self . dtype ) ?;
459
+ let zeros = Tensor :: zeros_like ( & attention_bias) ?. to_dtype ( self . dtype ) ?;
458
460
459
461
let causal_mask = causal_mask
460
462
. where_cond ( & negatives, & zeros) ?
461
- . to_device ( device) ?;
463
+ . to_device ( & self . device ) ?;
462
464
463
465
attention_bias. broadcast_add ( & causal_mask)
464
466
}
@@ -494,7 +496,7 @@ impl Qwen3Model {
494
496
for _ in 0 ..padding {
495
497
input_ids. push ( self . pad_token_id ) ;
496
498
position_ids. push ( 0 ) ;
497
- attention_bias. push ( f32:: MIN ) ;
499
+ attention_bias. push ( f32:: NEG_INFINITY ) ;
498
500
}
499
501
}
500
502
@@ -539,7 +541,7 @@ impl Qwen3Model {
539
541
// Create attention bias for causal masking even for single sequences
540
542
let attention_bias = Tensor :: zeros (
541
543
( 1 , self . num_attention_heads , seq_len, seq_len) ,
542
- candle :: DType :: F32 ,
544
+ self . dtype ,
543
545
& self . device ,
544
546
) ?;
545
547
0 commit comments