@@ -1557,162 +1557,120 @@ impl BlockContext<'_> {
1557
1557
Mf :: Pack4xU8 | Mf :: Pack4xU8Clamp => ( crate :: ScalarKind :: Uint , false ) ,
1558
1558
_ => unreachable ! ( ) ,
1559
1559
} ;
1560
+
1560
1561
let should_clamp = matches ! ( fun, Mf :: Pack4xI8Clamp | Mf :: Pack4xU8Clamp ) ;
1561
- let uint_type_id =
1562
- self . get_numeric_type_id ( NumericType :: Scalar ( crate :: Scalar :: U32 ) ) ;
1563
1562
1564
- let int_type_id =
1565
- self . get_numeric_type_id ( NumericType :: Scalar ( crate :: Scalar {
1563
+ let wide_vector_type_id = self . get_numeric_type_id ( NumericType :: Vector {
1564
+ size : crate :: VectorSize :: Quad ,
1565
+ scalar : crate :: Scalar {
1566
1566
kind : int_type,
1567
1567
width : 4 ,
1568
- } ) ) ;
1569
-
1570
- let mut last_instruction = Instruction :: new ( spirv:: Op :: Nop ) ;
1571
-
1572
- let zero = self . writer . get_constant_scalar ( crate :: Literal :: U32 ( 0 ) ) ;
1573
- let mut preresult = zero;
1574
- block
1575
- . body
1576
- . reserve ( usize:: from ( VEC_LENGTH ) * ( 2 + usize:: from ( is_signed) ) ) ;
1577
-
1578
- let eight = self . writer . get_constant_scalar ( crate :: Literal :: U32 ( 8 ) ) ;
1579
- const VEC_LENGTH : u8 = 4 ;
1580
- for i in 0 ..u32:: from ( VEC_LENGTH ) {
1581
- let offset =
1582
- self . writer . get_constant_scalar ( crate :: Literal :: U32 ( i * 8 ) ) ;
1583
- let mut extracted = self . gen_id ( ) ;
1584
- block. body . push ( Instruction :: binary (
1585
- spirv:: Op :: CompositeExtract ,
1586
- int_type_id,
1587
- extracted,
1588
- arg0_id,
1589
- i,
1590
- ) ) ;
1591
- if is_signed {
1592
- let casted = self . gen_id ( ) ;
1593
- block. body . push ( Instruction :: unary (
1594
- spirv:: Op :: Bitcast ,
1595
- uint_type_id,
1596
- casted,
1597
- extracted,
1598
- ) ) ;
1599
- extracted = casted;
1600
- }
1601
- if should_clamp {
1602
- let ( min, max, clamp_op) = if is_signed {
1603
- (
1604
- crate :: Literal :: I32 ( -128 ) ,
1605
- crate :: Literal :: I32 ( 127 ) ,
1606
- spirv:: GLOp :: SClamp ,
1607
- )
1608
- } else {
1609
- (
1610
- crate :: Literal :: U32 ( 0 ) ,
1611
- crate :: Literal :: U32 ( 255 ) ,
1612
- spirv:: GLOp :: UClamp ,
1613
- )
1614
- } ;
1615
- let [ min, max] =
1616
- [ min, max] . map ( |lit| self . writer . get_constant_scalar ( lit) ) ;
1617
-
1618
- let clamp_id = self . gen_id ( ) ;
1619
- block. body . push ( Instruction :: ext_inst (
1620
- self . writer . gl450_ext_inst_id ,
1621
- clamp_op,
1622
- result_type_id,
1623
- clamp_id,
1624
- & [ extracted, min, max] ,
1625
- ) ) ;
1568
+ } ,
1569
+ } ) ;
1570
+ let packed_vector_type_id = self . get_numeric_type_id ( NumericType :: Vector {
1571
+ size : crate :: VectorSize :: Quad ,
1572
+ scalar : crate :: Scalar {
1573
+ kind : crate :: ScalarKind :: Uint ,
1574
+ width : 1 ,
1575
+ } ,
1576
+ } ) ;
1626
1577
1627
- extracted = clamp_id;
1628
- }
1629
- let is_last = i == u32:: from ( VEC_LENGTH - 1 ) ;
1630
- if is_last {
1631
- last_instruction = Instruction :: quaternary (
1632
- spirv:: Op :: BitFieldInsert ,
1633
- result_type_id,
1634
- id,
1635
- preresult,
1636
- extracted,
1637
- offset,
1638
- eight,
1578
+ let mut wide_vector = arg0_id;
1579
+ if should_clamp {
1580
+ let ( min, max, clamp_op) = if is_signed {
1581
+ (
1582
+ crate :: Literal :: I32 ( -128 ) ,
1583
+ crate :: Literal :: I32 ( 127 ) ,
1584
+ spirv:: GLOp :: SClamp ,
1639
1585
)
1640
1586
} else {
1641
- let new_preresult = self . gen_id ( ) ;
1642
- block. body . push ( Instruction :: quaternary (
1643
- spirv:: Op :: BitFieldInsert ,
1644
- result_type_id,
1645
- new_preresult,
1646
- preresult,
1647
- extracted,
1648
- offset,
1649
- eight,
1587
+ (
1588
+ crate :: Literal :: U32 ( 0 ) ,
1589
+ crate :: Literal :: U32 ( 255 ) ,
1590
+ spirv:: GLOp :: UClamp ,
1591
+ )
1592
+ } ;
1593
+ let [ min, max] = [ min, max] . map ( |lit| {
1594
+ let scalar = self . writer . get_constant_scalar ( lit) ;
1595
+ // TODO: can we cache these constant vectors somehow?
1596
+ let id = self . gen_id ( ) ;
1597
+ block. body . push ( Instruction :: composite_construct (
1598
+ wide_vector_type_id,
1599
+ id,
1600
+ & [ scalar; 4 ] ,
1650
1601
) ) ;
1651
- preresult = new_preresult;
1652
- }
1602
+ id
1603
+ } ) ;
1604
+
1605
+ let clamp_id = self . gen_id ( ) ;
1606
+ block. body . push ( Instruction :: ext_inst (
1607
+ self . writer . gl450_ext_inst_id ,
1608
+ clamp_op,
1609
+ wide_vector_type_id,
1610
+ clamp_id,
1611
+ & [ wide_vector, min, max] ,
1612
+ ) ) ;
1613
+
1614
+ wide_vector = clamp_id;
1653
1615
}
1654
1616
1655
- MathOp :: Custom ( last_instruction)
1617
+ let packed_vector = self . gen_id ( ) ;
1618
+ block. body . push ( Instruction :: unary (
1619
+ spirv:: Op :: UConvert , // We truncate, so `UConvert` and `SConvert` behave identically.
1620
+ packed_vector_type_id,
1621
+ packed_vector,
1622
+ wide_vector,
1623
+ ) ) ;
1624
+
1625
+ // The SPIR-V spec [1] defines the bit order for bit casting between a vector
1626
+ // and a scalar precisely as required by the WGSL spec [2].
1627
+ // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
1628
+ // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin
1629
+ MathOp :: Custom ( Instruction :: unary (
1630
+ spirv:: Op :: Bitcast ,
1631
+ result_type_id,
1632
+ id,
1633
+ packed_vector,
1634
+ ) )
1656
1635
}
1657
1636
Mf :: Unpack4x8unorm => MathOp :: Ext ( spirv:: GLOp :: UnpackUnorm4x8 ) ,
1658
1637
Mf :: Unpack4x8snorm => MathOp :: Ext ( spirv:: GLOp :: UnpackSnorm4x8 ) ,
1659
1638
Mf :: Unpack2x16float => MathOp :: Ext ( spirv:: GLOp :: UnpackHalf2x16 ) ,
1660
1639
Mf :: Unpack2x16unorm => MathOp :: Ext ( spirv:: GLOp :: UnpackUnorm2x16 ) ,
1661
1640
Mf :: Unpack2x16snorm => MathOp :: Ext ( spirv:: GLOp :: UnpackSnorm2x16 ) ,
1662
1641
fun @ ( Mf :: Unpack4xI8 | Mf :: Unpack4xU8 ) => {
1663
- let ( int_type, extract_op, is_signed) = match fun {
1664
- Mf :: Unpack4xI8 => {
1665
- ( crate :: ScalarKind :: Sint , spirv:: Op :: BitFieldSExtract , true )
1666
- }
1667
- Mf :: Unpack4xU8 => {
1668
- ( crate :: ScalarKind :: Uint , spirv:: Op :: BitFieldUExtract , false )
1669
- }
1642
+ let ( int_type, convert_op) = match fun {
1643
+ Mf :: Unpack4xI8 => ( crate :: ScalarKind :: Sint , spirv:: Op :: SConvert ) ,
1644
+ Mf :: Unpack4xU8 => ( crate :: ScalarKind :: Uint , spirv:: Op :: UConvert ) ,
1670
1645
_ => unreachable ! ( ) ,
1671
1646
} ;
1672
1647
1673
- let sint_type_id =
1674
- self . get_numeric_type_id ( NumericType :: Scalar ( crate :: Scalar :: I32 ) ) ;
1675
-
1676
- let eight = self . writer . get_constant_scalar ( crate :: Literal :: U32 ( 8 ) ) ;
1677
- let int_type_id =
1678
- self . get_numeric_type_id ( NumericType :: Scalar ( crate :: Scalar {
1648
+ let packed_vector_type_id = self . get_numeric_type_id ( NumericType :: Vector {
1649
+ size : crate :: VectorSize :: Quad ,
1650
+ scalar : crate :: Scalar {
1679
1651
kind : int_type,
1680
- width : 4 ,
1681
- } ) ) ;
1682
- block
1683
- . body
1684
- . reserve ( usize:: from ( VEC_LENGTH ) * 2 + usize:: from ( is_signed) ) ;
1685
- let arg_id = if is_signed {
1686
- let new_arg_id = self . gen_id ( ) ;
1687
- block. body . push ( Instruction :: unary (
1688
- spirv:: Op :: Bitcast ,
1689
- sint_type_id,
1690
- new_arg_id,
1691
- arg0_id,
1692
- ) ) ;
1693
- new_arg_id
1694
- } else {
1695
- arg0_id
1696
- } ;
1697
-
1698
- const VEC_LENGTH : u8 = 4 ;
1699
- let parts: [ _ ; VEC_LENGTH as usize ] =
1700
- core:: array:: from_fn ( |_| self . gen_id ( ) ) ;
1701
- for ( i, part_id) in parts. into_iter ( ) . enumerate ( ) {
1702
- let index = self
1703
- . writer
1704
- . get_constant_scalar ( crate :: Literal :: U32 ( i as u32 * 8 ) ) ;
1705
- block. body . push ( Instruction :: ternary (
1706
- extract_op,
1707
- int_type_id,
1708
- part_id,
1709
- arg_id,
1710
- index,
1711
- eight,
1712
- ) ) ;
1713
- }
1652
+ width : 1 ,
1653
+ } ,
1654
+ } ) ;
1655
+
1656
+ // The SPIR-V spec [1] defines the bit order for bit casting between a vector
1657
+ // and a scalar precisely as required by the WGSL spec [2].
1658
+ // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
1659
+ // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin
1660
+ let packed_vector = self . gen_id ( ) ;
1661
+ block. body . push ( Instruction :: unary (
1662
+ spirv:: Op :: Bitcast ,
1663
+ packed_vector_type_id,
1664
+ packed_vector,
1665
+ arg0_id,
1666
+ ) ) ;
1714
1667
1715
- MathOp :: Custom ( Instruction :: composite_construct ( result_type_id, id, & parts) )
1668
+ MathOp :: Custom ( Instruction :: unary (
1669
+ convert_op,
1670
+ result_type_id,
1671
+ id,
1672
+ packed_vector,
1673
+ ) )
1716
1674
}
1717
1675
} ;
1718
1676
0 commit comments