Skip to content

Fix invalid ArrayStride decorations on arrays in function/private storage #297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,16 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
// There's a potential for this array to be sized, but the element to be unsized, e.g. `[[u8]; 5]`.
// However, I think rust disallows all these cases, so assert this here.
assert_eq!(count, 0);
let element_spirv = cx.lookup_type(element_type);
// Calculate stride with alignment for runtime arrays
let stride = element_spirv.physical_size(cx).and_then(|_| {
element_spirv
.sizeof(cx)
.map(|size| size.align_to(element_spirv.alignof(cx)).bytes() as u32)
});
SpirvType::RuntimeArray {
element: element_type,
stride,
}
.def(span, cx)
} else if count == 0 {
Expand All @@ -756,9 +764,13 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
.expect("Unexpected unsized type in sized FieldsShape::Array")
.align_to(element_spv.alignof(cx));
assert_eq!(stride_spv, stride);
// For arrays with explicit layout, use the actual stride from Rust's layout
// which already accounts for alignment
let array_stride = element_spv.physical_size(cx).map(|_| stride.bytes() as u32);
SpirvType::Array {
element: element_type,
count: count_const,
stride: array_stride,
}
.def(span, cx)
}
Expand Down Expand Up @@ -1060,8 +1072,17 @@ fn trans_intrinsic_type<'tcx>(
// We use a generic param to indicate the underlying element type.
// The SPIR-V element type will be generated from the first generic param.
if let Some(elem_ty) = args.types().next() {
let element_type = cx.layout_of(elem_ty).spirv_type(span, cx);
let element_spirv = cx.lookup_type(element_type);
// Calculate stride with alignment for intrinsic runtime arrays
let stride = element_spirv.physical_size(cx).and_then(|_| {
element_spirv
.sizeof(cx)
.map(|size| size.align_to(element_spirv.alignof(cx)).bytes() as u32)
});
Ok(SpirvType::RuntimeArray {
element: cx.layout_of(elem_ty).spirv_type(span, cx),
element: element_type,
stride,
}
.def(span, cx))
} else {
Expand Down
8 changes: 4 additions & 4 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
)
.def(self)
}
SpirvType::Array { element, count } => {
SpirvType::Array { element, count, .. } => {
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
let count = self.builder.lookup_const_scalar(count).unwrap() as usize;
self.constant_composite(
Expand Down Expand Up @@ -301,7 +301,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
_ => self.fatal(format!("memset on float width {width} not implemented yet")),
},
SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"),
SpirvType::Array { element, count } => {
SpirvType::Array { element, count, .. } => {
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
let count = self.builder.lookup_const_scalar(count).unwrap() as usize;
self.emit()
Expand Down Expand Up @@ -590,7 +590,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
SpirvType::Vector { element, .. }
| SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element }
| SpirvType::RuntimeArray { element, .. }
| SpirvType::Matrix { element, .. } => {
trace!("recovering access chain from Vector, Array, RuntimeArray, or Matrix");
ty = element;
Expand Down Expand Up @@ -687,7 +687,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
// If it's an array, vector, or matrix, indexing yields the element type.
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element }
| SpirvType::RuntimeArray { element, .. }
| SpirvType::Vector { element, .. }
| SpirvType::Matrix { element, .. } => element,
// Special case: If we started with a byte GEP (`is_byte_gep` is true) and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
element,
count,
),
SpirvType::Array { element, count } => {
SpirvType::Array { element, count, .. } => {
let count = match self.builder.lookup_const_scalar(count) {
Some(count) => count as u32,
None => return self.load_err(original_type, result_type),
Expand Down Expand Up @@ -322,7 +322,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
element,
count,
),
SpirvType::Array { element, count } => {
SpirvType::Array { element, count, .. } => {
let count = match self.builder.lookup_const_scalar(count) {
Some(count) => count as u32,
None => return self.store_err(original_type, value),
Expand Down
19 changes: 15 additions & 4 deletions crates/rustc_codegen_spirv/src/builder/spirv_asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,21 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
self.err("OpTypeArray in asm! is not supported yet");
return;
}
Op::TypeRuntimeArray => SpirvType::RuntimeArray {
element: inst.operands[0].unwrap_id_ref(),
Op::TypeRuntimeArray => {
let element_type = inst.operands[0].unwrap_id_ref();
let element_spirv = self.lookup_type(element_type);
// Calculate stride with alignment for asm runtime arrays
let stride = element_spirv.physical_size(self).and_then(|_| {
element_spirv
.sizeof(self)
.map(|size| size.align_to(element_spirv.alignof(self)).bytes() as u32)
});
SpirvType::RuntimeArray {
element: element_type,
stride,
}
.def(self.span(), self)
}
.def(self.span(), self),
Op::TypePointer => {
let storage_class = inst.operands[0].unwrap_storage_class();
if storage_class != StorageClass::Generic {
Expand Down Expand Up @@ -704,7 +715,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
};
ty = match cx.lookup_type(ty) {
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element }
| SpirvType::RuntimeArray { element, .. }
// HACK(eddyb) this is pretty bad because it's not
// checking that the index is an `OpConstant 0`, but
// there's no other valid choice anyway.
Expand Down
4 changes: 2 additions & 2 deletions crates/rustc_codegen_spirv/src/codegen_cx/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ impl<'tcx> CodegenCx<'tcx> {
}
self.constant_composite(ty, values.into_iter())
}
SpirvType::Array { element, count } => {
SpirvType::Array { element, count, .. } => {
let count = self.builder.lookup_const_scalar(count).unwrap() as usize;
let values = (0..count).map(|_| {
self.read_from_const_alloc(alloc, offset, element)
Expand Down Expand Up @@ -522,7 +522,7 @@ impl<'tcx> CodegenCx<'tcx> {
*offset = final_offset;
result
}
SpirvType::RuntimeArray { element } => {
SpirvType::RuntimeArray { element, .. } => {
let mut values = Vec::new();
while offset.bytes_usize() != alloc.inner().len() {
values.push(
Expand Down
6 changes: 3 additions & 3 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl<'tcx> CodegenCx<'tcx> {
let value_spirv_type = value_layout.spirv_type(hir_param.ty_span, self);
// Some types automatically specify a storage class. Compute that here.
let element_ty = match self.lookup_type(value_spirv_type) {
SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => {
SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element, .. } => {
self.lookup_type(element)
}
ty => ty,
Expand Down Expand Up @@ -505,7 +505,7 @@ impl<'tcx> CodegenCx<'tcx> {
&& {
// Peel off arrays first (used for "descriptor indexing").
let outermost_or_array_element = match self.lookup_type(value_spirv_type) {
SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => {
SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element, .. } => {
element
}
_ => value_spirv_type,
Expand Down Expand Up @@ -966,7 +966,7 @@ impl<'tcx> CodegenCx<'tcx> {
SpirvType::Vector { element, .. }
| SpirvType::Matrix { element, .. }
| SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element }
| SpirvType::RuntimeArray { element, .. }
| SpirvType::Pointer { pointee: element }
| SpirvType::InterfaceBlock {
inner_type: element,
Expand Down
8 changes: 8 additions & 0 deletions crates/rustc_codegen_spirv/src/codegen_cx/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,17 @@ impl<'tcx> BaseTypeCodegenMethods<'tcx> for CodegenCx<'tcx> {
}

fn type_array(&self, ty: Self::Type, len: u64) -> Self::Type {
let ty_spirv = self.lookup_type(ty);
// Calculate stride with alignment
let stride = ty_spirv.physical_size(self).and_then(|_| {
ty_spirv
.sizeof(self)
.map(|size| size.align_to(ty_spirv.alignof(self)).bytes() as u32)
});
SpirvType::Array {
element: ty,
count: self.constant_u64(DUMMY_SP, len),
stride,
}
.def(DUMMY_SP, self)
}
Expand Down
81 changes: 55 additions & 26 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,15 @@ pub enum SpirvType<'tcx> {
element: Word,
/// Note: array count is ref to constant.
count: SpirvValue,
/// Whether this array has an explicit stride decoration.
/// None means no decoration, Some(stride) means decorated with that stride.
stride: Option<u32>,
},
RuntimeArray {
element: Word,
/// Whether this array has an explicit stride decoration.
/// None means no decoration, Some(stride) means decorated with that stride.
stride: Option<u32>,
},
Pointer {
pointee: Word,
Expand Down Expand Up @@ -181,16 +187,32 @@ impl SpirvType<'_> {
}
Self::Vector { element, count } => cx.emit_global().type_vector_id(id, element, count),
Self::Matrix { element, count } => cx.emit_global().type_matrix_id(id, element, count),
Self::Array { element, count } => {
Self::Array {
element,
count,
stride,
} => {
let result = cx
.emit_global()
.type_array_id(id, element, count.def_cx(cx));
Self::decorate_array_stride(result, element, cx);
if let Some(stride_bytes) = stride {
cx.emit_global().decorate(
result,
Decoration::ArrayStride,
iter::once(Operand::LiteralBit32(stride_bytes)),
);
}
result
}
Self::RuntimeArray { element } => {
Self::RuntimeArray { element, stride } => {
let result = cx.emit_global().type_runtime_array_id(id, element);
Self::decorate_array_stride(result, element, cx);
if let Some(stride_bytes) = stride {
cx.emit_global().decorate(
result,
Decoration::ArrayStride,
iter::once(Operand::LiteralBit32(stride_bytes)),
);
}
result
}
Self::Pointer { pointee } => {
Expand Down Expand Up @@ -258,19 +280,6 @@ impl SpirvType<'_> {
result
}

fn decorate_array_stride(result: u32, element: u32, cx: &CodegenCx<'_>) {
let mut emit = cx.emit_global();
let ty = cx.lookup_type(element);
if let Some(element_size) = ty.physical_size(cx) {
// ArrayStride decoration wants in *bytes*
emit.decorate(
result,
Decoration::ArrayStride,
iter::once(Operand::LiteralBit32(element_size.bytes() as u32)),
);
}
}

/// `def_with_id` is used by the `RecursivePointeeCache` to handle `OpTypeForwardPointer`: when
/// emitting the subsequent `OpTypePointer`, the ID is already known and must be re-used.
pub fn def_with_id(self, cx: &CodegenCx<'_>, def_span: Span, id: Word) -> Word {
Expand Down Expand Up @@ -332,7 +341,7 @@ impl SpirvType<'_> {
cx.lookup_type(element).sizeof(cx)? * count.next_power_of_two() as u64
}
Self::Matrix { element, count } => cx.lookup_type(element).sizeof(cx)? * count as u64,
Self::Array { element, count } => {
Self::Array { element, count, .. } => {
cx.lookup_type(element).sizeof(cx)?
* cx.builder
.lookup_const_scalar(count)
Expand Down Expand Up @@ -367,7 +376,7 @@ impl SpirvType<'_> {
)
.expect("alignof: Vectors must have power-of-2 size"),
Self::Array { element, .. }
| Self::RuntimeArray { element }
| Self::RuntimeArray { element, .. }
| Self::Matrix { element, .. } => cx.lookup_type(element).alignof(cx),
Self::Pointer { .. } => cx.tcx.data_layout.pointer_align.abi,
Self::Image { .. }
Expand All @@ -388,7 +397,11 @@ impl SpirvType<'_> {

Self::Adt { size, .. } => size,

Self::Array { element, count } => Some(
Self::Array {
element,
count,
stride: _,
} => Some(
cx.lookup_type(element).physical_size(cx)?
* cx.builder
.lookup_const_scalar(count)
Expand Down Expand Up @@ -432,8 +445,18 @@ impl SpirvType<'_> {
SpirvType::Float(width) => SpirvType::Float(width),
SpirvType::Vector { element, count } => SpirvType::Vector { element, count },
SpirvType::Matrix { element, count } => SpirvType::Matrix { element, count },
SpirvType::Array { element, count } => SpirvType::Array { element, count },
SpirvType::RuntimeArray { element } => SpirvType::RuntimeArray { element },
SpirvType::Array {
element,
count,
stride,
} => SpirvType::Array {
element,
count,
stride,
},
SpirvType::RuntimeArray { element, stride } => {
SpirvType::RuntimeArray { element, stride }
}
SpirvType::Pointer { pointee } => SpirvType::Pointer { pointee },
SpirvType::Image {
sampled_type,
Expand Down Expand Up @@ -561,7 +584,11 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
.field("element", &self.cx.debug_type(element))
.field("count", &count)
.finish(),
SpirvType::Array { element, count } => f
SpirvType::Array {
element,
count,
stride,
} => f
.debug_struct("Array")
.field("id", &self.id)
.field("element", &self.cx.debug_type(element))
Expand All @@ -573,11 +600,13 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
.lookup_const_scalar(count)
.expect("Array type has invalid count value"),
)
.field("stride", &stride)
.finish(),
SpirvType::RuntimeArray { element } => f
SpirvType::RuntimeArray { element, stride } => f
.debug_struct("RuntimeArray")
.field("id", &self.id)
.field("element", &self.cx.debug_type(element))
.field("stride", &stride)
.finish(),
SpirvType::Pointer { pointee } => f
.debug_struct("Pointer")
Expand Down Expand Up @@ -720,14 +749,14 @@ impl SpirvTypePrinter<'_, '_> {
ty(self.cx, stack, f, element)?;
write!(f, "x{count}")
}
SpirvType::Array { element, count } => {
SpirvType::Array { element, count, .. } => {
let len = self.cx.builder.lookup_const_scalar(count);
let len = len.expect("Array type has invalid count value");
f.write_str("[")?;
ty(self.cx, stack, f, element)?;
write!(f, "; {len}]")
}
SpirvType::RuntimeArray { element } => {
SpirvType::RuntimeArray { element, .. } => {
f.write_str("[")?;
ty(self.cx, stack, f, element)?;
f.write_str("]")
Expand Down
20 changes: 20 additions & 0 deletions tests/compiletests/ui/dis/array_stride_alignment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// build-pass
// compile-flags: -C llvm-args=--disassemble-globals
// CHECK: OpDecorate %{{[0-9]+}} ArrayStride 16

use spirv_std::spirv;

// Test that array stride respects alignment requirements
// vec3<f32> has size 12 bytes but alignment 16 bytes
// So array stride should be 16, not 12
#[derive(Copy, Clone)]
pub struct AlignedBuffer {
data: [spirv_std::glam::Vec3; 4],
}

#[spirv(compute(threads(1)))]
pub fn main_cs(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] storage: &mut AlignedBuffer,
) {
storage.data[0] = spirv_std::glam::Vec3::new(1.0, 2.0, 3.0);
}
Loading
Loading