Skip to content

Commit 0e0878a

Browse files
committed
Optimize constant casts to avoid unnecessary capabilities
Add constant folding for widening casts (u8->u32, f32->f64) to avoid creating intermediate types. Remove capability checks during type creation and add bidirectional validation in linker to remove unused type capabilities. Fixes #300
1 parent b19a858 commit 0e0878a

19 files changed

+418
-46
lines changed

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,6 +2033,31 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
20332033
if val.ty == dest_ty {
20342034
val
20352035
} else {
2036+
// If casting a constant, directly create a constant of the target type.
2037+
// This avoids creating intermediate types that might require additional
2038+
// capabilities. For example, casting a f16 constant to f32 will directly
2039+
// create a f32 constant, avoiding the need for Float16 capability if it is
2040+
// not used elsewhere.
2041+
if let Some(const_val) = self.builder.lookup_const_scalar(val) {
2042+
if let (SpirvType::Float(src_width), SpirvType::Float(dst_width)) =
2043+
(self.lookup_type(val.ty), self.lookup_type(dest_ty))
2044+
{
2045+
if src_width < dst_width {
2046+
// Convert the bit representation to the actual float value
2047+
let float_val = match src_width {
2048+
32 => Some(f32::from_bits(const_val as u32) as f64),
2049+
64 => Some(f64::from_bits(const_val as u64)),
2050+
_ => None,
2051+
};
2052+
2053+
if let Some(val) = float_val {
2054+
return self.constant_float(dest_ty, val);
2055+
}
2056+
}
2057+
}
2058+
}
2059+
2060+
// Regular conversion
20362061
self.emit()
20372062
.f_convert(dest_ty, None, val.def(self))
20382063
.unwrap()
@@ -2198,6 +2223,46 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
21982223
// I guess?
21992224
return val;
22002225
}
2226+
2227+
// If casting a constant, directly create a constant of the target type. This
2228+
// avoids creating intermediate types that might require additional
2229+
// capabilities. For example, casting a u8 constant to u32 will directly create
2230+
// a u32 constant, avoiding the need for Int8 capability if it is not used
2231+
// elsewhere.
2232+
if let Some(const_val) = self.builder.lookup_const_scalar(val) {
2233+
let src_ty = self.lookup_type(val.ty);
2234+
let dst_ty_spv = self.lookup_type(dest_ty);
2235+
2236+
// Try to optimize the constant cast
2237+
let optimized_result = match (src_ty, dst_ty_spv) {
2238+
// Integer to integer cast
2239+
(SpirvType::Integer(src_width, _), SpirvType::Integer(dst_width, _)) => {
2240+
// Only optimize if we're widening or keeping the same width.
2241+
// This avoids creating the source type when it's safe to do so.
2242+
// For narrowing casts (e.g., u32 as u8), we need the proper truncation
2243+
// behavior that the regular cast provides.
2244+
if src_width <= dst_width {
2245+
Some(self.constant_int(dest_ty, const_val))
2246+
} else {
2247+
None
2248+
}
2249+
}
2250+
// Bool to integer cast - const_val will be 0 or 1
2251+
(SpirvType::Bool, SpirvType::Integer(_, _)) => {
2252+
Some(self.constant_int(dest_ty, const_val))
2253+
}
2254+
// Integer to bool cast - compare with zero
2255+
(SpirvType::Integer(_, _), SpirvType::Bool) => {
2256+
Some(self.constant_bool(self.span(), const_val != 0))
2257+
}
2258+
_ => None,
2259+
};
2260+
2261+
if let Some(result) = optimized_result {
2262+
return result;
2263+
}
2264+
}
2265+
22012266
match (self.lookup_type(val.ty), self.lookup_type(dest_ty)) {
22022267
// sign change
22032268
(

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,16 @@ pub fn link(
477477
simple_passes::remove_non_uniform_decorations(sess, &mut output)?;
478478
}
479479

480+
{
481+
let _timer = sess.timer("link_remove_unused_type_capabilities");
482+
simple_passes::remove_unused_type_capabilities(&mut output);
483+
}
484+
485+
{
486+
let _timer = sess.timer("link_type_capability_check");
487+
simple_passes::check_type_capabilities(sess, &output)?;
488+
}
489+
480490
// NOTE(eddyb) SPIR-T pipeline is entirely limited to this block.
481491
{
482492
let (spv_words, module_or_err, lower_from_spv_timer) =

crates/rustc_codegen_spirv/src/linker/simple_passes.rs

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@ use rustc_session::Session;
77
use std::iter::once;
88
use std::mem::take;
99

10+
/// Returns the capability required for an integer type of the given width, if any.
11+
fn capability_for_int_width(width: u32) -> Option<rspirv::spirv::Capability> {
12+
match width {
13+
8 => Some(rspirv::spirv::Capability::Int8),
14+
16 => Some(rspirv::spirv::Capability::Int16),
15+
64 => Some(rspirv::spirv::Capability::Int64),
16+
_ => None,
17+
}
18+
}
19+
20+
/// Returns the capability required for a float type of the given width, if any.
21+
fn capability_for_float_width(width: u32) -> Option<rspirv::spirv::Capability> {
22+
match width {
23+
16 => Some(rspirv::spirv::Capability::Float16),
24+
64 => Some(rspirv::spirv::Capability::Float64),
25+
_ => None,
26+
}
27+
}
28+
1029
pub fn shift_ids(module: &mut Module, add: u32) {
1130
module.all_inst_iter_mut().for_each(|inst| {
1231
if let Some(ref mut result_id) = &mut inst.result_id {
@@ -266,6 +285,111 @@ pub fn check_fragment_insts(sess: &Session, module: &Module) -> Result<()> {
266285
}
267286
}
268287

288+
/// Check that types requiring specific capabilities have those capabilities declared.
289+
///
290+
/// This function validates that if a module uses types like u8/i8 (requiring Int8),
291+
/// u16/i16 (requiring Int16), etc., the corresponding capabilities are declared.
292+
pub fn check_type_capabilities(sess: &Session, module: &Module) -> Result<()> {
293+
use rspirv::spirv::Capability;
294+
295+
// Collect declared capabilities
296+
let declared_capabilities: FxHashSet<Capability> = module
297+
.capabilities
298+
.iter()
299+
.map(|inst| inst.operands[0].unwrap_capability())
300+
.collect();
301+
302+
let mut errors = Vec::new();
303+
304+
for inst in &module.types_global_values {
305+
match inst.class.opcode {
306+
Op::TypeInt => {
307+
let width = inst.operands[0].unwrap_literal_bit32();
308+
let signedness = inst.operands[1].unwrap_literal_bit32() != 0;
309+
let type_name = if signedness { "i" } else { "u" };
310+
311+
if let Some(required_cap) = capability_for_int_width(width) {
312+
if !declared_capabilities.contains(&required_cap) {
313+
errors.push(format!(
314+
"`{type_name}{width}` type used without `OpCapability {required_cap:?}`"
315+
));
316+
}
317+
}
318+
}
319+
Op::TypeFloat => {
320+
let width = inst.operands[0].unwrap_literal_bit32();
321+
322+
if let Some(required_cap) = capability_for_float_width(width) {
323+
if !declared_capabilities.contains(&required_cap) {
324+
errors.push(format!(
325+
"`f{width}` type used without `OpCapability {required_cap:?}`"
326+
));
327+
}
328+
}
329+
}
330+
_ => {}
331+
}
332+
}
333+
334+
if !errors.is_empty() {
335+
let mut err = sess
336+
.dcx()
337+
.struct_err("Missing required capabilities for types");
338+
for error in errors {
339+
err = err.with_note(error);
340+
}
341+
Err(err.emit())
342+
} else {
343+
Ok(())
344+
}
345+
}
346+
347+
/// Remove type-related capabilities that are not required by any types in the module.
348+
///
349+
/// This function specifically targets Int8, Int16, Int64, Float16, and Float64 capabilities,
350+
/// removing them if no types in the module require them. All other capabilities are preserved.
351+
/// This is part of the fix for issue #300 where constant casts were creating unnecessary types.
352+
pub fn remove_unused_type_capabilities(module: &mut Module) {
353+
use rspirv::spirv::Capability;
354+
355+
// Collect type-related capabilities that are actually needed
356+
let mut needed_type_capabilities = FxHashSet::default();
357+
358+
// Scan all types to determine which type-related capabilities are needed
359+
for inst in &module.types_global_values {
360+
match inst.class.opcode {
361+
Op::TypeInt => {
362+
let width = inst.operands[0].unwrap_literal_bit32();
363+
if let Some(cap) = capability_for_int_width(width) {
364+
needed_type_capabilities.insert(cap);
365+
}
366+
}
367+
Op::TypeFloat => {
368+
let width = inst.operands[0].unwrap_literal_bit32();
369+
if let Some(cap) = capability_for_float_width(width) {
370+
needed_type_capabilities.insert(cap);
371+
}
372+
}
373+
_ => {}
374+
}
375+
}
376+
377+
// Remove only type-related capabilities that aren't needed
378+
module.capabilities.retain(|inst| {
379+
let cap = inst.operands[0].unwrap_capability();
380+
match cap {
381+
// Only remove these type-related capabilities if they're not used
382+
Capability::Int8
383+
| Capability::Int16
384+
| Capability::Int64
385+
| Capability::Float16
386+
| Capability::Float64 => needed_type_capabilities.contains(&cap),
387+
// Keep all other capabilities
388+
_ => true,
389+
}
390+
});
391+
}
392+
269393
/// Remove all [`Decoration::NonUniform`] if this module does *not* have [`Capability::ShaderNonUniform`].
270394
/// This allows image asm to always declare `NonUniform` and not worry about conditional compilation.
271395
pub fn remove_non_uniform_decorations(_sess: &Session, module: &mut Module) -> Result<()> {

crates/rustc_codegen_spirv/src/spirv_type.rs

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::builder_spirv::SpirvValue;
33
use crate::codegen_cx::CodegenCx;
44
use indexmap::IndexSet;
55
use rspirv::dr::Operand;
6-
use rspirv::spirv::{Capability, Decoration, Dim, ImageFormat, StorageClass, Word};
6+
use rspirv::spirv::{Decoration, Dim, ImageFormat, StorageClass, Word};
77
use rustc_data_structures::fx::FxHashMap;
88
use rustc_middle::span_bug;
99
use rustc_span::def_id::DefId;
@@ -105,21 +105,6 @@ impl SpirvType<'_> {
105105
let result = cx.emit_global().type_int_id(id, width, signedness as u32);
106106
let u_or_i = if signedness { "i" } else { "u" };
107107
match width {
108-
8 if !cx.builder.has_capability(Capability::Int8) => cx.zombie_with_span(
109-
result,
110-
def_span,
111-
&format!("`{u_or_i}8` without `OpCapability Int8`"),
112-
),
113-
16 if !cx.builder.has_capability(Capability::Int16) => cx.zombie_with_span(
114-
result,
115-
def_span,
116-
&format!("`{u_or_i}16` without `OpCapability Int16`"),
117-
),
118-
64 if !cx.builder.has_capability(Capability::Int64) => cx.zombie_with_span(
119-
result,
120-
def_span,
121-
&format!("`{u_or_i}64` without `OpCapability Int64`"),
122-
),
123108
8 | 16 | 32 | 64 => {}
124109
w => cx.zombie_with_span(
125110
result,
@@ -132,16 +117,6 @@ impl SpirvType<'_> {
132117
Self::Float(width) => {
133118
let result = cx.emit_global().type_float_id(id, width);
134119
match width {
135-
16 if !cx.builder.has_capability(Capability::Float16) => cx.zombie_with_span(
136-
result,
137-
def_span,
138-
"`f16` without `OpCapability Float16`",
139-
),
140-
64 if !cx.builder.has_capability(Capability::Float64) => cx.zombie_with_span(
141-
result,
142-
def_span,
143-
"`f64` without `OpCapability Float64`",
144-
),
145120
16 | 32 | 64 => (),
146121
other => cx.zombie_with_span(
147122
result,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Test that constant float widening casts are optimized to avoid creating
2+
// the smaller float type when not needed elsewhere.
3+
4+
// build-pass
5+
// compile-flags: -C llvm-args=--disassemble-globals
6+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
7+
// normalize-stderr-test "OpSource .*\n" -> ""
8+
// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> ""
9+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
10+
11+
use spirv_std::spirv;
12+
13+
#[spirv(fragment)]
14+
pub fn main(output: &mut f64) {
15+
// This should optimize away the f32 type since it's widening
16+
const SMALL: f32 = 20.5;
17+
let widened = SMALL as f64;
18+
*output = widened;
19+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
OpCapability Shader
2+
OpCapability Float64
3+
OpCapability ShaderClockKHR
4+
OpExtension "SPV_KHR_shader_clock"
5+
OpMemoryModel Logical Simple
6+
OpEntryPoint Fragment %1 "main" %2
7+
OpExecutionMode %1 OriginUpperLeft
8+
%3 = OpString "$OPSTRING_FILENAME/const-float-cast-optimized.rs"
9+
OpName %2 "output"
10+
OpDecorate %2 Location 0
11+
%4 = OpTypeFloat 64
12+
%5 = OpTypePointer Output %4
13+
%6 = OpTypeVoid
14+
%7 = OpTypeFunction %6
15+
%2 = OpVariable %5 Output
16+
%8 = OpConstant %4 4626463454704697344
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Test whether float constant casts need optimization
2+
3+
// build-pass
4+
// compile-flags: -C llvm-args=--disassemble-globals
5+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
6+
// normalize-stderr-test "OpSource .*\n" -> ""
7+
// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> ""
8+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
9+
10+
use spirv_std::spirv;
11+
12+
#[spirv(fragment)]
13+
pub fn main(output: &mut f32) {
14+
// Test f64 to f32 (narrowing)
15+
const BIG: f64 = 123.456;
16+
let narrowed = BIG as f32;
17+
*output = narrowed;
18+
19+
// Test f32 to f64 (widening) - this might create f32 type unnecessarily
20+
const SMALL: f32 = 20.5;
21+
let widened = SMALL as f64;
22+
*output += widened as f32;
23+
24+
let kept: f32 = 1.0 + SMALL;
25+
*output += kept;
26+
27+
// Test integer to float
28+
const INT: u32 = 42;
29+
let as_float = INT as f32;
30+
*output += as_float;
31+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
OpCapability Shader
2+
OpCapability Float64
3+
OpCapability ShaderClockKHR
4+
OpExtension "SPV_KHR_shader_clock"
5+
OpMemoryModel Logical Simple
6+
OpEntryPoint Fragment %1 "main" %2
7+
OpExecutionMode %1 OriginUpperLeft
8+
%3 = OpString "$OPSTRING_FILENAME/const-float-cast.rs"
9+
OpName %2 "output"
10+
OpDecorate %2 Location 0
11+
%4 = OpTypeFloat 32
12+
%5 = OpTypePointer Output %4
13+
%6 = OpTypeVoid
14+
%7 = OpTypeFunction %6
15+
%8 = OpTypeFloat 64
16+
%9 = OpConstant %8 4638387860618067575
17+
%2 = OpVariable %5 Output
18+
%10 = OpConstant %8 4626463454704697344
19+
%11 = OpConstant %4 1065353216
20+
%12 = OpConstant %4 1101266944
21+
%13 = OpTypeInt 32 0
22+
%14 = OpConstant %13 42
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Test that constant integer casts are optimized to avoid creating intermediate types
2+
// that would require additional capabilities (e.g., Int8 capability for u8).
3+
4+
// build-pass
5+
// compile-flags: -C llvm-args=--disassemble-globals
6+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
7+
// normalize-stderr-test "OpSource .*\n" -> ""
8+
// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> ""
9+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
10+
11+
use spirv_std::spirv;
12+
13+
const K: u8 = 20;
14+
15+
#[spirv(fragment)]
16+
pub fn main(output: &mut u32) {
17+
let position = 2u32;
18+
// This cast should be optimized to directly create a u32 constant with value 20,
19+
// avoiding the creation of a u8 type that would require Int8 capability
20+
let global_y_offset_bits = position * K as u32;
21+
*output = global_y_offset_bits;
22+
}

0 commit comments

Comments
 (0)