Skip to content

Impl. const. eval. for "first bit" numeric built-ins #5101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ Bottom level categories:

## Unreleased

### New Features

#### Naga

* Support constant evaluation for `firstLeadingBit` and `firstTrailingBit` numeric built-ins in WGSL. Front-ends that translate to these built-ins also benefit from constant evaluation. By @ErichDonGubler in [#5101](https://github.com/gfx-rs/wgpu/pull/5101).

### Bug Fixes

#### General
Expand Down
10 changes: 6 additions & 4 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3647,8 +3647,8 @@ impl<'a, W: Write> Writer<'a, W> {

return Ok(());
}
Mf::FindLsb => "findLSB",
Mf::FindMsb => "findMSB",
Mf::FirstTrailingBit => "findLSB",
Mf::FirstLeadingBit => "findMSB",
// data packing
Mf::Pack4x8snorm => "packSnorm4x8",
Mf::Pack4x8unorm => "packUnorm4x8",
Expand Down Expand Up @@ -3722,8 +3722,10 @@ impl<'a, W: Write> Writer<'a, W> {

// Some GLSL functions always return signed integers (like findMSB),
// so they need to be cast to uint if the argument is also an uint.
let ret_might_need_int_to_uint =
matches!(fun, Mf::FindLsb | Mf::FindMsb | Mf::CountOneBits | Mf::Abs);
let ret_might_need_int_to_uint = matches!(
fun,
Mf::FirstTrailingBit | Mf::FirstLeadingBit | Mf::CountOneBits | Mf::Abs
);

// Some GLSL functions only accept signed integers (like abs),
// so they need their argument cast from uint to int.
Expand Down
4 changes: 2 additions & 2 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3063,8 +3063,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::CountLeadingZeros => Function::CountLeadingZeros,
Mf::CountOneBits => Function::MissingIntOverload("countbits"),
Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"),
Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"),
Mf::FirstTrailingBit => Function::MissingIntReturnType("firstbitlow"),
Mf::FirstLeadingBit => Function::MissingIntReturnType("firstbithigh"),
Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
// Data Packing
Expand Down
10 changes: 5 additions & 5 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1875,8 +1875,8 @@ impl<W: Write> Writer<W> {
Mf::ReverseBits => "reverse_bits",
Mf::ExtractBits => "",
Mf::InsertBits => "",
Mf::FindLsb => "",
Mf::FindMsb => "",
Mf::FirstTrailingBit => "",
Mf::FirstLeadingBit => "",
// data packing
Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
Expand Down Expand Up @@ -1920,15 +1920,15 @@ impl<W: Write> Writer<W> {
self.put_expression(arg1.unwrap(), context, false)?;
write!(self.out, ")")?;
}
Mf::FindLsb => {
Mf::FirstTrailingBit => {
let scalar = context.resolve_type(arg).scalar().unwrap();
let constant = scalar.width * 8 + 1;

write!(self.out, "((({NAMESPACE}::ctz(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ") + 1) % {constant}) - 1)")?;
}
Mf::FindMsb => {
Mf::FirstLeadingBit => {
let inner = context.resolve_type(arg);
let scalar = inner.scalar().unwrap();
let constant = scalar.width * 8 - 1;
Expand Down Expand Up @@ -2702,7 +2702,7 @@ impl<W: Write> Writer<W> {
}
}
}
crate::MathFunction::FindMsb
crate::MathFunction::FirstLeadingBit
| crate::MathFunction::Pack4xI8
| crate::MathFunction::Pack4xU8
| crate::MathFunction::Unpack4xI8
Expand Down
6 changes: 3 additions & 3 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1183,13 +1183,13 @@ impl<'w> BlockContext<'w> {
count_id,
))
}
Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb),
Mf::FindMsb => {
Mf::FirstTrailingBit => MathOp::Ext(spirv::GLOp::FindILsb),
Mf::FirstLeadingBit => {
if arg_ty.scalar_width() == Some(4) {
let thing = match arg_scalar_kind {
Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
other => unimplemented!("Unexpected findMSB({:?})", other),
other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
};
MathOp::Ext(thing)
} else {
Expand Down
4 changes: 2 additions & 2 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1710,8 +1710,8 @@ impl<W: Write> Writer<W> {
Mf::ReverseBits => Function::Regular("reverseBits"),
Mf::ExtractBits => Function::Regular("extractBits"),
Mf::InsertBits => Function::Regular("insertBits"),
Mf::FindLsb => Function::Regular("firstTrailingBit"),
Mf::FindMsb => Function::Regular("firstLeadingBit"),
Mf::FirstTrailingBit => Function::Regular("firstTrailingBit"),
Mf::FirstLeadingBit => Function::Regular("firstLeadingBit"),
// data packing
Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"),
Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"),
Expand Down
16 changes: 10 additions & 6 deletions naga/src/front/glsl/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,8 @@ fn inject_standard_builtins(
"bitfieldReverse" => MathFunction::ReverseBits,
"bitfieldExtract" => MathFunction::ExtractBits,
"bitfieldInsert" => MathFunction::InsertBits,
"findLSB" => MathFunction::FindLsb,
"findMSB" => MathFunction::FindMsb,
"findLSB" => MathFunction::FirstTrailingBit,
"findMSB" => MathFunction::FirstLeadingBit,
_ => unreachable!(),
};

Expand Down Expand Up @@ -695,8 +695,12 @@ fn inject_standard_builtins(
// we need to cast the return type of findLsb / findMsb
let mc = if scalar.kind == Sk::Uint {
match mc {
MacroCall::MathFunction(MathFunction::FindLsb) => MacroCall::FindLsbUint,
MacroCall::MathFunction(MathFunction::FindMsb) => MacroCall::FindMsbUint,
MacroCall::MathFunction(MathFunction::FirstTrailingBit) => {
MacroCall::FindLsbUint
}
MacroCall::MathFunction(MathFunction::FirstLeadingBit) => {
MacroCall::FindMsbUint
}
mc => mc,
}
} else {
Expand Down Expand Up @@ -1787,8 +1791,8 @@ impl MacroCall {
)?,
mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => {
let fun = match mc {
MacroCall::FindLsbUint => MathFunction::FindLsb,
MacroCall::FindMsbUint => MathFunction::FindMsb,
MacroCall::FindLsbUint => MathFunction::FirstTrailingBit,
MacroCall::FindMsbUint => MathFunction::FirstLeadingBit,
_ => unreachable!(),
};
let res = ctx.add_expression(
Expand Down
4 changes: 2 additions & 2 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3026,8 +3026,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Glo::UnpackHalf2x16 => Mf::Unpack2x16float,
Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm,
Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm,
Glo::FindILsb => Mf::FindLsb,
Glo::FindUMsb | Glo::FindSMsb => Mf::FindMsb,
Glo::FindILsb => Mf::FirstTrailingBit,
Glo::FindUMsb | Glo::FindSMsb => Mf::FirstLeadingBit,
// TODO: https://github.com/gfx-rs/naga/issues/2526
Glo::Modf | Glo::Frexp => return Err(Error::UnsupportedExtInst(inst_id)),
Glo::IMix
Expand Down
4 changes: 2 additions & 2 deletions naga/src/front/wgsl/parse/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
"reverseBits" => Mf::ReverseBits,
"extractBits" => Mf::ExtractBits,
"insertBits" => Mf::InsertBits,
"firstTrailingBit" => Mf::FindLsb,
"firstLeadingBit" => Mf::FindMsb,
"firstTrailingBit" => Mf::FirstTrailingBit,
"firstLeadingBit" => Mf::FirstLeadingBit,
// data packing
"pack4x8snorm" => Mf::Pack4x8snorm,
"pack4x8unorm" => Mf::Pack4x8unorm,
Expand Down
4 changes: 2 additions & 2 deletions naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1198,8 +1198,8 @@ pub enum MathFunction {
ReverseBits,
ExtractBits,
InsertBits,
FindLsb,
FindMsb,
FirstTrailingBit,
FirstLeadingBit,
// data packing
Pack4x8snorm,
Pack4x8unorm,
Expand Down
176 changes: 176 additions & 0 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ macro_rules! gen_component_wise_extractor {
scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
) => {
/// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq))]
enum $target<const N: usize> {
$(
#[doc = concat!(
Expand Down Expand Up @@ -1231,6 +1233,12 @@ impl<'a> ConstantEvaluator<'a> {
crate::MathFunction::ReverseBits => {
component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
}
crate::MathFunction::FirstTrailingBit => {
component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
}
crate::MathFunction::FirstLeadingBit => {
component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
}

fun => Err(ConstantEvaluatorError::NotImplemented(format!(
"{fun:?} built-in function"
Expand Down Expand Up @@ -2096,6 +2104,174 @@ impl<'a> ConstantEvaluator<'a> {
}
}

fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
// NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
// of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
// return a right-to-left bit index of 0.
let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
match e {
idx @ 0..=31 => idx,
32 => u32::MAX,
_ => unreachable!(),
}
};
match concrete_int {
ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
ConcreteInt::I32([e]) => {
ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
}
}
}

#[test]
fn first_trailing_bit_smoke() {
assert_eq!(
first_trailing_bit(ConcreteInt::I32([0])),
ConcreteInt::I32([-1])
);
assert_eq!(
first_trailing_bit(ConcreteInt::I32([1])),
ConcreteInt::I32([0])
);
assert_eq!(
first_trailing_bit(ConcreteInt::I32([2])),
ConcreteInt::I32([1])
);
assert_eq!(
first_trailing_bit(ConcreteInt::I32([-1])),
ConcreteInt::I32([0]),
);
assert_eq!(
first_trailing_bit(ConcreteInt::I32([i32::MIN])),
ConcreteInt::I32([31]),
);
assert_eq!(
first_trailing_bit(ConcreteInt::I32([i32::MAX])),
ConcreteInt::I32([0]),
);
for idx in 0..32 {
assert_eq!(
first_trailing_bit(ConcreteInt::I32([1 << idx])),
ConcreteInt::I32([idx])
)
}

assert_eq!(
first_trailing_bit(ConcreteInt::U32([0])),
ConcreteInt::U32([u32::MAX])
);
assert_eq!(
first_trailing_bit(ConcreteInt::U32([1])),
ConcreteInt::U32([0])
);
assert_eq!(
first_trailing_bit(ConcreteInt::U32([2])),
ConcreteInt::U32([1])
);
assert_eq!(
first_trailing_bit(ConcreteInt::U32([1 << 31])),
ConcreteInt::U32([31]),
);
assert_eq!(
first_trailing_bit(ConcreteInt::U32([u32::MAX])),
ConcreteInt::U32([0]),
);
for idx in 0..32 {
assert_eq!(
first_trailing_bit(ConcreteInt::U32([1 << idx])),
ConcreteInt::U32([idx])
)
}
}

fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
// NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
// the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
// index of 0.
let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
match e {
idx @ 0..=31 => 31 - idx,
32 => u32::MAX,
_ => unreachable!(),
}
};
match concrete_int {
ConcreteInt::I32([e]) => ConcreteInt::I32([{
let rtl_bit_index = if e.is_negative() {
e.leading_ones()
} else {
e.leading_zeros()
};
rtl_to_ltr_bit_idx(rtl_bit_index) as i32
}]),
ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
}
}

#[test]
fn first_leading_bit_smoke() {
assert_eq!(
first_leading_bit(ConcreteInt::I32([-1])),
ConcreteInt::I32([-1])
);
assert_eq!(
first_leading_bit(ConcreteInt::I32([0])),
ConcreteInt::I32([-1])
);
assert_eq!(
first_leading_bit(ConcreteInt::I32([1])),
ConcreteInt::I32([0])
);
assert_eq!(
first_leading_bit(ConcreteInt::I32([-2])),
ConcreteInt::I32([0])
);
assert_eq!(
first_leading_bit(ConcreteInt::I32([1234 + 4567])),
ConcreteInt::I32([12])
);
assert_eq!(
first_leading_bit(ConcreteInt::I32([i32::MAX])),
ConcreteInt::I32([30])
);
assert_eq!(
first_leading_bit(ConcreteInt::I32([i32::MIN])),
ConcreteInt::I32([30])
);
// NOTE: Ignore the sign bit, which is a separate (above) case.
for idx in 0..(32 - 1) {
assert_eq!(
first_leading_bit(ConcreteInt::I32([1 << idx])),
ConcreteInt::I32([idx])
);
}
for idx in 1..(32 - 1) {
assert_eq!(
first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
ConcreteInt::I32([idx - 1])
);
}

assert_eq!(
first_leading_bit(ConcreteInt::U32([0])),
ConcreteInt::U32([u32::MAX])
);
assert_eq!(
first_leading_bit(ConcreteInt::U32([1])),
ConcreteInt::U32([0])
);
assert_eq!(
first_leading_bit(ConcreteInt::U32([u32::MAX])),
ConcreteInt::U32([31])
);
for idx in 0..32 {
assert_eq!(
first_leading_bit(ConcreteInt::U32([1 << idx])),
ConcreteInt::U32([idx])
)
}
}

/// Trait for conversions of abstract values to concrete types.
trait TryFromAbstract<T>: Sized {
/// Convert an abstract literal `value` to `Self`.
Expand Down
4 changes: 2 additions & 2 deletions naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ impl super::MathFunction {
Self::ReverseBits => 1,
Self::ExtractBits => 3,
Self::InsertBits => 4,
Self::FindLsb => 1,
Self::FindMsb => 1,
Self::FirstTrailingBit => 1,
Self::FirstLeadingBit => 1,
// data packing
Self::Pack4x8snorm => 1,
Self::Pack4x8unorm => 1,
Expand Down
Loading
Loading