-
Notifications
You must be signed in to change notification settings - Fork 272
Add CPU support for Qwen3-Embedding models #632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This commit implements complete CPU support for Qwen3-Embedding models in the candle backend, addressing the community request for CPU-based inference. Key changes: - Add complete Qwen3 model architecture implementation (qwen3.rs) - Integrate Qwen3 model detection and loading in lib.rs - Update model module exports to include qwen3 - Implement comprehensive test suite with snapshot validation - Support both batch and single-item processing scenarios Technical implementation: - Fixed attention bias tensor shape handling for multi-head attention - Corrected rotary embeddings broadcasting for proper position encoding - Resolved MLP activation function conflicts (silu vs swiglu) - Implemented proper last-token pooling for embedding extraction - CPU-optimized tensor operations throughout the pipeline The implementation is production-ready and includes regression tests that validate embedding generation quality for the Qwen3-Embedding-0.6B model. Fixes CPU support gap for Qwen3 models and enables deployment on CPU-only environments without requiring CUDA or flash attention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @randomm thanks for the PR! I did an initial review with some nits, I may need to still look at it in detail and test it myself! Should also work for MPS out of the box, right? 🤗
ModelType::Embedding(pool) => pool, | ||
}; | ||
|
||
// Handle potential "model" prefix for reranker models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Handle potential "model" prefix for reranker models | |
// The Qwen3-Reranker models contain the `model` key | |
// https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea | |
rotary_dim: usize, | ||
pool: Pool, | ||
pub device: Device, | ||
num_attention_heads: usize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really needed? Isn't it more reliable to just layers[0].attention.attention_head_size
within the load
method?
let rotary_dim = config | ||
.head_dim | ||
.unwrap_or(config.hidden_size / config.num_attention_heads); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let rotary_dim = config | |
.head_dim | |
.unwrap_or(config.hidden_size / config.num_attention_heads); | |
let rotary_dim = layers[0].attention.attention_head_size; |
let gate_states = gate_up_states.narrow(D::Minus1, 0, self.intermediate_size)?; | ||
let up_states = gate_up_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; | ||
|
||
let gate_states = gate_states.silu()?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you leverage the HiddenAct.forward
method instead? i.e.
let gate_states = gate_states.silu()?; | |
let gate_states = self.act.forward(&gate_states)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upon #631 I mean
gate_up_proj, | ||
down_proj, | ||
intermediate_size, | ||
span: tracing::span!(tracing::Level::TRACE, "mlp"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the change mentioned related to self.act
you should add the act
field within the Qwen3MLP
struct and then set it here via config.hidden_act.clone()
, you can check the recently merged CUDA version if applicable #627
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, tried on both CPU and MPS and it works fine on both! Since the missing comments are nits / small edits, IMO we can merge and do the release @Narsil, thanks for the PR @randomm 🤗
Edit: @randomm linting is failing, could you enable allow edits by maintainers on the branch? Otherwise you can try to run pre-commit install
+ pre-commit run --all-files
, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sorry only saw this now. Too late, I presume? |
No worries at all @randomm it's fixed on |
Overview
This PR implements complete CPU support for Qwen3-Embedding models in the candle backend, addressing the community request for CPU-based inference capabilities.
Motivation
Qwen3-Embedding models currently only support CUDA devices with flash attention, limiting deployment to GPU-enabled environments. This implementation enables CPU-only deployment, making these state-of-the-art embedding models accessible to a broader range of use cases and deployment scenarios.
Changes
Core Implementation
backends/candle/src/models/qwen3.rs
backends/candle/src/lib.rs
backends/candle/src/models/mod.rs
Technical Fixes
Testing
backends/candle/tests/test_qwen3.rs
with full test coveragePerformance
Compatibility
Testing
The test downloads the Qwen3-Embedding-0.6B model and validates embedding generation for both batch and single-item scenarios.
Related Issues
Addresses community requests for CPU support in Qwen3 models and fills the gap in model architecture support for CPU-only deployments.
Breaking Changes
None. This is purely additive functionality that extends existing model support.
Checklist