1
1
use super :: Builder ;
2
- use crate :: builder_spirv:: { BuilderCursor , SpirvValue , SpirvValueExt } ;
2
+ use crate :: builder_spirv:: { BuilderCursor , SpirvValue , SpirvValueExt , SpirvValueKind } ;
3
3
use crate :: spirv_type:: SpirvType ;
4
4
use rspirv:: dr:: { InsertPoint , Instruction , Operand } ;
5
- use rspirv:: spirv:: {
6
- AddressingModel , Capability , MemoryModel , MemorySemantics , Op , Scope , StorageClass , Word ,
7
- } ;
5
+ use rspirv:: spirv:: { Capability , MemoryModel , MemorySemantics , Op , Scope , StorageClass , Word } ;
8
6
use rustc_codegen_ssa:: common:: {
9
7
AtomicOrdering , AtomicRmwBinOp , IntPredicate , RealPredicate , SynchronizationScope ,
10
8
} ;
@@ -18,6 +16,7 @@ use rustc_middle::bug;
18
16
use rustc_middle:: ty:: Ty ;
19
17
use rustc_span:: Span ;
20
18
use rustc_target:: abi:: { Abi , Align , Scalar , Size } ;
19
+ use std:: convert:: TryInto ;
21
20
use std:: iter:: empty;
22
21
use std:: ops:: Range ;
23
22
@@ -334,54 +333,74 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
334
333
}
335
334
}
336
335
337
- fn zombie_bitcast_ptr ( & self , def : Word , from_ty : Word , to_ty : Word ) {
338
- let is_logical = self
339
- . emit ( )
340
- . module_ref ( )
341
- . memory_model
342
- . as_ref ( )
343
- . map_or ( false , |inst| {
344
- inst. operands [ 0 ] . unwrap_addressing_model ( ) == AddressingModel :: Logical
345
- } ) ;
346
- if is_logical {
347
- if self . is_system_crate ( ) {
348
- self . zombie ( def, "OpBitcast on ptr without AddressingModel != Logical" )
349
- } else {
350
- self . struct_err ( "Cannot cast between pointer types" )
351
- . note ( & format ! ( "from: {}" , self . debug_type( from_ty) ) )
352
- . note ( & format ! ( "to: {}" , self . debug_type( to_ty) ) )
353
- . emit ( )
354
- }
355
- }
356
- }
336
+ /// If possible, return the appropriate `OpAccessChain` indices for going from
337
+ /// a pointer to `ty`, to a pointer to `leaf_ty`, with an added `offset`.
338
+ ///
339
+ /// That is, try to turn `((_: *T) as *u8).add(offset) as *Leaf` into a series
340
+ /// of struct field and array/vector element accesses.
341
+ fn recover_access_chain_from_offset (
342
+ & self ,
343
+ mut ty : Word ,
344
+ leaf_ty : Word ,
345
+ mut offset : Size ,
346
+ ) -> Option < Vec < u32 > > {
347
+ assert_ne ! ( ty, leaf_ty) ;
348
+
349
+ // NOTE(eddyb) `ty` and `ty_kind` should be kept in sync.
350
+ let mut ty_kind = self . lookup_type ( ty) ;
357
351
358
- // Sometimes, when accessing the first field of a struct, vector, etc., instead of calling
359
- // struct_gep, codegen_ssa will call pointercast. This will then try to catch those cases and
360
- // translate them back to a struct_gep, instead of failing to compile the OpBitcast (which is
361
- // unsupported on shader target)
362
- fn try_pointercast_via_gep ( & self , mut val : Word , field : Word ) -> Option < Vec < u32 > > {
363
352
let mut indices = Vec :: new ( ) ;
364
- while val != field {
365
- match self . lookup_type ( val ) {
353
+ loop {
354
+ match ty_kind {
366
355
SpirvType :: Adt {
367
356
field_types,
368
357
field_offsets,
369
358
..
370
359
} => {
371
- let index = field_offsets. iter ( ) . position ( |& off| off == Size :: ZERO ) ?;
372
- indices. push ( index as u32 ) ;
373
- val = field_types[ index] ;
360
+ let ( i, field_ty, field_ty_kind, offset_in_field) = field_offsets
361
+ . iter ( )
362
+ . enumerate ( )
363
+ . find_map ( |( i, & field_offset) | {
364
+ if field_offset > offset {
365
+ return None ;
366
+ }
367
+
368
+ // Grab the actual field type to be able to confirm that
369
+ // the leaf is somewhere inside the field.
370
+ let field_ty = field_types[ i] ;
371
+ let field_ty_kind = self . lookup_type ( field_ty) ;
372
+
373
+ let offset_in_field = offset - field_offset;
374
+ if offset_in_field < field_ty_kind. sizeof ( self ) ? {
375
+ Some ( ( i, field_ty, field_ty_kind, offset_in_field) )
376
+ } else {
377
+ None
378
+ }
379
+ } ) ?;
380
+
381
+ ty = field_ty;
382
+ ty_kind = field_ty_kind;
383
+
384
+ indices. push ( i as u32 ) ;
385
+ offset = offset_in_field;
374
386
}
375
387
SpirvType :: Vector { element, .. }
376
388
| SpirvType :: Array { element, .. }
377
389
| SpirvType :: RuntimeArray { element } => {
378
- indices. push ( 0 ) ;
379
- val = element;
390
+ ty = element;
391
+ ty_kind = self . lookup_type ( ty) ;
392
+
393
+ let stride = ty_kind. sizeof ( self ) ?;
394
+ indices. push ( ( offset. bytes ( ) / stride. bytes ( ) ) . try_into ( ) . ok ( ) ?) ;
395
+ offset = Size :: from_bytes ( offset. bytes ( ) % stride. bytes ( ) ) ;
380
396
}
381
397
_ => return None ,
382
398
}
399
+
400
+ if offset == Size :: ZERO && ty == leaf_ty {
401
+ return Some ( indices) ;
402
+ }
383
403
}
384
- Some ( indices)
385
404
}
386
405
}
387
406
@@ -963,26 +982,65 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
963
982
}
964
983
965
984
fn struct_gep ( & mut self , ptr : Self :: Value , idx : u64 ) -> Self :: Value {
966
- let result_pointee_type = match self . lookup_type ( ptr. ty ) {
967
- SpirvType :: Pointer { pointee } => match self . lookup_type ( pointee) {
968
- SpirvType :: Adt { field_types, .. } => field_types[ idx as usize ] ,
969
- SpirvType :: Array { element, .. }
970
- | SpirvType :: RuntimeArray { element, .. }
971
- | SpirvType :: Vector { element, .. } => element,
972
- other => self . fatal ( & format ! (
973
- "struct_gep not on struct, array, or vector type: {:?}, index {}" ,
974
- other, idx
975
- ) ) ,
976
- } ,
985
+ let pointee = match self . lookup_type ( ptr. ty ) {
986
+ SpirvType :: Pointer { pointee } => pointee,
977
987
other => self . fatal ( & format ! (
978
988
"struct_gep not on pointer type: {:?}, index {}" ,
979
989
other, idx
980
990
) ) ,
981
991
} ;
992
+ let pointee_kind = self . lookup_type ( pointee) ;
993
+ let result_pointee_type = match pointee_kind {
994
+ SpirvType :: Adt {
995
+ ref field_types, ..
996
+ } => field_types[ idx as usize ] ,
997
+ SpirvType :: Array { element, .. }
998
+ | SpirvType :: RuntimeArray { element, .. }
999
+ | SpirvType :: Vector { element, .. } => element,
1000
+ other => self . fatal ( & format ! (
1001
+ "struct_gep not on struct, array, or vector type: {:?}, index {}" ,
1002
+ other, idx
1003
+ ) ) ,
1004
+ } ;
982
1005
let result_type = SpirvType :: Pointer {
983
1006
pointee : result_pointee_type,
984
1007
}
985
1008
. def ( self . span ( ) , self ) ;
1009
+
1010
+ // Special-case field accesses through a `pointercast`, to accesss the
1011
+ // right field in the original type, for the `Logical` addressing model.
1012
+ if let SpirvValueKind :: LogicalPtrCast {
1013
+ original_ptr,
1014
+ original_pointee_ty,
1015
+ zombie_target_undef : _,
1016
+ } = ptr. kind
1017
+ {
1018
+ let offset = match pointee_kind {
1019
+ SpirvType :: Adt { field_offsets, .. } => field_offsets[ idx as usize ] ,
1020
+ SpirvType :: Array { element, .. }
1021
+ | SpirvType :: RuntimeArray { element, .. }
1022
+ | SpirvType :: Vector { element, .. } => {
1023
+ self . lookup_type ( element) . sizeof ( self ) . unwrap ( ) * idx
1024
+ }
1025
+ _ => unreachable ! ( ) ,
1026
+ } ;
1027
+ if let Some ( indices) = self . recover_access_chain_from_offset (
1028
+ original_pointee_ty,
1029
+ result_pointee_type,
1030
+ offset,
1031
+ ) {
1032
+ let indices = indices
1033
+ . into_iter ( )
1034
+ . map ( |idx| self . constant_u32 ( self . span ( ) , idx) . def ( self ) )
1035
+ . collect :: < Vec < _ > > ( ) ;
1036
+ return self
1037
+ . emit ( )
1038
+ . access_chain ( result_type, None , original_ptr, indices)
1039
+ . unwrap ( )
1040
+ . with_type ( result_type) ;
1041
+ }
1042
+ }
1043
+
986
1044
// Important! LLVM, and therefore intel-compute-runtime, require the `getelementptr` instruction (and therefore
987
1045
// OpAccessChain) on structs to be a constant i32. Not i64! i32.
988
1046
if idx > u32:: MAX as u64 {
@@ -1131,16 +1189,34 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
1131
1189
if val. ty == dest_ty {
1132
1190
val
1133
1191
} else {
1192
+ let val_is_ptr = matches ! ( self . lookup_type( val. ty) , SpirvType :: Pointer { .. } ) ;
1193
+ let dest_is_ptr = matches ! ( self . lookup_type( dest_ty) , SpirvType :: Pointer { .. } ) ;
1194
+
1195
+ // Reuse the pointer-specific logic in `pointercast` for `*T -> *U`.
1196
+ if val_is_ptr && dest_is_ptr {
1197
+ return self . pointercast ( val, dest_ty) ;
1198
+ }
1199
+
1134
1200
let result = self
1135
1201
. emit ( )
1136
1202
. bitcast ( dest_ty, None , val. def ( self ) )
1137
1203
. unwrap ( )
1138
1204
. with_type ( dest_ty) ;
1139
- let val_is_ptr = matches ! ( self . lookup_type( val. ty) , SpirvType :: Pointer { .. } ) ;
1140
- let dest_is_ptr = matches ! ( self . lookup_type( dest_ty) , SpirvType :: Pointer { .. } ) ;
1141
- if val_is_ptr || dest_is_ptr {
1142
- self . zombie_bitcast_ptr ( result. def ( self ) , val. ty , dest_ty) ;
1205
+
1206
+ if ( val_is_ptr || dest_is_ptr) && self . logical_addressing_model ( ) {
1207
+ if self . is_system_crate ( ) {
1208
+ self . zombie (
1209
+ result. def ( self ) ,
1210
+ "OpBitcast between ptr and non-ptr without AddressingModel != Logical" ,
1211
+ )
1212
+ } else {
1213
+ self . struct_err ( "Cannot cast between pointer and non-pointer types" )
1214
+ . note ( & format ! ( "from: {}" , self . debug_type( val. ty) ) )
1215
+ . note ( & format ! ( "to: {}" , self . debug_type( dest_ty) ) )
1216
+ . emit ( )
1217
+ }
1143
1218
}
1219
+
1144
1220
result
1145
1221
}
1146
1222
}
@@ -1204,12 +1280,29 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
1204
1280
}
1205
1281
1206
1282
fn pointercast ( & mut self , val : Self :: Value , dest_ty : Self :: Type ) -> Self :: Value {
1207
- let val_pointee = match self . lookup_type ( val. ty ) {
1208
- SpirvType :: Pointer { pointee } => pointee,
1209
- other => self . fatal ( & format ! (
1210
- "pointercast called on non-pointer source type: {:?}" ,
1211
- other
1212
- ) ) ,
1283
+ let ( val, val_pointee) = match val. kind {
1284
+ // Strip a previous `pointercast`, to reveal the original pointer type.
1285
+ SpirvValueKind :: LogicalPtrCast {
1286
+ original_ptr,
1287
+ original_pointee_ty,
1288
+ zombie_target_undef : _,
1289
+ } => (
1290
+ original_ptr. with_type (
1291
+ SpirvType :: Pointer {
1292
+ pointee : original_pointee_ty,
1293
+ }
1294
+ . def ( self . span ( ) , self ) ,
1295
+ ) ,
1296
+ original_pointee_ty,
1297
+ ) ,
1298
+
1299
+ _ => match self . lookup_type ( val. ty ) {
1300
+ SpirvType :: Pointer { pointee } => ( val, pointee) ,
1301
+ other => self . fatal ( & format ! (
1302
+ "pointercast called on non-pointer source type: {:?}" ,
1303
+ other
1304
+ ) ) ,
1305
+ } ,
1213
1306
} ;
1214
1307
let dest_pointee = match self . lookup_type ( dest_ty) {
1215
1308
SpirvType :: Pointer { pointee } => pointee,
@@ -1220,7 +1313,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
1220
1313
} ;
1221
1314
if val. ty == dest_ty {
1222
1315
val
1223
- } else if let Some ( indices) = self . try_pointercast_via_gep ( val_pointee, dest_pointee) {
1316
+ } else if let Some ( indices) =
1317
+ self . recover_access_chain_from_offset ( val_pointee, dest_pointee, Size :: ZERO )
1318
+ {
1224
1319
let indices = indices
1225
1320
. into_iter ( )
1226
1321
. map ( |idx| self . constant_u32 ( self . span ( ) , idx) . def ( self ) )
@@ -1229,14 +1324,21 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
1229
1324
. access_chain ( dest_ty, None , val. def ( self ) , indices)
1230
1325
. unwrap ( )
1231
1326
. with_type ( dest_ty)
1327
+ } else if self . logical_addressing_model ( ) {
1328
+ // Defer the cast so that it has a chance to be avoided.
1329
+ SpirvValue {
1330
+ kind : SpirvValueKind :: LogicalPtrCast {
1331
+ original_ptr : val. def ( self ) ,
1332
+ original_pointee_ty : val_pointee,
1333
+ zombie_target_undef : self . undef ( dest_ty) . def ( self ) ,
1334
+ } ,
1335
+ ty : dest_ty,
1336
+ }
1232
1337
} else {
1233
- let result = self
1234
- . emit ( )
1338
+ self . emit ( )
1235
1339
. bitcast ( dest_ty, None , val. def ( self ) )
1236
1340
. unwrap ( )
1237
- . with_type ( dest_ty) ;
1238
- self . zombie_bitcast_ptr ( result. def ( self ) , val. ty , dest_ty) ;
1239
- result
1341
+ . with_type ( dest_ty)
1240
1342
}
1241
1343
}
1242
1344
0 commit comments