Skip to content

Commit 1de2c4f

Browse files
committed
[naga] Vectorize [un]pack4x{I, U}8[Clamp] on spv
Emits vectorized SPIR-V code for the WGSL functions `unpack4xI8`, `unpack4xU8`, `pack4xI8`, `pack4xU8`, `pack4xI8Clamp`, `pack4xU8Clamp`. Exploits the following facts about SPIR-V ops: - `SClamp`, `UClamp`, and `OpUConvert` accept vector arguments, in which case results are computed per component; and - `OpBitcast` can cast between vectors and scalars, with a well-defined bit order that matches that required by the WGSL spec, see below. WGSL spec for `pack4xI8` [1]: > Component e[i] of the input is mapped to bits 8 x i through 8 x i + 7 > of the result. SPIR-V spec for `OpBitcast` [2]: > Within this mapping, any single component of `S` [remark: the type > with fewer but wider components] (mapping to multiple components of > `L` [remark: the type with more but narrower components]) maps its > lower-ordered bits to the lower-numbered components of `L`. [1] https://www.w3.org/TR/WGSL/#pack4xI8-builtin [2] https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
1 parent 50eb207 commit 1de2c4f

File tree

3 files changed

+317
-404
lines changed

3 files changed

+317
-404
lines changed

naga/src/back/spv/block.rs

Lines changed: 92 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,162 +1557,120 @@ impl BlockContext<'_> {
15571557
Mf::Pack4xU8 | Mf::Pack4xU8Clamp => (crate::ScalarKind::Uint, false),
15581558
_ => unreachable!(),
15591559
};
1560+
15601561
let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp);
1561-
let uint_type_id =
1562-
self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
15631562

1564-
let int_type_id =
1565-
self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
1563+
let wide_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
1564+
size: crate::VectorSize::Quad,
1565+
scalar: crate::Scalar {
15661566
kind: int_type,
15671567
width: 4,
1568-
}));
1569-
1570-
let mut last_instruction = Instruction::new(spirv::Op::Nop);
1571-
1572-
let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
1573-
let mut preresult = zero;
1574-
block
1575-
.body
1576-
.reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
1577-
1578-
let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1579-
const VEC_LENGTH: u8 = 4;
1580-
for i in 0..u32::from(VEC_LENGTH) {
1581-
let offset =
1582-
self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
1583-
let mut extracted = self.gen_id();
1584-
block.body.push(Instruction::binary(
1585-
spirv::Op::CompositeExtract,
1586-
int_type_id,
1587-
extracted,
1588-
arg0_id,
1589-
i,
1590-
));
1591-
if is_signed {
1592-
let casted = self.gen_id();
1593-
block.body.push(Instruction::unary(
1594-
spirv::Op::Bitcast,
1595-
uint_type_id,
1596-
casted,
1597-
extracted,
1598-
));
1599-
extracted = casted;
1600-
}
1601-
if should_clamp {
1602-
let (min, max, clamp_op) = if is_signed {
1603-
(
1604-
crate::Literal::I32(-128),
1605-
crate::Literal::I32(127),
1606-
spirv::GLOp::SClamp,
1607-
)
1608-
} else {
1609-
(
1610-
crate::Literal::U32(0),
1611-
crate::Literal::U32(255),
1612-
spirv::GLOp::UClamp,
1613-
)
1614-
};
1615-
let [min, max] =
1616-
[min, max].map(|lit| self.writer.get_constant_scalar(lit));
1617-
1618-
let clamp_id = self.gen_id();
1619-
block.body.push(Instruction::ext_inst(
1620-
self.writer.gl450_ext_inst_id,
1621-
clamp_op,
1622-
result_type_id,
1623-
clamp_id,
1624-
&[extracted, min, max],
1625-
));
1568+
},
1569+
});
1570+
let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
1571+
size: crate::VectorSize::Quad,
1572+
scalar: crate::Scalar {
1573+
kind: crate::ScalarKind::Uint,
1574+
width: 1,
1575+
},
1576+
});
16261577

1627-
extracted = clamp_id;
1628-
}
1629-
let is_last = i == u32::from(VEC_LENGTH - 1);
1630-
if is_last {
1631-
last_instruction = Instruction::quaternary(
1632-
spirv::Op::BitFieldInsert,
1633-
result_type_id,
1634-
id,
1635-
preresult,
1636-
extracted,
1637-
offset,
1638-
eight,
1578+
let mut wide_vector = arg0_id;
1579+
if should_clamp {
1580+
let (min, max, clamp_op) = if is_signed {
1581+
(
1582+
crate::Literal::I32(-128),
1583+
crate::Literal::I32(127),
1584+
spirv::GLOp::SClamp,
16391585
)
16401586
} else {
1641-
let new_preresult = self.gen_id();
1642-
block.body.push(Instruction::quaternary(
1643-
spirv::Op::BitFieldInsert,
1644-
result_type_id,
1645-
new_preresult,
1646-
preresult,
1647-
extracted,
1648-
offset,
1649-
eight,
1587+
(
1588+
crate::Literal::U32(0),
1589+
crate::Literal::U32(255),
1590+
spirv::GLOp::UClamp,
1591+
)
1592+
};
1593+
let [min, max] = [min, max].map(|lit| {
1594+
let scalar = self.writer.get_constant_scalar(lit);
1595+
// TODO: can we cache these constant vectors somehow?
1596+
let id = self.gen_id();
1597+
block.body.push(Instruction::composite_construct(
1598+
wide_vector_type_id,
1599+
id,
1600+
&[scalar; 4],
16501601
));
1651-
preresult = new_preresult;
1652-
}
1602+
id
1603+
});
1604+
1605+
let clamp_id = self.gen_id();
1606+
block.body.push(Instruction::ext_inst(
1607+
self.writer.gl450_ext_inst_id,
1608+
clamp_op,
1609+
wide_vector_type_id,
1610+
clamp_id,
1611+
&[wide_vector, min, max],
1612+
));
1613+
1614+
wide_vector = clamp_id;
16531615
}
16541616

1655-
MathOp::Custom(last_instruction)
1617+
let packed_vector = self.gen_id();
1618+
block.body.push(Instruction::unary(
1619+
spirv::Op::UConvert, // We truncate, so `UConvert` and `SConvert` behave identically.
1620+
packed_vector_type_id,
1621+
packed_vector,
1622+
wide_vector,
1623+
));
1624+
1625+
// The SPIR-V spec [1] defines the bit order for bit casting between a vector
1626+
// and a scalar precisely as required by the WGSL spec [2].
1627+
// [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
1628+
// [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin
1629+
MathOp::Custom(Instruction::unary(
1630+
spirv::Op::Bitcast,
1631+
result_type_id,
1632+
id,
1633+
packed_vector,
1634+
))
16561635
}
16571636
Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
16581637
Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
16591638
Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
16601639
Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
16611640
Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
16621641
fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1663-
let (int_type, extract_op, is_signed) = match fun {
1664-
Mf::Unpack4xI8 => {
1665-
(crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract, true)
1666-
}
1667-
Mf::Unpack4xU8 => {
1668-
(crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract, false)
1669-
}
1642+
let (int_type, convert_op) = match fun {
1643+
Mf::Unpack4xI8 => (crate::ScalarKind::Sint, spirv::Op::SConvert),
1644+
Mf::Unpack4xU8 => (crate::ScalarKind::Uint, spirv::Op::UConvert),
16701645
_ => unreachable!(),
16711646
};
16721647

1673-
let sint_type_id =
1674-
self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
1675-
1676-
let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1677-
let int_type_id =
1678-
self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
1648+
let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
1649+
size: crate::VectorSize::Quad,
1650+
scalar: crate::Scalar {
16791651
kind: int_type,
1680-
width: 4,
1681-
}));
1682-
block
1683-
.body
1684-
.reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
1685-
let arg_id = if is_signed {
1686-
let new_arg_id = self.gen_id();
1687-
block.body.push(Instruction::unary(
1688-
spirv::Op::Bitcast,
1689-
sint_type_id,
1690-
new_arg_id,
1691-
arg0_id,
1692-
));
1693-
new_arg_id
1694-
} else {
1695-
arg0_id
1696-
};
1697-
1698-
const VEC_LENGTH: u8 = 4;
1699-
let parts: [_; VEC_LENGTH as usize] =
1700-
core::array::from_fn(|_| self.gen_id());
1701-
for (i, part_id) in parts.into_iter().enumerate() {
1702-
let index = self
1703-
.writer
1704-
.get_constant_scalar(crate::Literal::U32(i as u32 * 8));
1705-
block.body.push(Instruction::ternary(
1706-
extract_op,
1707-
int_type_id,
1708-
part_id,
1709-
arg_id,
1710-
index,
1711-
eight,
1712-
));
1713-
}
1652+
width: 1,
1653+
},
1654+
});
1655+
1656+
// The SPIR-V spec [1] defines the bit order for bit casting between a vector
1657+
// and a scalar precisely as required by the WGSL spec [2].
1658+
// [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
1659+
// [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin
1660+
let packed_vector = self.gen_id();
1661+
block.body.push(Instruction::unary(
1662+
spirv::Op::Bitcast,
1663+
packed_vector_type_id,
1664+
packed_vector,
1665+
arg0_id,
1666+
));
17141667

1715-
MathOp::Custom(Instruction::composite_construct(result_type_id, id, &parts))
1668+
MathOp::Custom(Instruction::unary(
1669+
convert_op,
1670+
result_type_id,
1671+
id,
1672+
packed_vector,
1673+
))
17161674
}
17171675
};
17181676

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
; SPIR-V
22
; Version: 1.1
33
; Generator: rspirv
4-
; Bound: 30
4+
; Bound: 23
55
OpCapability Shader
6+
OpCapability Int8
67
%1 = OpExtInstImport "GLSL.std.450"
78
OpMemoryModel Logical GLSL450
89
OpEntryPoint GLCompute %4 "main"
@@ -14,27 +15,20 @@ OpExecutionMode %4 LocalSize 1 1 1
1415
%8 = OpTypeInt 32 0
1516
%9 = OpConstant %8 12
1617
%11 = OpTypeVector %6 4
17-
%13 = OpConstant %8 8
18-
%19 = OpConstant %8 0
19-
%20 = OpConstant %8 16
20-
%21 = OpConstant %8 24
21-
%23 = OpTypeVector %8 4
18+
%14 = OpTypeInt 8 1
19+
%13 = OpTypeVector %14 4
20+
%17 = OpTypeVector %8 4
21+
%20 = OpTypeInt 8 0
22+
%19 = OpTypeVector %20 4
2223
%4 = OpFunction %2 None %5
2324
%3 = OpLabel
2425
OpBranch %10
2526
%10 = OpLabel
26-
%14 = OpBitcast %6 %9
27-
%15 = OpBitFieldSExtract %6 %14 %19 %13
28-
%16 = OpBitFieldSExtract %6 %14 %13 %13
29-
%17 = OpBitFieldSExtract %6 %14 %20 %13
30-
%18 = OpBitFieldSExtract %6 %14 %21 %13
31-
%12 = OpCompositeConstruct %11 %15 %16 %17 %18
32-
%22 = OpCompositeExtract %6 %12 2
33-
%25 = OpBitFieldUExtract %8 %9 %19 %13
34-
%26 = OpBitFieldUExtract %8 %9 %13 %13
35-
%27 = OpBitFieldUExtract %8 %9 %20 %13
36-
%28 = OpBitFieldUExtract %8 %9 %21 %13
37-
%24 = OpCompositeConstruct %23 %25 %26 %27 %28
38-
%29 = OpCompositeExtract %8 %24 1
27+
%15 = OpBitcast %13 %9
28+
%12 = OpSConvert %11 %15
29+
%16 = OpCompositeExtract %6 %12 2
30+
%21 = OpBitcast %19 %9
31+
%18 = OpUConvert %17 %21
32+
%22 = OpCompositeExtract %8 %18 1
3933
OpReturn
4034
OpFunctionEnd

0 commit comments

Comments
 (0)