Skip to content

Commit 7723f5d

Browse files
authored
Access the original type's fields through a pointercast, under the Logical addressing model. (#469)
* Defer pointer casts under the Logical addressing model. * Access the original type's fields through a `pointercast`, under the Logical addressing model. * Add a test using a `for` loop and a custom `Range`-like iterator.
1 parent 5cfaa00 commit 7723f5d

File tree

4 files changed

+255
-66
lines changed

4 files changed

+255
-66
lines changed

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 166 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
use super::Builder;
2-
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt};
2+
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt, SpirvValueKind};
33
use crate::spirv_type::SpirvType;
44
use rspirv::dr::{InsertPoint, Instruction, Operand};
5-
use rspirv::spirv::{
6-
AddressingModel, Capability, MemoryModel, MemorySemantics, Op, Scope, StorageClass, Word,
7-
};
5+
use rspirv::spirv::{Capability, MemoryModel, MemorySemantics, Op, Scope, StorageClass, Word};
86
use rustc_codegen_ssa::common::{
97
AtomicOrdering, AtomicRmwBinOp, IntPredicate, RealPredicate, SynchronizationScope,
108
};
@@ -18,6 +16,7 @@ use rustc_middle::bug;
1816
use rustc_middle::ty::Ty;
1917
use rustc_span::Span;
2018
use rustc_target::abi::{Abi, Align, Scalar, Size};
19+
use std::convert::TryInto;
2120
use std::iter::empty;
2221
use std::ops::Range;
2322

@@ -334,54 +333,74 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
334333
}
335334
}
336335

337-
fn zombie_bitcast_ptr(&self, def: Word, from_ty: Word, to_ty: Word) {
338-
let is_logical = self
339-
.emit()
340-
.module_ref()
341-
.memory_model
342-
.as_ref()
343-
.map_or(false, |inst| {
344-
inst.operands[0].unwrap_addressing_model() == AddressingModel::Logical
345-
});
346-
if is_logical {
347-
if self.is_system_crate() {
348-
self.zombie(def, "OpBitcast on ptr without AddressingModel != Logical")
349-
} else {
350-
self.struct_err("Cannot cast between pointer types")
351-
.note(&format!("from: {}", self.debug_type(from_ty)))
352-
.note(&format!("to: {}", self.debug_type(to_ty)))
353-
.emit()
354-
}
355-
}
356-
}
336+
/// If possible, return the appropriate `OpAccessChain` indices for going from
337+
/// a pointer to `ty`, to a pointer to `leaf_ty`, with an added `offset`.
338+
///
339+
/// That is, try to turn `((_: *T) as *u8).add(offset) as *Leaf` into a series
340+
/// of struct field and array/vector element accesses.
341+
fn recover_access_chain_from_offset(
342+
&self,
343+
mut ty: Word,
344+
leaf_ty: Word,
345+
mut offset: Size,
346+
) -> Option<Vec<u32>> {
347+
assert_ne!(ty, leaf_ty);
348+
349+
// NOTE(eddyb) `ty` and `ty_kind` should be kept in sync.
350+
let mut ty_kind = self.lookup_type(ty);
357351

358-
// Sometimes, when accessing the first field of a struct, vector, etc., instead of calling
359-
// struct_gep, codegen_ssa will call pointercast. This will then try to catch those cases and
360-
// translate them back to a struct_gep, instead of failing to compile the OpBitcast (which is
361-
// unsupported on shader target)
362-
fn try_pointercast_via_gep(&self, mut val: Word, field: Word) -> Option<Vec<u32>> {
363352
let mut indices = Vec::new();
364-
while val != field {
365-
match self.lookup_type(val) {
353+
loop {
354+
match ty_kind {
366355
SpirvType::Adt {
367356
field_types,
368357
field_offsets,
369358
..
370359
} => {
371-
let index = field_offsets.iter().position(|&off| off == Size::ZERO)?;
372-
indices.push(index as u32);
373-
val = field_types[index];
360+
let (i, field_ty, field_ty_kind, offset_in_field) = field_offsets
361+
.iter()
362+
.enumerate()
363+
.find_map(|(i, &field_offset)| {
364+
if field_offset > offset {
365+
return None;
366+
}
367+
368+
// Grab the actual field type to be able to confirm that
369+
// the leaf is somewhere inside the field.
370+
let field_ty = field_types[i];
371+
let field_ty_kind = self.lookup_type(field_ty);
372+
373+
let offset_in_field = offset - field_offset;
374+
if offset_in_field < field_ty_kind.sizeof(self)? {
375+
Some((i, field_ty, field_ty_kind, offset_in_field))
376+
} else {
377+
None
378+
}
379+
})?;
380+
381+
ty = field_ty;
382+
ty_kind = field_ty_kind;
383+
384+
indices.push(i as u32);
385+
offset = offset_in_field;
374386
}
375387
SpirvType::Vector { element, .. }
376388
| SpirvType::Array { element, .. }
377389
| SpirvType::RuntimeArray { element } => {
378-
indices.push(0);
379-
val = element;
390+
ty = element;
391+
ty_kind = self.lookup_type(ty);
392+
393+
let stride = ty_kind.sizeof(self)?;
394+
indices.push((offset.bytes() / stride.bytes()).try_into().ok()?);
395+
offset = Size::from_bytes(offset.bytes() % stride.bytes());
380396
}
381397
_ => return None,
382398
}
399+
400+
if offset == Size::ZERO && ty == leaf_ty {
401+
return Some(indices);
402+
}
383403
}
384-
Some(indices)
385404
}
386405
}
387406

@@ -963,26 +982,65 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
963982
}
964983

965984
fn struct_gep(&mut self, ptr: Self::Value, idx: u64) -> Self::Value {
966-
let result_pointee_type = match self.lookup_type(ptr.ty) {
967-
SpirvType::Pointer { pointee } => match self.lookup_type(pointee) {
968-
SpirvType::Adt { field_types, .. } => field_types[idx as usize],
969-
SpirvType::Array { element, .. }
970-
| SpirvType::RuntimeArray { element, .. }
971-
| SpirvType::Vector { element, .. } => element,
972-
other => self.fatal(&format!(
973-
"struct_gep not on struct, array, or vector type: {:?}, index {}",
974-
other, idx
975-
)),
976-
},
985+
let pointee = match self.lookup_type(ptr.ty) {
986+
SpirvType::Pointer { pointee } => pointee,
977987
other => self.fatal(&format!(
978988
"struct_gep not on pointer type: {:?}, index {}",
979989
other, idx
980990
)),
981991
};
992+
let pointee_kind = self.lookup_type(pointee);
993+
let result_pointee_type = match pointee_kind {
994+
SpirvType::Adt {
995+
ref field_types, ..
996+
} => field_types[idx as usize],
997+
SpirvType::Array { element, .. }
998+
| SpirvType::RuntimeArray { element, .. }
999+
| SpirvType::Vector { element, .. } => element,
1000+
other => self.fatal(&format!(
1001+
"struct_gep not on struct, array, or vector type: {:?}, index {}",
1002+
other, idx
1003+
)),
1004+
};
9821005
let result_type = SpirvType::Pointer {
9831006
pointee: result_pointee_type,
9841007
}
9851008
.def(self.span(), self);
1009+
1010+
// Special-case field accesses through a `pointercast`, to accesss the
1011+
// right field in the original type, for the `Logical` addressing model.
1012+
if let SpirvValueKind::LogicalPtrCast {
1013+
original_ptr,
1014+
original_pointee_ty,
1015+
zombie_target_undef: _,
1016+
} = ptr.kind
1017+
{
1018+
let offset = match pointee_kind {
1019+
SpirvType::Adt { field_offsets, .. } => field_offsets[idx as usize],
1020+
SpirvType::Array { element, .. }
1021+
| SpirvType::RuntimeArray { element, .. }
1022+
| SpirvType::Vector { element, .. } => {
1023+
self.lookup_type(element).sizeof(self).unwrap() * idx
1024+
}
1025+
_ => unreachable!(),
1026+
};
1027+
if let Some(indices) = self.recover_access_chain_from_offset(
1028+
original_pointee_ty,
1029+
result_pointee_type,
1030+
offset,
1031+
) {
1032+
let indices = indices
1033+
.into_iter()
1034+
.map(|idx| self.constant_u32(self.span(), idx).def(self))
1035+
.collect::<Vec<_>>();
1036+
return self
1037+
.emit()
1038+
.access_chain(result_type, None, original_ptr, indices)
1039+
.unwrap()
1040+
.with_type(result_type);
1041+
}
1042+
}
1043+
9861044
// Important! LLVM, and therefore intel-compute-runtime, require the `getelementptr` instruction (and therefore
9871045
// OpAccessChain) on structs to be a constant i32. Not i64! i32.
9881046
if idx > u32::MAX as u64 {
@@ -1131,16 +1189,34 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
11311189
if val.ty == dest_ty {
11321190
val
11331191
} else {
1192+
let val_is_ptr = matches!(self.lookup_type(val.ty), SpirvType::Pointer { .. });
1193+
let dest_is_ptr = matches!(self.lookup_type(dest_ty), SpirvType::Pointer { .. });
1194+
1195+
// Reuse the pointer-specific logic in `pointercast` for `*T -> *U`.
1196+
if val_is_ptr && dest_is_ptr {
1197+
return self.pointercast(val, dest_ty);
1198+
}
1199+
11341200
let result = self
11351201
.emit()
11361202
.bitcast(dest_ty, None, val.def(self))
11371203
.unwrap()
11381204
.with_type(dest_ty);
1139-
let val_is_ptr = matches!(self.lookup_type(val.ty), SpirvType::Pointer { .. });
1140-
let dest_is_ptr = matches!(self.lookup_type(dest_ty), SpirvType::Pointer { .. });
1141-
if val_is_ptr || dest_is_ptr {
1142-
self.zombie_bitcast_ptr(result.def(self), val.ty, dest_ty);
1205+
1206+
if (val_is_ptr || dest_is_ptr) && self.logical_addressing_model() {
1207+
if self.is_system_crate() {
1208+
self.zombie(
1209+
result.def(self),
1210+
"OpBitcast between ptr and non-ptr without AddressingModel != Logical",
1211+
)
1212+
} else {
1213+
self.struct_err("Cannot cast between pointer and non-pointer types")
1214+
.note(&format!("from: {}", self.debug_type(val.ty)))
1215+
.note(&format!("to: {}", self.debug_type(dest_ty)))
1216+
.emit()
1217+
}
11431218
}
1219+
11441220
result
11451221
}
11461222
}
@@ -1204,12 +1280,29 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
12041280
}
12051281

12061282
fn pointercast(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
1207-
let val_pointee = match self.lookup_type(val.ty) {
1208-
SpirvType::Pointer { pointee } => pointee,
1209-
other => self.fatal(&format!(
1210-
"pointercast called on non-pointer source type: {:?}",
1211-
other
1212-
)),
1283+
let (val, val_pointee) = match val.kind {
1284+
// Strip a previous `pointercast`, to reveal the original pointer type.
1285+
SpirvValueKind::LogicalPtrCast {
1286+
original_ptr,
1287+
original_pointee_ty,
1288+
zombie_target_undef: _,
1289+
} => (
1290+
original_ptr.with_type(
1291+
SpirvType::Pointer {
1292+
pointee: original_pointee_ty,
1293+
}
1294+
.def(self.span(), self),
1295+
),
1296+
original_pointee_ty,
1297+
),
1298+
1299+
_ => match self.lookup_type(val.ty) {
1300+
SpirvType::Pointer { pointee } => (val, pointee),
1301+
other => self.fatal(&format!(
1302+
"pointercast called on non-pointer source type: {:?}",
1303+
other
1304+
)),
1305+
},
12131306
};
12141307
let dest_pointee = match self.lookup_type(dest_ty) {
12151308
SpirvType::Pointer { pointee } => pointee,
@@ -1220,7 +1313,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
12201313
};
12211314
if val.ty == dest_ty {
12221315
val
1223-
} else if let Some(indices) = self.try_pointercast_via_gep(val_pointee, dest_pointee) {
1316+
} else if let Some(indices) =
1317+
self.recover_access_chain_from_offset(val_pointee, dest_pointee, Size::ZERO)
1318+
{
12241319
let indices = indices
12251320
.into_iter()
12261321
.map(|idx| self.constant_u32(self.span(), idx).def(self))
@@ -1229,14 +1324,21 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
12291324
.access_chain(dest_ty, None, val.def(self), indices)
12301325
.unwrap()
12311326
.with_type(dest_ty)
1327+
} else if self.logical_addressing_model() {
1328+
// Defer the cast so that it has a chance to be avoided.
1329+
SpirvValue {
1330+
kind: SpirvValueKind::LogicalPtrCast {
1331+
original_ptr: val.def(self),
1332+
original_pointee_ty: val_pointee,
1333+
zombie_target_undef: self.undef(dest_ty).def(self),
1334+
},
1335+
ty: dest_ty,
1336+
}
12321337
} else {
1233-
let result = self
1234-
.emit()
1338+
self.emit()
12351339
.bitcast(dest_ty, None, val.def(self))
12361340
.unwrap()
1237-
.with_type(dest_ty);
1238-
self.zombie_bitcast_ptr(result.def(self), val.ty, dest_ty);
1239-
result
1341+
.with_type(dest_ty)
12401342
}
12411343
}
12421344

crates/rustc_codegen_spirv/src/builder_spirv.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::{fs::File, io::Write, path::Path};
1313
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
1414
pub enum SpirvValueKind {
1515
Def(Word),
16+
1617
/// There are a fair number of places where `rustc_codegen_ssa` creates a pointer to something
1718
/// that cannot be pointed to in SPIR-V. For example, constant values are frequently emitted as
1819
/// a pointer to constant memory, and then dereferenced where they're used. Functions are the
@@ -28,6 +29,24 @@ pub enum SpirvValueKind {
2829
/// its initializer) to attach zombies to.
2930
global_var: Word,
3031
},
32+
33+
/// Deferred pointer cast, for the `Logical` addressing model (which doesn't
34+
/// really support raw pointers in the way Rust expects to be able to use).
35+
///
36+
/// The cast's target pointer type is the `ty` of the `SpirvValue` that has
37+
/// `LogicalPtrCast` as its `kind`, as it would be redundant to have it here.
38+
LogicalPtrCast {
39+
/// Pointer value being cast.
40+
original_ptr: Word,
41+
42+
/// Pointee type of `original_ptr`.
43+
original_pointee_ty: Word,
44+
45+
/// `OpUndef` of the right target pointer type, to attach zombies to.
46+
// FIXME(eddyb) we should be using a real `OpBitcast` here, but we can't
47+
// emit that on the fly during `SpirvValue::def`, due to builder locking.
48+
zombie_target_undef: Word,
49+
},
3150
}
3251

3352
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
@@ -49,7 +68,8 @@ impl SpirvValue {
4968
};
5069
Some(initializer.with_type(ty))
5170
}
52-
SpirvValueKind::Def(_) => None,
71+
72+
SpirvValueKind::Def(_) | SpirvValueKind::LogicalPtrCast { .. } => None,
5373
}
5474
}
5575

@@ -69,6 +89,7 @@ impl SpirvValue {
6989
pub fn def_with_span(self, cx: &CodegenCx<'_>, span: Span) -> Word {
7090
match self.kind {
7191
SpirvValueKind::Def(word) => word,
92+
7293
SpirvValueKind::ConstantPointer {
7394
initializer: _,
7495
global_var,
@@ -83,6 +104,29 @@ impl SpirvValue {
83104

84105
global_var
85106
}
107+
108+
SpirvValueKind::LogicalPtrCast {
109+
original_ptr: _,
110+
original_pointee_ty,
111+
zombie_target_undef,
112+
} => {
113+
if cx.is_system_crate() {
114+
cx.zombie_with_span(
115+
zombie_target_undef,
116+
span,
117+
"OpBitcast on ptr without AddressingModel != Logical",
118+
)
119+
} else {
120+
cx.tcx
121+
.sess
122+
.struct_span_err(span, "Cannot cast between pointer types")
123+
.note(&format!("from: *{}", cx.debug_type(original_pointee_ty)))
124+
.note(&format!("to: {}", cx.debug_type(self.ty)))
125+
.emit()
126+
}
127+
128+
zombie_target_undef
129+
}
86130
}
87131
}
88132
}

0 commit comments

Comments
 (0)