From 1e4eca2845b65a0bb16a9dd969dc93b262f523b8 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Sat, 22 Mar 2025 13:01:11 -0700 Subject: [PATCH 1/2] [naga hlsl-out] Handle additional cases of Cx2 matrices Fixes #4423 --- naga/src/back/hlsl/mod.rs | 25 +- naga/src/back/hlsl/storage.rs | 146 ++++++-- naga/src/back/hlsl/writer.rs | 85 ++++- naga/tests/in/wgsl/access.wgsl | 5 +- naga/tests/in/wgsl/hlsl_mat_cx2.toml | 1 + naga/tests/in/wgsl/hlsl_mat_cx2.wgsl | 177 ++++++++++ naga/tests/in/wgsl/hlsl_mat_cx3.toml | 1 + naga/tests/in/wgsl/hlsl_mat_cx3.wgsl | 173 ++++++++++ naga/tests/out/hlsl/wgsl-hlsl_mat_cx2.hlsl | 372 ++++++++++++++++++++ naga/tests/out/hlsl/wgsl-hlsl_mat_cx2.ron | 12 + naga/tests/out/hlsl/wgsl-hlsl_mat_cx3.hlsl | 350 +++++++++++++++++++ naga/tests/out/hlsl/wgsl-hlsl_mat_cx3.ron | 12 + naga/tests/out/spv/wgsl-access.spvasm | 373 +++++++++++---------- 13 files changed, 1503 insertions(+), 229 deletions(-) create mode 100644 naga/tests/in/wgsl/hlsl_mat_cx2.toml create mode 100644 naga/tests/in/wgsl/hlsl_mat_cx2.wgsl create mode 100644 naga/tests/in/wgsl/hlsl_mat_cx3.toml create mode 100644 naga/tests/in/wgsl/hlsl_mat_cx3.wgsl create mode 100644 naga/tests/out/hlsl/wgsl-hlsl_mat_cx2.hlsl create mode 100644 naga/tests/out/hlsl/wgsl-hlsl_mat_cx2.ron create mode 100644 naga/tests/out/hlsl/wgsl-hlsl_mat_cx3.hlsl create mode 100644 naga/tests/out/hlsl/wgsl-hlsl_mat_cx3.ron diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index 9e041ff73f..f40a5a4c5c 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -13,11 +13,17 @@ type should be stored in `uniform` and `storage` buffers. The HLSL we generate must access values in that form, even when it is not what HLSL would use normally. -The rules described here only apply to WGSL `uniform` variables. WGSL -`storage` buffers are translated as HLSL `ByteAddressBuffers`, for -which we generate `Load` and `Store` method calls with explicit byte -offsets. WGSL pipeline inputs must be scalars or vectors; they cannot -be matrices, which is where the interesting problems arise. +Matching the WGSL memory layout is a concern only for `uniform` +variables. WGSL `storage` buffers are translated as HLSL +`ByteAddressBuffers`, for which we generate `Load` and `Store` method +calls with explicit byte offsets. WGSL pipeline inputs must be scalars +or vectors; they cannot be matrices, which is where the interesting +problems arise. However, when an affected type appears in a struct +definition, the transformations described here are applied without +consideration of where the struct is used. + +Access to storage buffers is implemented in `storage.rs`. Access to +uniform buffers is implemented where applicable in `writer.rs`. ## Row- and column-major ordering for matrices @@ -57,10 +63,9 @@ that the columns of a `matKx2` need only be [aligned as required for `vec2`][ilov], which is [eight-byte alignment][8bb]. To compensate for this, any time a `matKx2` appears in a WGSL -`uniform` variable, whether directly as the variable's type or as part -of a struct/array, we actually emit `K` separate `float2` members, and -assemble/disassemble the matrix from its columns (in WGSL; rows in -HLSL) upon load and store. +`uniform` value or as part of a struct/array, we actually emit `K` +separate `float2` members, and assemble/disassemble the matrix from its +columns (in WGSL; rows in HLSL) upon load and store. For example, the following WGSL struct type: @@ -448,6 +453,8 @@ pub enum Error { Override, #[error(transparent)] ResolveArraySizeError(#[from] proc::ResolveArraySizeError), + #[error("Internal error: reached unreachable code: {0}")] + Unreachable(String), } #[derive(PartialEq, Eq, Hash)] diff --git a/naga/src/back/hlsl/storage.rs b/naga/src/back/hlsl/storage.rs index 9f92d86639..2b68f973f0 100644 --- a/naga/src/back/hlsl/storage.rs +++ b/naga/src/back/hlsl/storage.rs @@ -108,6 +108,13 @@ pub(super) enum StoreValue { base: Handle, member_index: u32, }, + // Access to a single column of a Cx2 matrix within a struct + TempColumnAccess { + depth: usize, + base: Handle, + member_index: u32, + column: u32, + }, } impl super::Writer<'_, W> { @@ -290,6 +297,15 @@ impl super::Writer<'_, W> { let name = &self.names[&NameKey::StructMember(base, member_index)]; write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")? } + StoreValue::TempColumnAccess { + depth, + base, + member_index, + column, + } => { + let name = &self.names[&NameKey::StructMember(base, member_index)]; + write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}_{column}")? + } } Ok(()) } @@ -302,6 +318,7 @@ impl super::Writer<'_, W> { value: StoreValue, func_ctx: &FunctionCtx, level: crate::back::Level, + within_struct: Option>, ) -> BackendResult { let temp_resolution; let ty_resolution = match value { @@ -325,6 +342,11 @@ impl super::Writer<'_, W> { temp_resolution = TypeResolution::Handle(ty_handle); &temp_resolution } + StoreValue::TempColumnAccess { .. } => { + return Err(Error::Unreachable( + "attempting write_storage_store for TempColumnAccess".into(), + )); + } }; match *ty_resolution.inner_with(&module.types) { crate::TypeInner::Scalar(scalar) => { @@ -372,37 +394,92 @@ impl super::Writer<'_, W> { rows, scalar, } => { - // first, assign the value to a temporary - writeln!(self.out, "{level}{{")?; - let depth = level.0 + 1; - write!( - self.out, - "{}{}{}x{} {}{} = ", - level.next(), - scalar.to_hlsl_str()?, - columns as u8, - rows as u8, - STORE_TEMP_NAME, - depth, - )?; - self.write_store_value(module, &value, func_ctx)?; - writeln!(self.out, ";")?; - // Note: Matrices containing vec3s, due to padding, act like they contain vec4s. let row_stride = Alignment::from(rows) * scalar.width as u32; - // then iterate the stores - for i in 0..columns as u32 { - self.temp_access_chain - .push(SubAccess::Offset(i * row_stride)); - let ty_inner = crate::TypeInner::Vector { size: rows, scalar }; - let sv = StoreValue::TempIndex { - depth, - index: i, - ty: TypeResolution::Value(ty_inner), - }; - self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?; - self.temp_access_chain.pop(); + writeln!(self.out, "{level}{{")?; + + match within_struct { + Some(containing_struct) if rows == crate::VectorSize::Bi => { + // If we are within a struct, then the struct was already assigned to + // a temporary, we don't need to make another. + let mut chain = mem::take(&mut self.temp_access_chain); + for i in 0..columns as u32 { + chain.push(SubAccess::Offset(i * row_stride)); + // working around the borrow checker in `self.write_expr` + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + let StoreValue::TempAccess { member_index, .. } = value else { + return Err(Error::Unreachable( + "write_storage_store within_struct but not TempAccess".into(), + )); + }; + let column_value = StoreValue::TempColumnAccess { + depth: level.0, // note not incrementing, b/c no temp + base: containing_struct, + member_index, + column: i, + }; + // See note about DXC and Load/Store in the module's documentation. + if scalar.width == 4 { + write!( + self.out, + "{}{}.Store{}(", + level.next(), + var_name, + rows as u8 + )?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", asuint(")?; + self.write_store_value(module, &column_value, func_ctx)?; + writeln!(self.out, "));")?; + } else { + write!(self.out, "{}{var_name}.Store(", level.next())?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", ")?; + self.write_store_value(module, &column_value, func_ctx)?; + writeln!(self.out, ");")?; + } + chain.pop(); + } + self.temp_access_chain = chain; + } + _ => { + // first, assign the value to a temporary + let depth = level.0 + 1; + write!( + self.out, + "{}{}{}x{} {}{} = ", + level.next(), + scalar.to_hlsl_str()?, + columns as u8, + rows as u8, + STORE_TEMP_NAME, + depth, + )?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ";")?; + + // then iterate the stores + for i in 0..columns as u32 { + self.temp_access_chain + .push(SubAccess::Offset(i * row_stride)); + let ty_inner = crate::TypeInner::Vector { size: rows, scalar }; + let sv = StoreValue::TempIndex { + depth, + index: i, + ty: TypeResolution::Value(ty_inner), + }; + self.write_storage_store( + module, + var_handle, + sv, + func_ctx, + level.next(), + None, + )?; + self.temp_access_chain.pop(); + } + } } // done writeln!(self.out, "{level}}}")?; @@ -415,7 +492,7 @@ impl super::Writer<'_, W> { // first, assign the value to a temporary writeln!(self.out, "{level}{{")?; write!(self.out, "{}", level.next())?; - self.write_value_type(module, &module.types[base].inner)?; + self.write_type(module, base)?; let depth = level.next().0; write!(self.out, " {STORE_TEMP_NAME}{depth}")?; self.write_array_size(module, base, crate::ArraySize::Constant(size))?; @@ -430,7 +507,7 @@ impl super::Writer<'_, W> { index: i, ty: TypeResolution::Handle(base), }; - self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?; + self.write_storage_store(module, var_handle, sv, func_ctx, level.next(), None)?; self.temp_access_chain.pop(); } // done @@ -461,7 +538,14 @@ impl super::Writer<'_, W> { base: struct_ty, member_index: i as u32, }; - self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?; + self.write_storage_store( + module, + var_handle, + sv, + func_ctx, + level.next(), + Some(struct_ty), + )?; self.temp_access_chain.pop(); } // done diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 59725df3db..5f959e37e4 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -1911,6 +1911,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { StoreValue::Expression(value), func_ctx, level, + None, )?; } else { // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. @@ -2912,12 +2913,37 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // // Note that this only works for `Load`s and we handle // `Store`s differently in `Statement::Store`. + let cx2_columns; if let Some(MatrixType { columns, rows: crate::VectorSize::Bi, width: 4, }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true) { + cx2_columns = Some(columns); + } else { + let base_tr = func_ctx + .resolve_type(base, &module.types) + .pointer_base_type(); + let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types)); + match (&func_ctx.expressions[base], base_ty) { + ( + &Expression::GlobalVariable(handle), + Some(&TypeInner::Matrix { + columns, + rows: crate::VectorSize::Bi, + .. + }), + ) if module.global_variables[handle].space + == crate::AddressSpace::Uniform => + { + cx2_columns = Some(columns); + } + _ => cx2_columns = None, + } + } + + if let Some(columns) = cx2_columns { write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?; self.write_expr(module, base, func_ctx)?; write!(self.out, ", ")?; @@ -3031,12 +3057,36 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } else { // We write the matrix column access in a special way since // the type of `base` is our special __matCx2 struct. + let is_cx2; if let Some(MatrixType { rows: crate::VectorSize::Bi, width: 4, .. }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true) { + is_cx2 = true; + } else { + let base_tr = func_ctx + .resolve_type(base, &module.types) + .pointer_base_type(); + let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types)); + match (&func_ctx.expressions[base], base_ty) { + ( + &Expression::GlobalVariable(handle), + Some(&TypeInner::Matrix { + rows: crate::VectorSize::Bi, + .. + }), + ) if module.global_variables[handle].space + == crate::AddressSpace::Uniform => + { + is_cx2 = true; + } + _ => is_cx2 = false, + } + } + + if is_cx2 { self.write_expr(module, base, func_ctx)?; write!(self.out, "._{index}")?; return Ok(()); @@ -3309,8 +3359,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx)) { let mut resolved = func_ctx.resolve_type(pointer, &module.types); - if let TypeInner::Pointer { base, .. } = *resolved { - resolved = &module.types[base].inner; + let ptr_tr = resolved.pointer_base_type(); + if let Some(ptr_ty) = + ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types)) + { + resolved = ptr_ty; } write!(self.out, "((")?; @@ -4288,6 +4341,32 @@ pub(super) fn get_inner_matrix_data( } } +fn find_matrix_in_access_chain( + module: &Module, + base: Handle, + func_ctx: &back::FunctionCtx<'_>, +) -> Option> { + let mut current_base = base; + loop { + let resolved_tr = func_ctx + .resolve_type(current_base, &module.types) + .pointer_base_type(); + let resolved = resolved_tr.as_ref()?.inner_with(&module.types); + + match *resolved { + TypeInner::Scalar(_) | TypeInner::Vector { .. } => {} + TypeInner::Matrix { .. } => return Some(current_base), + _ => return None, + } + + current_base = match func_ctx.expressions[current_base] { + crate::Expression::Access { base, .. } => base, + crate::Expression::AccessIndex { base, .. } => base, + _ => return None, + } + } +} + /// Returns the matrix data if the access chain starting at `base`: /// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true` /// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`] @@ -4355,10 +4434,10 @@ fn get_inner_matrix_of_global_uniform( base: Handle, func_ctx: &back::FunctionCtx<'_>, ) -> Option { + let mut current_base = find_matrix_in_access_chain(module, base, func_ctx)?; let mut mat_data = None; let mut array_base = None; - let mut current_base = base; loop { let mut resolved = func_ctx.resolve_type(current_base, &module.types); if let TypeInner::Pointer { base, .. } = *resolved { diff --git a/naga/tests/in/wgsl/access.wgsl b/naga/tests/in/wgsl/access.wgsl index 2ad53b6134..e80750c951 100644 --- a/naga/tests/in/wgsl/access.wgsl +++ b/naga/tests/in/wgsl/access.wgsl @@ -35,7 +35,10 @@ var baz: Baz; var qux: vec2; fn test_matrix_within_struct_accesses() { - var idx = 1; + // Test HLSL accesses to Cx2 matrices. There are additional tests + // in `hlsl_mat_cx2.wgsl`. + + var idx = 1; idx--; diff --git a/naga/tests/in/wgsl/hlsl_mat_cx2.toml b/naga/tests/in/wgsl/hlsl_mat_cx2.toml new file mode 100644 index 0000000000..3ca0b52f4e --- /dev/null +++ b/naga/tests/in/wgsl/hlsl_mat_cx2.toml @@ -0,0 +1 @@ +targets = "HLSL" diff --git a/naga/tests/in/wgsl/hlsl_mat_cx2.wgsl b/naga/tests/in/wgsl/hlsl_mat_cx2.wgsl new file mode 100644 index 0000000000..50bc188794 --- /dev/null +++ b/naga/tests/in/wgsl/hlsl_mat_cx2.wgsl @@ -0,0 +1,177 @@ +// Test HLSL handling of N-by-2 matrices. +// See the doc comment on `naga::back::hlsl` for details. +// +// There are additional tests in `access.wgsl`. +// +// Tests that we don't apply this handling to other sizes are in hlsl_mat_cx3.wgsl. + +// Access type (3rd item in variable names) +// S = Struct +// M = Matrix +// C = Column +// E = Element + +// Index type (4th item in variable names) +// C = Constant +// V = Variable + +alias Mat = mat2x2; + +@group(0) @binding(0) +var s_m: Mat; + +@group(0) @binding(1) +var u_m: Mat; + +fn access_m() { + var idx = 1; + idx--; + + // loads from storage + let l_s_m = s_m; + let l_s_c_c = s_m[0]; + let l_s_c_v = s_m[idx]; + let l_s_e_cc = s_m[0][0]; + let l_s_e_cv = s_m[0][idx]; + let l_s_e_vc = s_m[idx][0]; + let l_s_e_vv = s_m[idx][idx]; + + // loads from uniform + let l_u_m = u_m; + let l_u_c_c = u_m[0]; + let l_u_c_v = u_m[idx]; + let l_u_e_cc = u_m[0][0]; + let l_u_e_cv = u_m[0][idx]; + let l_u_e_vc = u_m[idx][0]; + let l_u_e_vv = u_m[idx][idx]; + + // stores to storage + s_m = l_u_m; + s_m[0] = l_u_c_c; + s_m[idx] = l_u_c_v; + s_m[0][0] = l_u_e_cc; + s_m[0][idx] = l_u_e_cv; + s_m[idx][0] = l_u_e_vc; + s_m[idx][idx] = l_u_e_vv; +} + +struct StructWithMat { + m: Mat, +} + +@group(1) @binding(0) +var s_sm: StructWithMat; + +@group(1) @binding(1) +var u_sm: StructWithMat; + +fn access_sm() { + var idx = 1; + idx--; + + // loads from storage + let l_s_s = s_sm; + let l_s_m = s_sm.m; + let l_s_c_c = s_sm.m[0]; + let l_s_c_v = s_sm.m[idx]; + let l_s_e_cc = s_sm.m[0][0]; + let l_s_e_cv = s_sm.m[0][idx]; + let l_s_e_vc = s_sm.m[idx][0]; + let l_s_e_vv = s_sm.m[idx][idx]; + + // loads from uniform + let l_u_s = u_sm; + let l_u_m = u_sm.m; + let l_u_c_c = u_sm.m[0]; + let l_u_c_v = u_sm.m[idx]; + let l_u_e_cc = u_sm.m[0][0]; + let l_u_e_cv = u_sm.m[0][idx]; + let l_u_e_vc = u_sm.m[idx][0]; + let l_u_e_vv = u_sm.m[idx][idx]; + + // stores to storage + s_sm = l_u_s; + s_sm.m = l_u_m; + s_sm.m[0] = l_u_c_c; + s_sm.m[idx] = l_u_c_v; + s_sm.m[0][0] = l_u_e_cc; + s_sm.m[0][idx] = l_u_e_cv; + s_sm.m[idx][0] = l_u_e_vc; + s_sm.m[idx][idx] = l_u_e_vv; +} + +struct StructWithArrayOfStructOfMat { + a: array, +} + +@group(2) @binding(0) +var s_sasm: StructWithArrayOfStructOfMat; + +@group(2) @binding(1) +var u_sasm: StructWithArrayOfStructOfMat; + +fn access_sasm() { + var idx = 1; + idx--; + + // loads from storage + let l_s_s = s_sasm; + let l_s_a = s_sasm.a; + let l_s_m_c = s_sasm.a[0].m; + let l_s_m_v = s_sasm.a[idx].m; + let l_s_c_cc = s_sasm.a[0].m[0]; + let l_s_c_cv = s_sasm.a[0].m[idx]; + let l_s_c_vc = s_sasm.a[idx].m[0]; + let l_s_c_vv = s_sasm.a[idx].m[idx]; + let l_s_e_ccc = s_sasm.a[0].m[0][0]; + let l_s_e_ccv = s_sasm.a[0].m[0][idx]; + let l_s_e_cvc = s_sasm.a[0].m[idx][0]; + let l_s_e_cvv = s_sasm.a[0].m[idx][idx]; + let l_s_e_vcc = s_sasm.a[idx].m[0][0]; + let l_s_e_vcv = s_sasm.a[idx].m[0][idx]; + let l_s_e_vvc = s_sasm.a[idx].m[idx][0]; + let l_s_e_vvv = s_sasm.a[idx].m[idx][idx]; + + // loads from uniform + let l_u_s = u_sasm; + let l_u_a = u_sasm.a; + let l_u_m_c = u_sasm.a[0].m; + let l_u_m_v = u_sasm.a[idx].m; + let l_u_c_cc = u_sasm.a[0].m[0]; + let l_u_c_cv = u_sasm.a[0].m[idx]; + let l_u_c_vc = u_sasm.a[idx].m[0]; + let l_u_c_vv = u_sasm.a[idx].m[idx]; + let l_u_e_ccc = u_sasm.a[0].m[0][0]; + let l_u_e_ccv = u_sasm.a[0].m[0][idx]; + let l_u_e_cvc = u_sasm.a[0].m[idx][0]; + let l_u_e_cvv = u_sasm.a[0].m[idx][idx]; + let l_u_e_vcc = u_sasm.a[idx].m[0][0]; + let l_u_e_vcv = u_sasm.a[idx].m[0][idx]; + let l_u_e_vvc = u_sasm.a[idx].m[idx][0]; + let l_u_e_vvv = u_sasm.a[idx].m[idx][idx]; + + // stores to storage + s_sasm = l_u_s; + s_sasm.a = l_u_a; + s_sasm.a[0].m = l_u_m_c; + s_sasm.a[idx].m = l_u_m_v; + s_sasm.a[0].m[0] = l_u_c_cc; + s_sasm.a[0].m[idx] = l_u_c_cv; + s_sasm.a[idx].m[0] = l_u_c_vc; + s_sasm.a[idx].m[idx] = l_u_c_vv; + s_sasm.a[0].m[0][0] = l_u_e_ccc; + s_sasm.a[0].m[0][idx] = l_u_e_ccv; + s_sasm.a[0].m[idx][0] = l_u_e_cvc; + s_sasm.a[0].m[idx][idx] = l_u_e_cvv; + s_sasm.a[idx].m[0][0] = l_u_e_vcc; + s_sasm.a[idx].m[0][idx] = l_u_e_vcv; + s_sasm.a[idx].m[idx][0] = l_u_e_vvc; + s_sasm.a[idx].m[idx][idx] = l_u_e_vvv; +} + +@compute @workgroup_size(1) +fn main() { + access_m(); + access_sm(); + access_sasm(); +} diff --git a/naga/tests/in/wgsl/hlsl_mat_cx3.toml b/naga/tests/in/wgsl/hlsl_mat_cx3.toml new file mode 100644 index 0000000000..3ca0b52f4e --- /dev/null +++ b/naga/tests/in/wgsl/hlsl_mat_cx3.toml @@ -0,0 +1 @@ +targets = "HLSL" diff --git a/naga/tests/in/wgsl/hlsl_mat_cx3.wgsl b/naga/tests/in/wgsl/hlsl_mat_cx3.wgsl new file mode 100644 index 0000000000..e33f10fc9c --- /dev/null +++ b/naga/tests/in/wgsl/hlsl_mat_cx3.wgsl @@ -0,0 +1,173 @@ +// Test HLSL handling of N-by-3 matrices. These should not receive the special +// treatment that N-by-2 matrices receive (which is tested in hlsl_mat_cx2). + +// Access type (3rd item in variable names) +// S = Struct +// M = Matrix +// C = Column +// E = Element + +// Index type (4th item in variable names) +// C = Constant +// V = Variable + +alias Mat = mat3x3; + +@group(0) @binding(0) +var s_m: Mat; + +@group(0) @binding(1) +var u_m: Mat; + +fn access_m() { + var idx = 1; + idx--; + + // loads from storage + let l_s_m = s_m; + let l_s_c_c = s_m[0]; + let l_s_c_v = s_m[idx]; + let l_s_e_cc = s_m[0][0]; + let l_s_e_cv = s_m[0][idx]; + let l_s_e_vc = s_m[idx][0]; + let l_s_e_vv = s_m[idx][idx]; + + // loads from uniform + let l_u_m = u_m; + let l_u_c_c = u_m[0]; + let l_u_c_v = u_m[idx]; + let l_u_e_cc = u_m[0][0]; + let l_u_e_cv = u_m[0][idx]; + let l_u_e_vc = u_m[idx][0]; + let l_u_e_vv = u_m[idx][idx]; + + // stores to storage + s_m = l_u_m; + s_m[0] = l_u_c_c; + s_m[idx] = l_u_c_v; + s_m[0][0] = l_u_e_cc; + s_m[0][idx] = l_u_e_cv; + s_m[idx][0] = l_u_e_vc; + s_m[idx][idx] = l_u_e_vv; +} + +struct StructWithMat { + m: Mat, +} + +@group(1) @binding(0) +var s_sm: StructWithMat; + +@group(1) @binding(1) +var u_sm: StructWithMat; + +fn access_sm() { + var idx = 1; + idx--; + + // loads from storage + let l_s_s = s_sm; + let l_s_m = s_sm.m; + let l_s_c_c = s_sm.m[0]; + let l_s_c_v = s_sm.m[idx]; + let l_s_e_cc = s_sm.m[0][0]; + let l_s_e_cv = s_sm.m[0][idx]; + let l_s_e_vc = s_sm.m[idx][0]; + let l_s_e_vv = s_sm.m[idx][idx]; + + // loads from uniform + let l_u_s = u_sm; + let l_u_m = u_sm.m; + let l_u_c_c = u_sm.m[0]; + let l_u_c_v = u_sm.m[idx]; + let l_u_e_cc = u_sm.m[0][0]; + let l_u_e_cv = u_sm.m[0][idx]; + let l_u_e_vc = u_sm.m[idx][0]; + let l_u_e_vv = u_sm.m[idx][idx]; + + // stores to storage + s_sm = l_u_s; + s_sm.m = l_u_m; + s_sm.m[0] = l_u_c_c; + s_sm.m[idx] = l_u_c_v; + s_sm.m[0][0] = l_u_e_cc; + s_sm.m[0][idx] = l_u_e_cv; + s_sm.m[idx][0] = l_u_e_vc; + s_sm.m[idx][idx] = l_u_e_vv; +} + +struct StructWithArrayOfStructOfMat { + a: array, +} + +@group(2) @binding(0) +var s_sasm: StructWithArrayOfStructOfMat; + +@group(2) @binding(1) +var u_sasm: StructWithArrayOfStructOfMat; + +fn access_sasm() { + var idx = 1; + idx--; + + // loads from storage + let l_s_s = s_sasm; + let l_s_a = s_sasm.a; + let l_s_m_c = s_sasm.a[0].m; + let l_s_m_v = s_sasm.a[idx].m; + let l_s_c_cc = s_sasm.a[0].m[0]; + let l_s_c_cv = s_sasm.a[0].m[idx]; + let l_s_c_vc = s_sasm.a[idx].m[0]; + let l_s_c_vv = s_sasm.a[idx].m[idx]; + let l_s_e_ccc = s_sasm.a[0].m[0][0]; + let l_s_e_ccv = s_sasm.a[0].m[0][idx]; + let l_s_e_cvc = s_sasm.a[0].m[idx][0]; + let l_s_e_cvv = s_sasm.a[0].m[idx][idx]; + let l_s_e_vcc = s_sasm.a[idx].m[0][0]; + let l_s_e_vcv = s_sasm.a[idx].m[0][idx]; + let l_s_e_vvc = s_sasm.a[idx].m[idx][0]; + let l_s_e_vvv = s_sasm.a[idx].m[idx][idx]; + + // loads from uniform + let l_u_s = u_sasm; + let l_u_a = u_sasm.a; + let l_u_m_c = u_sasm.a[0].m; + let l_u_m_v = u_sasm.a[idx].m; + let l_u_c_cc = u_sasm.a[0].m[0]; + let l_u_c_cv = u_sasm.a[0].m[idx]; + let l_u_c_vc = u_sasm.a[idx].m[0]; + let l_u_c_vv = u_sasm.a[idx].m[idx]; + let l_u_e_ccc = u_sasm.a[0].m[0][0]; + let l_u_e_ccv = u_sasm.a[0].m[0][idx]; + let l_u_e_cvc = u_sasm.a[0].m[idx][0]; + let l_u_e_cvv = u_sasm.a[0].m[idx][idx]; + let l_u_e_vcc = u_sasm.a[idx].m[0][0]; + let l_u_e_vcv = u_sasm.a[idx].m[0][idx]; + let l_u_e_vvc = u_sasm.a[idx].m[idx][0]; + let l_u_e_vvv = u_sasm.a[idx].m[idx][idx]; + + // stores to storage + s_sasm = l_u_s; + s_sasm.a = l_u_a; + s_sasm.a[0].m = l_u_m_c; + s_sasm.a[idx].m = l_u_m_v; + s_sasm.a[0].m[0] = l_u_c_cc; + s_sasm.a[0].m[idx] = l_u_c_cv; + s_sasm.a[idx].m[0] = l_u_c_vc; + s_sasm.a[idx].m[idx] = l_u_c_vv; + s_sasm.a[0].m[0][0] = l_u_e_ccc; + s_sasm.a[0].m[0][idx] = l_u_e_ccv; + s_sasm.a[0].m[idx][0] = l_u_e_cvc; + s_sasm.a[0].m[idx][idx] = l_u_e_cvv; + s_sasm.a[idx].m[0][0] = l_u_e_vcc; + s_sasm.a[idx].m[0][idx] = l_u_e_vcv; + s_sasm.a[idx].m[idx][0] = l_u_e_vvc; + s_sasm.a[idx].m[idx][idx] = l_u_e_vvv; +} + +@compute @workgroup_size(1) +fn main() { + access_m(); + access_sm(); + access_sasm(); +} diff --git a/naga/tests/out/hlsl/wgsl-hlsl_mat_cx2.hlsl b/naga/tests/out/hlsl/wgsl-hlsl_mat_cx2.hlsl new file mode 100644 index 0000000000..1a6e932a8d --- /dev/null +++ b/naga/tests/out/hlsl/wgsl-hlsl_mat_cx2.hlsl @@ -0,0 +1,372 @@ +typedef struct { float2 _0; float2 _1; } __mat2x2; +float2 __get_col_of_mat2x2(__mat2x2 mat, uint idx) { + switch(idx) { + case 0: { return mat._0; } + case 1: { return mat._1; } + default: { return (float2)0; } + } +} +void __set_col_of_mat2x2(__mat2x2 mat, uint idx, float2 value) { + switch(idx) { + case 0: { mat._0 = value; break; } + case 1: { mat._1 = value; break; } + } +} +void __set_el_of_mat2x2(__mat2x2 mat, uint idx, uint vec_idx, float value) { + switch(idx) { + case 0: { mat._0[vec_idx] = value; break; } + case 1: { mat._1[vec_idx] = value; break; } + } +} + +struct StructWithMat { + float2 m_0; float2 m_1; +}; + +struct StructWithArrayOfStructOfMat { + StructWithMat a[4]; +}; + +RWByteAddressBuffer s_m : register(u0); +cbuffer u_m : register(b1) { __mat2x2 u_m; } +RWByteAddressBuffer s_sm : register(u0, space1); +cbuffer u_sm : register(b1, space1) { StructWithMat u_sm; } +RWByteAddressBuffer s_sasm : register(u0, space2); +cbuffer u_sasm : register(b1, space2) { StructWithArrayOfStructOfMat u_sasm; } + +void access_m() +{ + int idx = int(1); + + int _e3 = idx; + idx = asint(asuint(_e3) - asuint(int(1))); + float2x2 l_s_m = float2x2(asfloat(s_m.Load2(0)), asfloat(s_m.Load2(8))); + float2 l_s_c_c = asfloat(s_m.Load2(0)); + int _e11 = idx; + float2 l_s_c_v = asfloat(s_m.Load2(_e11*8)); + float l_s_e_cc = asfloat(s_m.Load(0+0)); + int _e20 = idx; + float l_s_e_cv = asfloat(s_m.Load(_e20*4+0)); + int _e24 = idx; + float l_s_e_vc = asfloat(s_m.Load(0+_e24*8)); + int _e29 = idx; + int _e31 = idx; + float l_s_e_vv = asfloat(s_m.Load(_e31*4+_e29*8)); + float2x2 l_u_m = ((float2x2)u_m); + float2 l_u_c_c = ((float2)u_m._0); + int _e40 = idx; + float2 l_u_c_v = ((float2)__get_col_of_mat2x2(u_m, _e40)); + float l_u_e_cc = ((float)u_m._0.x); + int _e49 = idx; + float l_u_e_cv = ((float)u_m._0[_e49]); + int _e53 = idx; + float l_u_e_vc = ((float)__get_col_of_mat2x2(u_m, _e53).x); + int _e58 = idx; + int _e60 = idx; + float l_u_e_vv = ((float)__get_col_of_mat2x2(u_m, _e58)[_e60]); + { + float2x2 _value2 = l_u_m; + s_m.Store2(0, asuint(_value2[0])); + s_m.Store2(8, asuint(_value2[1])); + } + s_m.Store2(0, asuint(l_u_c_c)); + int _e67 = idx; + s_m.Store2(_e67*8, asuint(l_u_c_v)); + s_m.Store(0+0, asuint(l_u_e_cc)); + int _e74 = idx; + s_m.Store(_e74*4+0, asuint(l_u_e_cv)); + int _e77 = idx; + s_m.Store(0+_e77*8, asuint(l_u_e_vc)); + int _e81 = idx; + int _e83 = idx; + s_m.Store(_e83*4+_e81*8, asuint(l_u_e_vv)); + return; +} + +StructWithMat ConstructStructWithMat(float2x2 arg0) { + StructWithMat ret = (StructWithMat)0; + ret.m_0 = arg0[0]; + ret.m_1 = arg0[1]; + return ret; +} + +float2x2 GetMatmOnStructWithMat(StructWithMat obj) { + return float2x2(obj.m_0, obj.m_1); +} + +void SetMatmOnStructWithMat(StructWithMat obj, float2x2 mat) { + obj.m_0 = mat[0]; + obj.m_1 = mat[1]; +} + +void SetMatVecmOnStructWithMat(StructWithMat obj, float2 vec, uint mat_idx) { + switch(mat_idx) { + case 0: { obj.m_0 = vec; break; } + case 1: { obj.m_1 = vec; break; } + } +} + +void SetMatScalarmOnStructWithMat(StructWithMat obj, float scalar, uint mat_idx, uint vec_idx) { + switch(mat_idx) { + case 0: { obj.m_0[vec_idx] = scalar; break; } + case 1: { obj.m_1[vec_idx] = scalar; break; } + } +} + +void access_sm() +{ + int idx_1 = int(1); + + int _e3 = idx_1; + idx_1 = asint(asuint(_e3) - asuint(int(1))); + StructWithMat l_s_s = ConstructStructWithMat(float2x2(asfloat(s_sm.Load2(0+0)), asfloat(s_sm.Load2(0+8)))); + float2x2 l_s_m_1 = float2x2(asfloat(s_sm.Load2(0+0)), asfloat(s_sm.Load2(0+8))); + float2 l_s_c_c_1 = asfloat(s_sm.Load2(0+0)); + int _e16 = idx_1; + float2 l_s_c_v_1 = asfloat(s_sm.Load2(_e16*8+0)); + float l_s_e_cc_1 = asfloat(s_sm.Load(0+0+0)); + int _e27 = idx_1; + float l_s_e_cv_1 = asfloat(s_sm.Load(_e27*4+0+0)); + int _e32 = idx_1; + float l_s_e_vc_1 = asfloat(s_sm.Load(0+_e32*8+0)); + int _e38 = idx_1; + int _e40 = idx_1; + float l_s_e_vv_1 = asfloat(s_sm.Load(_e40*4+_e38*8+0)); + StructWithMat l_u_s = u_sm; + float2x2 l_u_m_1 = GetMatmOnStructWithMat(u_sm); + float2 l_u_c_c_1 = GetMatmOnStructWithMat(u_sm)[0]; + int _e54 = idx_1; + float2 l_u_c_v_1 = GetMatmOnStructWithMat(u_sm)[_e54]; + float l_u_e_cc_1 = GetMatmOnStructWithMat(u_sm)[0].x; + int _e65 = idx_1; + float l_u_e_cv_1 = GetMatmOnStructWithMat(u_sm)[0][_e65]; + int _e70 = idx_1; + float l_u_e_vc_1 = GetMatmOnStructWithMat(u_sm)[_e70].x; + int _e76 = idx_1; + int _e78 = idx_1; + float l_u_e_vv_1 = GetMatmOnStructWithMat(u_sm)[_e76][_e78]; + { + StructWithMat _value2 = l_u_s; + { + s_sm.Store2(0+0, asuint(_value2.m_0)); + s_sm.Store2(0+8, asuint(_value2.m_1)); + } + } + { + float2x2 _value2 = l_u_m_1; + s_sm.Store2(0+0, asuint(_value2[0])); + s_sm.Store2(0+8, asuint(_value2[1])); + } + s_sm.Store2(0+0, asuint(l_u_c_c_1)); + int _e89 = idx_1; + s_sm.Store2(_e89*8+0, asuint(l_u_c_v_1)); + s_sm.Store(0+0+0, asuint(l_u_e_cc_1)); + int _e98 = idx_1; + s_sm.Store(_e98*4+0+0, asuint(l_u_e_cv_1)); + int _e102 = idx_1; + s_sm.Store(0+_e102*8+0, asuint(l_u_e_vc_1)); + int _e107 = idx_1; + int _e109 = idx_1; + s_sm.Store(_e109*4+_e107*8+0, asuint(l_u_e_vv_1)); + return; +} + +typedef StructWithMat ret_Constructarray4_StructWithMat_[4]; +ret_Constructarray4_StructWithMat_ Constructarray4_StructWithMat_(StructWithMat arg0, StructWithMat arg1, StructWithMat arg2, StructWithMat arg3) { + StructWithMat ret[4] = { arg0, arg1, arg2, arg3 }; + return ret; +} + +StructWithArrayOfStructOfMat ConstructStructWithArrayOfStructOfMat(StructWithMat arg0[4]) { + StructWithArrayOfStructOfMat ret = (StructWithArrayOfStructOfMat)0; + ret.a = arg0; + return ret; +} + +void access_sasm() +{ + int idx_2 = int(1); + + int _e3 = idx_2; + idx_2 = asint(asuint(_e3) - asuint(int(1))); + StructWithArrayOfStructOfMat l_s_s_1 = ConstructStructWithArrayOfStructOfMat(Constructarray4_StructWithMat_(ConstructStructWithMat(float2x2(asfloat(s_sasm.Load2(0+0+0+0)), asfloat(s_sasm.Load2(0+0+0+8)))), ConstructStructWithMat(float2x2(asfloat(s_sasm.Load2(0+16+0+0)), asfloat(s_sasm.Load2(0+16+0+8)))), ConstructStructWithMat(float2x2(asfloat(s_sasm.Load2(0+32+0+0)), asfloat(s_sasm.Load2(0+32+0+8)))), ConstructStructWithMat(float2x2(asfloat(s_sasm.Load2(0+48+0+0)), asfloat(s_sasm.Load2(0+48+0+8)))))); + StructWithMat l_s_a[4] = Constructarray4_StructWithMat_(ConstructStructWithMat(float2x2(asfloat(s_sasm.Load2(0+0+0+0)), asfloat(s_sasm.Load2(0+0+0+8)))), ConstructStructWithMat(float2x2(asfloat(s_sasm.Load2(0+16+0+0)), asfloat(s_sasm.Load2(0+16+0+8)))), ConstructStructWithMat(float2x2(asfloat(s_sasm.Load2(0+32+0+0)), asfloat(s_sasm.Load2(0+32+0+8)))), ConstructStructWithMat(float2x2(asfloat(s_sasm.Load2(0+48+0+0)), asfloat(s_sasm.Load2(0+48+0+8))))); + float2x2 l_s_m_c = float2x2(asfloat(s_sasm.Load2(0+0+0+0)), asfloat(s_sasm.Load2(0+0+0+8))); + int _e17 = idx_2; + float2x2 l_s_m_v = float2x2(asfloat(s_sasm.Load2(0+_e17*16+0+0)), asfloat(s_sasm.Load2(0+_e17*16+0+8))); + float2 l_s_c_cc = asfloat(s_sasm.Load2(0+0+0+0)); + int _e31 = idx_2; + float2 l_s_c_cv = asfloat(s_sasm.Load2(_e31*8+0+0+0)); + int _e36 = idx_2; + float2 l_s_c_vc = asfloat(s_sasm.Load2(0+0+_e36*16+0)); + int _e43 = idx_2; + int _e46 = idx_2; + float2 l_s_c_vv = asfloat(s_sasm.Load2(_e46*8+0+_e43*16+0)); + float l_s_e_ccc = asfloat(s_sasm.Load(0+0+0+0+0)); + int _e61 = idx_2; + float l_s_e_ccv = asfloat(s_sasm.Load(_e61*4+0+0+0+0)); + int _e68 = idx_2; + float l_s_e_cvc = asfloat(s_sasm.Load(0+_e68*8+0+0+0)); + int _e76 = idx_2; + int _e78 = idx_2; + float l_s_e_cvv = asfloat(s_sasm.Load(_e78*4+_e76*8+0+0+0)); + int _e83 = idx_2; + float l_s_e_vcc = asfloat(s_sasm.Load(0+0+0+_e83*16+0)); + int _e91 = idx_2; + int _e95 = idx_2; + float l_s_e_vcv = asfloat(s_sasm.Load(_e95*4+0+0+_e91*16+0)); + int _e100 = idx_2; + int _e103 = idx_2; + float l_s_e_vvc = asfloat(s_sasm.Load(0+_e103*8+0+_e100*16+0)); + int _e109 = idx_2; + int _e112 = idx_2; + int _e114 = idx_2; + float l_s_e_vvv = asfloat(s_sasm.Load(_e114*4+_e112*8+0+_e109*16+0)); + StructWithArrayOfStructOfMat l_u_s_1 = u_sasm; + StructWithMat l_u_a[4] = u_sasm.a; + float2x2 l_u_m_c = GetMatmOnStructWithMat(u_sasm.a[0]); + int _e129 = idx_2; + float2x2 l_u_m_v = GetMatmOnStructWithMat(u_sasm.a[_e129]); + float2 l_u_c_cc = GetMatmOnStructWithMat(u_sasm.a[0])[0]; + int _e143 = idx_2; + float2 l_u_c_cv = GetMatmOnStructWithMat(u_sasm.a[0])[_e143]; + int _e148 = idx_2; + float2 l_u_c_vc = GetMatmOnStructWithMat(u_sasm.a[_e148])[0]; + int _e155 = idx_2; + int _e158 = idx_2; + float2 l_u_c_vv = GetMatmOnStructWithMat(u_sasm.a[_e155])[_e158]; + float l_u_e_ccc = GetMatmOnStructWithMat(u_sasm.a[0])[0].x; + int _e173 = idx_2; + float l_u_e_ccv = GetMatmOnStructWithMat(u_sasm.a[0])[0][_e173]; + int _e180 = idx_2; + float l_u_e_cvc = GetMatmOnStructWithMat(u_sasm.a[0])[_e180].x; + int _e188 = idx_2; + int _e190 = idx_2; + float l_u_e_cvv = GetMatmOnStructWithMat(u_sasm.a[0])[_e188][_e190]; + int _e195 = idx_2; + float l_u_e_vcc = GetMatmOnStructWithMat(u_sasm.a[_e195])[0].x; + int _e203 = idx_2; + int _e207 = idx_2; + float l_u_e_vcv = GetMatmOnStructWithMat(u_sasm.a[_e203])[0][_e207]; + int _e212 = idx_2; + int _e215 = idx_2; + float l_u_e_vvc = GetMatmOnStructWithMat(u_sasm.a[_e212])[_e215].x; + int _e221 = idx_2; + int _e224 = idx_2; + int _e226 = idx_2; + float l_u_e_vvv = GetMatmOnStructWithMat(u_sasm.a[_e221])[_e224][_e226]; + { + StructWithArrayOfStructOfMat _value2 = l_u_s_1; + { + StructWithMat _value3[4] = _value2.a; + { + StructWithMat _value4 = _value3[0]; + { + s_sasm.Store2(0+0+0+0, asuint(_value4.m_0)); + s_sasm.Store2(0+0+0+8, asuint(_value4.m_1)); + } + } + { + StructWithMat _value4 = _value3[1]; + { + s_sasm.Store2(0+16+0+0, asuint(_value4.m_0)); + s_sasm.Store2(0+16+0+8, asuint(_value4.m_1)); + } + } + { + StructWithMat _value4 = _value3[2]; + { + s_sasm.Store2(0+32+0+0, asuint(_value4.m_0)); + s_sasm.Store2(0+32+0+8, asuint(_value4.m_1)); + } + } + { + StructWithMat _value4 = _value3[3]; + { + s_sasm.Store2(0+48+0+0, asuint(_value4.m_0)); + s_sasm.Store2(0+48+0+8, asuint(_value4.m_1)); + } + } + } + } + { + StructWithMat _value2[4] = l_u_a; + { + StructWithMat _value3 = _value2[0]; + { + s_sasm.Store2(0+0+0+0, asuint(_value3.m_0)); + s_sasm.Store2(0+0+0+8, asuint(_value3.m_1)); + } + } + { + StructWithMat _value3 = _value2[1]; + { + s_sasm.Store2(0+16+0+0, asuint(_value3.m_0)); + s_sasm.Store2(0+16+0+8, asuint(_value3.m_1)); + } + } + { + StructWithMat _value3 = _value2[2]; + { + s_sasm.Store2(0+32+0+0, asuint(_value3.m_0)); + s_sasm.Store2(0+32+0+8, asuint(_value3.m_1)); + } + } + { + StructWithMat _value3 = _value2[3]; + { + s_sasm.Store2(0+48+0+0, asuint(_value3.m_0)); + s_sasm.Store2(0+48+0+8, asuint(_value3.m_1)); + } + } + } + { + float2x2 _value2 = l_u_m_c; + s_sasm.Store2(0+0+0+0, asuint(_value2[0])); + s_sasm.Store2(0+0+0+8, asuint(_value2[1])); + } + int _e238 = idx_2; + { + float2x2 _value2 = l_u_m_v; + s_sasm.Store2(0+_e238*16+0+0, asuint(_value2[0])); + s_sasm.Store2(0+_e238*16+0+8, asuint(_value2[1])); + } + s_sasm.Store2(0+0+0+0, asuint(l_u_c_cc)); + int _e250 = idx_2; + s_sasm.Store2(_e250*8+0+0+0, asuint(l_u_c_cv)); + int _e254 = idx_2; + s_sasm.Store2(0+0+_e254*16+0, asuint(l_u_c_vc)); + int _e260 = idx_2; + int _e263 = idx_2; + s_sasm.Store2(_e263*8+0+_e260*16+0, asuint(l_u_c_vv)); + s_sasm.Store(0+0+0+0+0, asuint(l_u_e_ccc)); + int _e276 = idx_2; + s_sasm.Store(_e276*4+0+0+0+0, asuint(l_u_e_ccv)); + int _e282 = idx_2; + s_sasm.Store(0+_e282*8+0+0+0, asuint(l_u_e_cvc)); + int _e289 = idx_2; + int _e291 = idx_2; + s_sasm.Store(_e291*4+_e289*8+0+0+0, asuint(l_u_e_cvv)); + int _e295 = idx_2; + s_sasm.Store(0+0+0+_e295*16+0, asuint(l_u_e_vcc)); + int _e302 = idx_2; + int _e306 = idx_2; + s_sasm.Store(_e306*4+0+0+_e302*16+0, asuint(l_u_e_vcv)); + int _e310 = idx_2; + int _e313 = idx_2; + s_sasm.Store(0+_e313*8+0+_e310*16+0, asuint(l_u_e_vvc)); + int _e318 = idx_2; + int _e321 = idx_2; + int _e323 = idx_2; + s_sasm.Store(_e323*4+_e321*8+0+_e318*16+0, asuint(l_u_e_vvv)); + return; +} + +[numthreads(1, 1, 1)] +void main() +{ + access_m(); + access_sm(); + access_sasm(); + return; +} diff --git a/naga/tests/out/hlsl/wgsl-hlsl_mat_cx2.ron b/naga/tests/out/hlsl/wgsl-hlsl_mat_cx2.ron new file mode 100644 index 0000000000..a07b03300b --- /dev/null +++ b/naga/tests/out/hlsl/wgsl-hlsl_mat_cx2.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/naga/tests/out/hlsl/wgsl-hlsl_mat_cx3.hlsl b/naga/tests/out/hlsl/wgsl-hlsl_mat_cx3.hlsl new file mode 100644 index 0000000000..f90cdff6e1 --- /dev/null +++ b/naga/tests/out/hlsl/wgsl-hlsl_mat_cx3.hlsl @@ -0,0 +1,350 @@ +struct StructWithMat { + row_major float3x3 m; + int _end_pad_0; +}; + +struct StructWithArrayOfStructOfMat { + StructWithMat a[4]; +}; + +RWByteAddressBuffer s_m : register(u0); +cbuffer u_m : register(b1) { row_major float3x3 u_m; } +RWByteAddressBuffer s_sm : register(u0, space1); +cbuffer u_sm : register(b1, space1) { StructWithMat u_sm; } +RWByteAddressBuffer s_sasm : register(u0, space2); +cbuffer u_sasm : register(b1, space2) { StructWithArrayOfStructOfMat u_sasm; } + +void access_m() +{ + int idx = int(1); + + int _e3 = idx; + idx = asint(asuint(_e3) - asuint(int(1))); + float3x3 l_s_m = float3x3(asfloat(s_m.Load3(0)), asfloat(s_m.Load3(16)), asfloat(s_m.Load3(32))); + float3 l_s_c_c = asfloat(s_m.Load3(0)); + int _e11 = idx; + float3 l_s_c_v = asfloat(s_m.Load3(_e11*16)); + float l_s_e_cc = asfloat(s_m.Load(0+0)); + int _e20 = idx; + float l_s_e_cv = asfloat(s_m.Load(_e20*4+0)); + int _e24 = idx; + float l_s_e_vc = asfloat(s_m.Load(0+_e24*16)); + int _e29 = idx; + int _e31 = idx; + float l_s_e_vv = asfloat(s_m.Load(_e31*4+_e29*16)); + float3x3 l_u_m = u_m; + float3 l_u_c_c = u_m[0]; + int _e40 = idx; + float3 l_u_c_v = u_m[_e40]; + float l_u_e_cc = u_m[0].x; + int _e49 = idx; + float l_u_e_cv = u_m[0][_e49]; + int _e53 = idx; + float l_u_e_vc = u_m[_e53].x; + int _e58 = idx; + int _e60 = idx; + float l_u_e_vv = u_m[_e58][_e60]; + { + float3x3 _value2 = l_u_m; + s_m.Store3(0, asuint(_value2[0])); + s_m.Store3(16, asuint(_value2[1])); + s_m.Store3(32, asuint(_value2[2])); + } + s_m.Store3(0, asuint(l_u_c_c)); + int _e67 = idx; + s_m.Store3(_e67*16, asuint(l_u_c_v)); + s_m.Store(0+0, asuint(l_u_e_cc)); + int _e74 = idx; + s_m.Store(_e74*4+0, asuint(l_u_e_cv)); + int _e77 = idx; + s_m.Store(0+_e77*16, asuint(l_u_e_vc)); + int _e81 = idx; + int _e83 = idx; + s_m.Store(_e83*4+_e81*16, asuint(l_u_e_vv)); + return; +} + +StructWithMat ConstructStructWithMat(float3x3 arg0) { + StructWithMat ret = (StructWithMat)0; + ret.m = arg0; + return ret; +} + +void access_sm() +{ + int idx_1 = int(1); + + int _e3 = idx_1; + idx_1 = asint(asuint(_e3) - asuint(int(1))); + StructWithMat l_s_s = ConstructStructWithMat(float3x3(asfloat(s_sm.Load3(0+0)), asfloat(s_sm.Load3(0+16)), asfloat(s_sm.Load3(0+32)))); + float3x3 l_s_m_1 = float3x3(asfloat(s_sm.Load3(0+0)), asfloat(s_sm.Load3(0+16)), asfloat(s_sm.Load3(0+32))); + float3 l_s_c_c_1 = asfloat(s_sm.Load3(0+0)); + int _e16 = idx_1; + float3 l_s_c_v_1 = asfloat(s_sm.Load3(_e16*16+0)); + float l_s_e_cc_1 = asfloat(s_sm.Load(0+0+0)); + int _e27 = idx_1; + float l_s_e_cv_1 = asfloat(s_sm.Load(_e27*4+0+0)); + int _e32 = idx_1; + float l_s_e_vc_1 = asfloat(s_sm.Load(0+_e32*16+0)); + int _e38 = idx_1; + int _e40 = idx_1; + float l_s_e_vv_1 = asfloat(s_sm.Load(_e40*4+_e38*16+0)); + StructWithMat l_u_s = u_sm; + float3x3 l_u_m_1 = u_sm.m; + float3 l_u_c_c_1 = u_sm.m[0]; + int _e54 = idx_1; + float3 l_u_c_v_1 = u_sm.m[_e54]; + float l_u_e_cc_1 = u_sm.m[0].x; + int _e65 = idx_1; + float l_u_e_cv_1 = u_sm.m[0][_e65]; + int _e70 = idx_1; + float l_u_e_vc_1 = u_sm.m[_e70].x; + int _e76 = idx_1; + int _e78 = idx_1; + float l_u_e_vv_1 = u_sm.m[_e76][_e78]; + { + StructWithMat _value2 = l_u_s; + { + float3x3 _value3 = _value2.m; + s_sm.Store3(0+0, asuint(_value3[0])); + s_sm.Store3(0+16, asuint(_value3[1])); + s_sm.Store3(0+32, asuint(_value3[2])); + } + } + { + float3x3 _value2 = l_u_m_1; + s_sm.Store3(0+0, asuint(_value2[0])); + s_sm.Store3(0+16, asuint(_value2[1])); + s_sm.Store3(0+32, asuint(_value2[2])); + } + s_sm.Store3(0+0, asuint(l_u_c_c_1)); + int _e89 = idx_1; + s_sm.Store3(_e89*16+0, asuint(l_u_c_v_1)); + s_sm.Store(0+0+0, asuint(l_u_e_cc_1)); + int _e98 = idx_1; + s_sm.Store(_e98*4+0+0, asuint(l_u_e_cv_1)); + int _e102 = idx_1; + s_sm.Store(0+_e102*16+0, asuint(l_u_e_vc_1)); + int _e107 = idx_1; + int _e109 = idx_1; + s_sm.Store(_e109*4+_e107*16+0, asuint(l_u_e_vv_1)); + return; +} + +typedef StructWithMat ret_Constructarray4_StructWithMat_[4]; +ret_Constructarray4_StructWithMat_ Constructarray4_StructWithMat_(StructWithMat arg0, StructWithMat arg1, StructWithMat arg2, StructWithMat arg3) { + StructWithMat ret[4] = { arg0, arg1, arg2, arg3 }; + return ret; +} + +StructWithArrayOfStructOfMat ConstructStructWithArrayOfStructOfMat(StructWithMat arg0[4]) { + StructWithArrayOfStructOfMat ret = (StructWithArrayOfStructOfMat)0; + ret.a = arg0; + return ret; +} + +void access_sasm() +{ + int idx_2 = int(1); + + int _e3 = idx_2; + idx_2 = asint(asuint(_e3) - asuint(int(1))); + StructWithArrayOfStructOfMat l_s_s_1 = ConstructStructWithArrayOfStructOfMat(Constructarray4_StructWithMat_(ConstructStructWithMat(float3x3(asfloat(s_sasm.Load3(0+0+0+0)), asfloat(s_sasm.Load3(0+0+0+16)), asfloat(s_sasm.Load3(0+0+0+32)))), ConstructStructWithMat(float3x3(asfloat(s_sasm.Load3(0+48+0+0)), asfloat(s_sasm.Load3(0+48+0+16)), asfloat(s_sasm.Load3(0+48+0+32)))), ConstructStructWithMat(float3x3(asfloat(s_sasm.Load3(0+96+0+0)), asfloat(s_sasm.Load3(0+96+0+16)), asfloat(s_sasm.Load3(0+96+0+32)))), ConstructStructWithMat(float3x3(asfloat(s_sasm.Load3(0+144+0+0)), asfloat(s_sasm.Load3(0+144+0+16)), asfloat(s_sasm.Load3(0+144+0+32)))))); + StructWithMat l_s_a[4] = Constructarray4_StructWithMat_(ConstructStructWithMat(float3x3(asfloat(s_sasm.Load3(0+0+0+0)), asfloat(s_sasm.Load3(0+0+0+16)), asfloat(s_sasm.Load3(0+0+0+32)))), ConstructStructWithMat(float3x3(asfloat(s_sasm.Load3(0+48+0+0)), asfloat(s_sasm.Load3(0+48+0+16)), asfloat(s_sasm.Load3(0+48+0+32)))), ConstructStructWithMat(float3x3(asfloat(s_sasm.Load3(0+96+0+0)), asfloat(s_sasm.Load3(0+96+0+16)), asfloat(s_sasm.Load3(0+96+0+32)))), ConstructStructWithMat(float3x3(asfloat(s_sasm.Load3(0+144+0+0)), asfloat(s_sasm.Load3(0+144+0+16)), asfloat(s_sasm.Load3(0+144+0+32))))); + float3x3 l_s_m_c = float3x3(asfloat(s_sasm.Load3(0+0+0+0)), asfloat(s_sasm.Load3(0+0+0+16)), asfloat(s_sasm.Load3(0+0+0+32))); + int _e17 = idx_2; + float3x3 l_s_m_v = float3x3(asfloat(s_sasm.Load3(0+_e17*48+0+0)), asfloat(s_sasm.Load3(0+_e17*48+0+16)), asfloat(s_sasm.Load3(0+_e17*48+0+32))); + float3 l_s_c_cc = asfloat(s_sasm.Load3(0+0+0+0)); + int _e31 = idx_2; + float3 l_s_c_cv = asfloat(s_sasm.Load3(_e31*16+0+0+0)); + int _e36 = idx_2; + float3 l_s_c_vc = asfloat(s_sasm.Load3(0+0+_e36*48+0)); + int _e43 = idx_2; + int _e46 = idx_2; + float3 l_s_c_vv = asfloat(s_sasm.Load3(_e46*16+0+_e43*48+0)); + float l_s_e_ccc = asfloat(s_sasm.Load(0+0+0+0+0)); + int _e61 = idx_2; + float l_s_e_ccv = asfloat(s_sasm.Load(_e61*4+0+0+0+0)); + int _e68 = idx_2; + float l_s_e_cvc = asfloat(s_sasm.Load(0+_e68*16+0+0+0)); + int _e76 = idx_2; + int _e78 = idx_2; + float l_s_e_cvv = asfloat(s_sasm.Load(_e78*4+_e76*16+0+0+0)); + int _e83 = idx_2; + float l_s_e_vcc = asfloat(s_sasm.Load(0+0+0+_e83*48+0)); + int _e91 = idx_2; + int _e95 = idx_2; + float l_s_e_vcv = asfloat(s_sasm.Load(_e95*4+0+0+_e91*48+0)); + int _e100 = idx_2; + int _e103 = idx_2; + float l_s_e_vvc = asfloat(s_sasm.Load(0+_e103*16+0+_e100*48+0)); + int _e109 = idx_2; + int _e112 = idx_2; + int _e114 = idx_2; + float l_s_e_vvv = asfloat(s_sasm.Load(_e114*4+_e112*16+0+_e109*48+0)); + StructWithArrayOfStructOfMat l_u_s_1 = u_sasm; + StructWithMat l_u_a[4] = u_sasm.a; + float3x3 l_u_m_c = u_sasm.a[0].m; + int _e129 = idx_2; + float3x3 l_u_m_v = u_sasm.a[_e129].m; + float3 l_u_c_cc = u_sasm.a[0].m[0]; + int _e143 = idx_2; + float3 l_u_c_cv = u_sasm.a[0].m[_e143]; + int _e148 = idx_2; + float3 l_u_c_vc = u_sasm.a[_e148].m[0]; + int _e155 = idx_2; + int _e158 = idx_2; + float3 l_u_c_vv = u_sasm.a[_e155].m[_e158]; + float l_u_e_ccc = u_sasm.a[0].m[0].x; + int _e173 = idx_2; + float l_u_e_ccv = u_sasm.a[0].m[0][_e173]; + int _e180 = idx_2; + float l_u_e_cvc = u_sasm.a[0].m[_e180].x; + int _e188 = idx_2; + int _e190 = idx_2; + float l_u_e_cvv = u_sasm.a[0].m[_e188][_e190]; + int _e195 = idx_2; + float l_u_e_vcc = u_sasm.a[_e195].m[0].x; + int _e203 = idx_2; + int _e207 = idx_2; + float l_u_e_vcv = u_sasm.a[_e203].m[0][_e207]; + int _e212 = idx_2; + int _e215 = idx_2; + float l_u_e_vvc = u_sasm.a[_e212].m[_e215].x; + int _e221 = idx_2; + int _e224 = idx_2; + int _e226 = idx_2; + float l_u_e_vvv = u_sasm.a[_e221].m[_e224][_e226]; + { + StructWithArrayOfStructOfMat _value2 = l_u_s_1; + { + StructWithMat _value3[4] = _value2.a; + { + StructWithMat _value4 = _value3[0]; + { + float3x3 _value5 = _value4.m; + s_sasm.Store3(0+0+0+0, asuint(_value5[0])); + s_sasm.Store3(0+0+0+16, asuint(_value5[1])); + s_sasm.Store3(0+0+0+32, asuint(_value5[2])); + } + } + { + StructWithMat _value4 = _value3[1]; + { + float3x3 _value5 = _value4.m; + s_sasm.Store3(0+48+0+0, asuint(_value5[0])); + s_sasm.Store3(0+48+0+16, asuint(_value5[1])); + s_sasm.Store3(0+48+0+32, asuint(_value5[2])); + } + } + { + StructWithMat _value4 = _value3[2]; + { + float3x3 _value5 = _value4.m; + s_sasm.Store3(0+96+0+0, asuint(_value5[0])); + s_sasm.Store3(0+96+0+16, asuint(_value5[1])); + s_sasm.Store3(0+96+0+32, asuint(_value5[2])); + } + } + { + StructWithMat _value4 = _value3[3]; + { + float3x3 _value5 = _value4.m; + s_sasm.Store3(0+144+0+0, asuint(_value5[0])); + s_sasm.Store3(0+144+0+16, asuint(_value5[1])); + s_sasm.Store3(0+144+0+32, asuint(_value5[2])); + } + } + } + } + { + StructWithMat _value2[4] = l_u_a; + { + StructWithMat _value3 = _value2[0]; + { + float3x3 _value4 = _value3.m; + s_sasm.Store3(0+0+0+0, asuint(_value4[0])); + s_sasm.Store3(0+0+0+16, asuint(_value4[1])); + s_sasm.Store3(0+0+0+32, asuint(_value4[2])); + } + } + { + StructWithMat _value3 = _value2[1]; + { + float3x3 _value4 = _value3.m; + s_sasm.Store3(0+48+0+0, asuint(_value4[0])); + s_sasm.Store3(0+48+0+16, asuint(_value4[1])); + s_sasm.Store3(0+48+0+32, asuint(_value4[2])); + } + } + { + StructWithMat _value3 = _value2[2]; + { + float3x3 _value4 = _value3.m; + s_sasm.Store3(0+96+0+0, asuint(_value4[0])); + s_sasm.Store3(0+96+0+16, asuint(_value4[1])); + s_sasm.Store3(0+96+0+32, asuint(_value4[2])); + } + } + { + StructWithMat _value3 = _value2[3]; + { + float3x3 _value4 = _value3.m; + s_sasm.Store3(0+144+0+0, asuint(_value4[0])); + s_sasm.Store3(0+144+0+16, asuint(_value4[1])); + s_sasm.Store3(0+144+0+32, asuint(_value4[2])); + } + } + } + { + float3x3 _value2 = l_u_m_c; + s_sasm.Store3(0+0+0+0, asuint(_value2[0])); + s_sasm.Store3(0+0+0+16, asuint(_value2[1])); + s_sasm.Store3(0+0+0+32, asuint(_value2[2])); + } + int _e238 = idx_2; + { + float3x3 _value2 = l_u_m_v; + s_sasm.Store3(0+_e238*48+0+0, asuint(_value2[0])); + s_sasm.Store3(0+_e238*48+0+16, asuint(_value2[1])); + s_sasm.Store3(0+_e238*48+0+32, asuint(_value2[2])); + } + s_sasm.Store3(0+0+0+0, asuint(l_u_c_cc)); + int _e250 = idx_2; + s_sasm.Store3(_e250*16+0+0+0, asuint(l_u_c_cv)); + int _e254 = idx_2; + s_sasm.Store3(0+0+_e254*48+0, asuint(l_u_c_vc)); + int _e260 = idx_2; + int _e263 = idx_2; + s_sasm.Store3(_e263*16+0+_e260*48+0, asuint(l_u_c_vv)); + s_sasm.Store(0+0+0+0+0, asuint(l_u_e_ccc)); + int _e276 = idx_2; + s_sasm.Store(_e276*4+0+0+0+0, asuint(l_u_e_ccv)); + int _e282 = idx_2; + s_sasm.Store(0+_e282*16+0+0+0, asuint(l_u_e_cvc)); + int _e289 = idx_2; + int _e291 = idx_2; + s_sasm.Store(_e291*4+_e289*16+0+0+0, asuint(l_u_e_cvv)); + int _e295 = idx_2; + s_sasm.Store(0+0+0+_e295*48+0, asuint(l_u_e_vcc)); + int _e302 = idx_2; + int _e306 = idx_2; + s_sasm.Store(_e306*4+0+0+_e302*48+0, asuint(l_u_e_vcv)); + int _e310 = idx_2; + int _e313 = idx_2; + s_sasm.Store(0+_e313*16+0+_e310*48+0, asuint(l_u_e_vvc)); + int _e318 = idx_2; + int _e321 = idx_2; + int _e323 = idx_2; + s_sasm.Store(_e323*4+_e321*16+0+_e318*48+0, asuint(l_u_e_vvv)); + return; +} + +[numthreads(1, 1, 1)] +void main() +{ + access_m(); + access_sm(); + access_sasm(); + return; +} diff --git a/naga/tests/out/hlsl/wgsl-hlsl_mat_cx3.ron b/naga/tests/out/hlsl/wgsl-hlsl_mat_cx3.ron new file mode 100644 index 0000000000..a07b03300b --- /dev/null +++ b/naga/tests/out/hlsl/wgsl-hlsl_mat_cx3.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/naga/tests/out/spv/wgsl-access.spvasm b/naga/tests/out/spv/wgsl-access.spvasm index ac979146f5..17afd11c5f 100644 --- a/naga/tests/out/spv/wgsl-access.spvasm +++ b/naga/tests/out/spv/wgsl-access.spvasm @@ -51,7 +51,10 @@ var baz: Baz; var qux: vec2; fn test_matrix_within_struct_accesses() { -\tvar idx = 1; + // Test HLSL accesses to Cx2 matrices. There are additional tests + // in `hlsl_mat_cx2.wgsl`. + + var idx = 1; idx--; @@ -567,91 +570,91 @@ OpDecorate %365 Location 0 %69 = OpAccessChain %68 %56 %48 OpBranch %98 %98 = OpLabel -OpLine %3 40 5 +OpLine %3 43 5 %99 = OpLoad %6 %94 %100 = OpISub %6 %99 %70 -OpLine %3 40 5 +OpLine %3 43 5 OpStore %94 %100 -OpLine %3 43 14 +OpLine %3 46 14 %102 = OpAccessChain %101 %69 %48 %103 = OpLoad %22 %102 -OpLine %3 44 14 -OpLine %3 44 14 +OpLine %3 47 14 +OpLine %3 47 14 %105 = OpAccessChain %104 %69 %48 %48 %106 = OpLoad %13 %105 -OpLine %3 45 14 +OpLine %3 48 14 %107 = OpLoad %6 %94 %108 = OpAccessChain %104 %69 %48 %107 %109 = OpLoad %13 %108 -OpLine %3 46 14 -OpLine %3 46 14 -OpLine %3 46 14 +OpLine %3 49 14 +OpLine %3 49 14 +OpLine %3 49 14 %111 = OpAccessChain %110 %69 %48 %48 %44 %112 = OpLoad %9 %111 -OpLine %3 47 14 -OpLine %3 47 14 +OpLine %3 50 14 +OpLine %3 50 14 %113 = OpLoad %6 %94 %114 = OpAccessChain %110 %69 %48 %48 %113 %115 = OpLoad %9 %114 -OpLine %3 48 14 +OpLine %3 51 14 %116 = OpLoad %6 %94 -OpLine %3 48 14 +OpLine %3 51 14 %117 = OpAccessChain %110 %69 %48 %116 %44 %118 = OpLoad %9 %117 -OpLine %3 49 14 +OpLine %3 52 14 %119 = OpLoad %6 %94 %120 = OpLoad %6 %94 %121 = OpAccessChain %110 %69 %48 %119 %120 %122 = OpLoad %9 %121 -OpLine %3 51 29 -OpLine %3 51 45 -OpLine %3 51 13 -OpLine %3 53 5 +OpLine %3 54 29 +OpLine %3 54 45 +OpLine %3 54 13 +OpLine %3 56 5 %123 = OpLoad %6 %94 %124 = OpIAdd %6 %123 %70 -OpLine %3 53 5 -OpStore %94 %124 -OpLine %3 56 5 -OpLine %3 56 23 -OpLine %3 56 39 -OpLine %3 56 11 OpLine %3 56 5 +OpStore %94 %124 +OpLine %3 59 5 +OpLine %3 59 23 +OpLine %3 59 39 +OpLine %3 59 11 +OpLine %3 59 5 %126 = OpAccessChain %125 %96 %48 OpStore %126 %85 -OpLine %3 57 5 -OpLine %3 57 5 -OpLine %3 57 14 -OpLine %3 57 5 +OpLine %3 60 5 +OpLine %3 60 5 +OpLine %3 60 14 +OpLine %3 60 5 %128 = OpAccessChain %127 %96 %48 %48 OpStore %128 %87 -OpLine %3 58 5 +OpLine %3 61 5 %129 = OpLoad %6 %94 -OpLine %3 58 16 -OpLine %3 58 5 +OpLine %3 61 16 +OpLine %3 61 5 %130 = OpAccessChain %127 %96 %48 %129 OpStore %130 %89 -OpLine %3 59 5 -OpLine %3 59 5 -OpLine %3 59 5 -OpLine %3 59 5 +OpLine %3 62 5 +OpLine %3 62 5 +OpLine %3 62 5 +OpLine %3 62 5 %131 = OpAccessChain %28 %96 %48 %48 %44 OpStore %131 %90 -OpLine %3 60 5 -OpLine %3 60 5 +OpLine %3 63 5 +OpLine %3 63 5 %132 = OpLoad %6 %94 -OpLine %3 60 5 +OpLine %3 63 5 %133 = OpAccessChain %28 %96 %48 %48 %132 OpStore %133 %91 -OpLine %3 61 5 +OpLine %3 64 5 %134 = OpLoad %6 %94 -OpLine %3 61 5 -OpLine %3 61 5 +OpLine %3 64 5 +OpLine %3 64 5 %135 = OpAccessChain %28 %96 %48 %134 %44 OpStore %135 %92 -OpLine %3 62 5 +OpLine %3 65 5 %136 = OpLoad %6 %94 %137 = OpLoad %6 %94 -OpLine %3 62 5 +OpLine %3 65 5 %138 = OpAccessChain %28 %96 %48 %136 %137 OpStore %138 %93 OpReturn @@ -663,111 +666,111 @@ OpFunctionEnd %142 = OpAccessChain %141 %62 %48 OpBranch %153 %153 = OpLabel -OpLine %3 75 5 +OpLine %3 78 5 %154 = OpLoad %6 %150 %155 = OpISub %6 %154 %70 -OpLine %3 75 5 +OpLine %3 78 5 OpStore %150 %155 -OpLine %3 78 14 +OpLine %3 81 14 %157 = OpAccessChain %156 %142 %48 %158 = OpLoad %26 %157 -OpLine %3 79 14 -OpLine %3 79 14 +OpLine %3 82 14 +OpLine %3 82 14 %160 = OpAccessChain %159 %142 %48 %48 %161 = OpLoad %25 %160 -OpLine %3 80 14 -OpLine %3 80 14 -OpLine %3 80 14 +OpLine %3 83 14 +OpLine %3 83 14 +OpLine %3 83 14 %162 = OpAccessChain %104 %142 %48 %48 %48 %163 = OpLoad %13 %162 -OpLine %3 81 14 -OpLine %3 81 14 +OpLine %3 84 14 +OpLine %3 84 14 %164 = OpLoad %6 %150 %165 = OpAccessChain %104 %142 %48 %48 %164 %166 = OpLoad %13 %165 -OpLine %3 82 14 -OpLine %3 82 14 -OpLine %3 82 14 -OpLine %3 82 14 +OpLine %3 85 14 +OpLine %3 85 14 +OpLine %3 85 14 +OpLine %3 85 14 %167 = OpAccessChain %110 %142 %48 %48 %48 %44 %168 = OpLoad %9 %167 -OpLine %3 83 14 -OpLine %3 83 14 -OpLine %3 83 14 +OpLine %3 86 14 +OpLine %3 86 14 +OpLine %3 86 14 %169 = OpLoad %6 %150 %170 = OpAccessChain %110 %142 %48 %48 %48 %169 %171 = OpLoad %9 %170 -OpLine %3 84 14 -OpLine %3 84 14 +OpLine %3 87 14 +OpLine %3 87 14 %172 = OpLoad %6 %150 -OpLine %3 84 14 +OpLine %3 87 14 %173 = OpAccessChain %110 %142 %48 %48 %172 %44 %174 = OpLoad %9 %173 -OpLine %3 85 14 -OpLine %3 85 14 +OpLine %3 88 14 +OpLine %3 88 14 %175 = OpLoad %6 %150 %176 = OpLoad %6 %150 %177 = OpAccessChain %110 %142 %48 %48 %175 %176 %178 = OpLoad %9 %177 -OpLine %3 87 13 -OpLine %3 89 5 +OpLine %3 90 13 +OpLine %3 92 5 %179 = OpLoad %6 %150 %180 = OpIAdd %6 %179 %70 -OpLine %3 89 5 -OpStore %150 %180 -OpLine %3 92 5 OpLine %3 92 5 +OpStore %150 %180 +OpLine %3 95 5 +OpLine %3 95 5 %182 = OpAccessChain %181 %151 %48 OpStore %182 %143 -OpLine %3 93 5 -OpLine %3 93 5 -OpLine %3 93 27 -OpLine %3 93 43 -OpLine %3 93 59 -OpLine %3 93 15 -OpLine %3 93 5 +OpLine %3 96 5 +OpLine %3 96 5 +OpLine %3 96 27 +OpLine %3 96 43 +OpLine %3 96 59 +OpLine %3 96 15 +OpLine %3 96 5 %184 = OpAccessChain %183 %151 %48 %48 OpStore %184 %149 -OpLine %3 94 5 -OpLine %3 94 5 -OpLine %3 94 5 -OpLine %3 94 18 -OpLine %3 94 5 +OpLine %3 97 5 +OpLine %3 97 5 +OpLine %3 97 5 +OpLine %3 97 18 +OpLine %3 97 5 %185 = OpAccessChain %127 %151 %48 %48 %48 OpStore %185 %87 -OpLine %3 95 5 -OpLine %3 95 5 +OpLine %3 98 5 +OpLine %3 98 5 %186 = OpLoad %6 %150 -OpLine %3 95 20 -OpLine %3 95 5 +OpLine %3 98 20 +OpLine %3 98 5 %187 = OpAccessChain %127 %151 %48 %48 %186 OpStore %187 %89 -OpLine %3 96 5 -OpLine %3 96 5 -OpLine %3 96 5 -OpLine %3 96 5 -OpLine %3 96 5 +OpLine %3 99 5 +OpLine %3 99 5 +OpLine %3 99 5 +OpLine %3 99 5 +OpLine %3 99 5 %188 = OpAccessChain %28 %151 %48 %48 %48 %44 OpStore %188 %90 -OpLine %3 97 5 -OpLine %3 97 5 -OpLine %3 97 5 +OpLine %3 100 5 +OpLine %3 100 5 +OpLine %3 100 5 %189 = OpLoad %6 %150 -OpLine %3 97 5 +OpLine %3 100 5 %190 = OpAccessChain %28 %151 %48 %48 %48 %189 OpStore %190 %91 -OpLine %3 98 5 -OpLine %3 98 5 +OpLine %3 101 5 +OpLine %3 101 5 %191 = OpLoad %6 %150 -OpLine %3 98 5 -OpLine %3 98 5 +OpLine %3 101 5 +OpLine %3 101 5 %192 = OpAccessChain %28 %151 %48 %48 %191 %44 OpStore %192 %92 -OpLine %3 99 5 -OpLine %3 99 5 +OpLine %3 102 5 +OpLine %3 102 5 %193 = OpLoad %6 %150 %194 = OpLoad %6 %150 -OpLine %3 99 5 +OpLine %3 102 5 %195 = OpAccessChain %28 %151 %48 %48 %193 %194 OpStore %195 %93 OpReturn @@ -777,7 +780,7 @@ OpFunctionEnd %196 = OpLabel OpBranch %200 %200 = OpLabel -OpLine %3 102 22 +OpLine %3 105 22 %201 = OpLoad %9 %197 OpReturnValue %201 OpFunctionEnd @@ -786,9 +789,9 @@ OpFunctionEnd %202 = OpLabel OpBranch %206 %206 = OpLabel -OpLine %3 107 12 +OpLine %3 110 12 %207 = OpCompositeExtract %29 %203 4 -OpLine %3 107 12 +OpLine %3 110 12 %208 = OpCompositeExtract %9 %207 9 OpReturnValue %208 OpFunctionEnd @@ -797,7 +800,7 @@ OpFunctionEnd %209 = OpLabel OpBranch %214 %214 = OpLabel -OpLine %3 155 5 +OpLine %3 158 5 OpStore %210 %213 OpReturn OpFunctionEnd @@ -806,11 +809,11 @@ OpFunctionEnd %215 = OpLabel OpBranch %222 %222 = OpLabel -OpLine %3 159 32 -OpLine %3 159 43 -OpLine %3 159 32 -OpLine %3 159 12 -OpLine %3 159 5 +OpLine %3 162 32 +OpLine %3 162 43 +OpLine %3 162 32 +OpLine %3 162 12 +OpLine %3 162 5 OpStore %216 %221 OpReturn OpFunctionEnd @@ -819,7 +822,7 @@ OpFunctionEnd %223 = OpLabel OpBranch %227 %227 = OpLabel -OpLine %3 176 10 +OpLine %3 179 10 %228 = OpAccessChain %34 %224 %48 %229 = OpLoad %4 %228 OpReturnValue %229 @@ -829,8 +832,8 @@ OpFunctionEnd %230 = OpLabel OpBranch %234 %234 = OpLabel -OpLine %3 180 3 -OpLine %3 180 3 +OpLine %3 183 3 +OpLine %3 183 3 %235 = OpAccessChain %34 %231 %48 OpStore %235 %17 OpReturn @@ -840,7 +843,7 @@ OpFunctionEnd %236 = OpLabel OpBranch %240 %240 = OpLabel -OpLine %3 184 10 +OpLine %3 187 10 %241 = OpAccessChain %34 %237 %44 %242 = OpLoad %4 %241 OpReturnValue %242 @@ -850,8 +853,8 @@ OpFunctionEnd %243 = OpLabel OpBranch %247 %247 = OpLabel -OpLine %3 188 3 -OpLine %3 188 3 +OpLine %3 191 3 +OpLine %3 191 3 %248 = OpAccessChain %34 %244 %44 OpStore %248 %17 OpReturn @@ -862,11 +865,11 @@ OpFunctionEnd %253 = OpVariable %254 Function %255 OpBranch %256 %256 = OpLabel -OpLine %3 203 13 +OpLine %3 206 13 %257 = OpCompositeConstruct %43 %250 -OpLine %3 203 5 +OpLine %3 206 5 OpStore %253 %257 -OpLine %3 205 12 +OpLine %3 208 12 %259 = OpAccessChain %258 %253 %48 %260 = OpLoad %42 %259 OpReturnValue %260 @@ -876,8 +879,8 @@ OpFunctionEnd %266 = OpVariable %267 Function %265 OpBranch %268 %268 = OpLabel -OpLine %3 211 16 -OpLine %3 213 12 +OpLine %3 214 16 +OpLine %3 216 12 %269 = OpAccessChain %95 %266 %48 %270 = OpLoad %6 %269 OpReturnValue %270 @@ -886,19 +889,19 @@ OpFunctionEnd %271 = OpLabel OpBranch %274 %274 = OpLabel -OpLine %3 223 17 +OpLine %3 226 17 %275 = OpCompositeExtract %46 %273 0 -OpLine %3 224 20 +OpLine %3 227 20 %276 = OpCompositeExtract %6 %275 0 -OpLine %3 226 9 +OpLine %3 229 9 %277 = OpCompositeExtract %4 %273 1 %278 = OpBitcast %4 %276 %279 = OpINotEqual %42 %277 %278 -OpLine %3 226 5 +OpLine %3 229 5 OpSelectionMerge %280 None OpBranchConditional %279 %280 %280 %280 = OpLabel -OpLine %3 230 12 +OpLine %3 233 12 %281 = OpCompositeExtract %46 %273 0 %282 = OpCompositeExtract %6 %281 0 OpReturnValue %282 @@ -910,27 +913,27 @@ OpFunctionEnd %290 = OpVariable %95 Function %291 OpBranch %292 %292 = OpLabel -OpLine %3 236 17 +OpLine %3 239 17 %293 = OpAccessChain %288 %285 %48 %294 = OpLoad %46 %293 -OpLine %3 236 5 +OpLine %3 239 5 OpStore %287 %294 -OpLine %3 237 20 +OpLine %3 240 20 %295 = OpAccessChain %95 %287 %48 %296 = OpLoad %6 %295 -OpLine %3 237 5 +OpLine %3 240 5 OpStore %290 %296 -OpLine %3 239 9 +OpLine %3 242 9 %297 = OpAccessChain %34 %285 %44 %298 = OpLoad %4 %297 %299 = OpLoad %6 %290 %300 = OpBitcast %4 %299 %301 = OpINotEqual %42 %298 %300 -OpLine %3 239 5 +OpLine %3 242 5 OpSelectionMerge %302 None OpBranchConditional %301 %302 %302 %302 = OpLabel -OpLine %3 243 12 +OpLine %3 246 12 %303 = OpAccessChain %95 %285 %48 %48 %304 = OpLoad %6 %303 OpReturnValue %304 @@ -947,56 +950,56 @@ OpBranch %326 %326 = OpLabel OpLine %3 1 1 %327 = OpLoad %9 %322 -OpLine %3 115 5 +OpLine %3 118 5 OpStore %322 %71 -OpLine %3 117 2 +OpLine %3 120 2 %328 = OpFunctionCall %2 %66 -OpLine %3 118 2 +OpLine %3 121 2 %329 = OpFunctionCall %2 %140 -OpLine %3 121 16 +OpLine %3 124 16 %331 = OpAccessChain %330 %54 %48 %332 = OpLoad %10 %331 -OpLine %3 122 12 +OpLine %3 125 12 %334 = OpAccessChain %333 %54 %40 %335 = OpLoad %19 %334 -OpLine %3 124 10 +OpLine %3 127 10 %338 = OpAccessChain %337 %54 %48 %317 %48 %339 = OpLoad %9 %338 -OpLine %3 125 10 -OpLine %3 125 19 +OpLine %3 128 10 +OpLine %3 128 19 %341 = OpArrayLength %4 %54 5 -OpLine %3 125 10 +OpLine %3 128 10 %342 = OpISub %4 %341 %15 %345 = OpAccessChain %344 %54 %31 %342 %48 %346 = OpLoad %6 %345 -OpLine %3 126 10 +OpLine %3 129 10 %347 = OpLoad %24 %314 -OpLine %3 129 53 -OpLine %3 129 53 -OpLine %3 130 18 +OpLine %3 132 53 +OpLine %3 132 53 +OpLine %3 133 18 %348 = OpFunctionCall %9 %198 %322 -OpLine %3 133 28 +OpLine %3 136 28 %351 = OpExtInst %9 %1 FClamp %339 %349 %350 %352 = OpConvertFToS %6 %351 -OpLine %3 133 11 +OpLine %3 136 11 %353 = OpCompositeConstruct %33 %346 %352 %318 %319 %320 -OpLine %3 133 2 +OpLine %3 136 2 OpStore %323 %353 -OpLine %3 134 2 +OpLine %3 137 2 %354 = OpIAdd %4 %308 %44 -OpLine %3 134 2 +OpLine %3 137 2 %355 = OpAccessChain %95 %323 %354 OpStore %355 %264 -OpLine %3 135 14 +OpLine %3 138 14 %356 = OpAccessChain %95 %323 %308 %357 = OpLoad %6 %356 -OpLine %3 137 2 +OpLine %3 140 2 %358 = OpFunctionCall %9 %204 %321 -OpLine %3 139 19 +OpLine %3 142 19 %360 = OpCompositeConstruct %359 %357 %357 %357 %357 %361 = OpConvertSToF %32 %360 %362 = OpMatrixTimesVector %11 %332 %361 -OpLine %3 139 9 +OpLine %3 142 9 %363 = OpCompositeConstruct %32 %362 %73 OpStore %309 %363 OpReturn @@ -1006,33 +1009,33 @@ OpFunctionEnd %367 = OpAccessChain %313 %59 %48 OpBranch %378 %378 = OpLabel -OpLine %3 145 2 -OpLine %3 145 2 -OpLine %3 145 2 +OpLine %3 148 2 +OpLine %3 148 2 +OpLine %3 148 2 %379 = OpAccessChain %337 %54 %48 %44 %15 OpStore %379 %71 -OpLine %3 146 2 -OpLine %3 146 28 -OpLine %3 146 44 -OpLine %3 146 60 -OpLine %3 146 16 -OpLine %3 146 2 +OpLine %3 149 2 +OpLine %3 149 28 +OpLine %3 149 44 +OpLine %3 149 60 +OpLine %3 149 16 +OpLine %3 149 2 %380 = OpAccessChain %330 %54 %48 OpStore %380 %372 -OpLine %3 147 2 -OpLine %3 147 32 -OpLine %3 147 12 -OpLine %3 147 2 +OpLine %3 150 2 +OpLine %3 150 32 +OpLine %3 150 12 +OpLine %3 150 2 %381 = OpAccessChain %333 %54 %40 OpStore %381 %375 -OpLine %3 148 2 -OpLine %3 148 2 -OpLine %3 148 2 +OpLine %3 151 2 +OpLine %3 151 2 +OpLine %3 151 2 %382 = OpAccessChain %344 %54 %31 %44 %48 OpStore %382 %70 -OpLine %3 149 2 +OpLine %3 152 2 OpStore %367 %376 -OpLine %3 151 9 +OpLine %3 154 9 OpStore %365 %377 OpReturn OpFunctionEnd @@ -1042,13 +1045,13 @@ OpFunctionEnd %390 = OpVariable %36 Function %388 OpBranch %391 %391 = OpLabel -OpLine %3 165 5 -%392 = OpFunctionCall %2 %211 %389 -OpLine %3 167 32 -OpLine %3 167 43 -OpLine %3 167 32 -OpLine %3 167 12 OpLine %3 168 5 +%392 = OpFunctionCall %2 %211 %389 +OpLine %3 170 32 +OpLine %3 170 43 +OpLine %3 170 32 +OpLine %3 170 12 +OpLine %3 171 5 %393 = OpFunctionCall %2 %217 %390 OpReturn OpFunctionEnd @@ -1058,13 +1061,13 @@ OpFunctionEnd %398 = OpVariable %41 Function %399 OpBranch %400 %400 = OpLabel -OpLine %3 194 4 +OpLine %3 197 4 %401 = OpFunctionCall %2 %232 %396 -OpLine %3 195 4 -%402 = OpFunctionCall %4 %225 %396 OpLine %3 198 4 +%402 = OpFunctionCall %4 %225 %396 +OpLine %3 201 4 %403 = OpFunctionCall %2 %245 %398 -OpLine %3 199 4 +OpLine %3 202 4 %404 = OpFunctionCall %4 %238 %398 OpReturn OpFunctionEnd \ No newline at end of file From da45585c175f358d68889eb2a32b864bb68771c2 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 25 Mar 2025 19:09:43 -0700 Subject: [PATCH 2/2] [naga hlsl-out] Factor out some repetitive code --- naga/src/back/hlsl/writer.rs | 360 +++++++++++++++-------------------- 1 file changed, 152 insertions(+), 208 deletions(-) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 5f959e37e4..65b8dc0ba7 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -43,6 +43,11 @@ pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32"; pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64"; pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64"; +enum Index { + Expression(Handle), + Static(u32), +} + struct EpStructMember { name: String, ty: Handle, @@ -1766,6 +1771,23 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Ok(()) } + fn write_index( + &mut self, + module: &Module, + index: Index, + func_ctx: &back::FunctionCtx<'_>, + ) -> BackendResult { + match index { + Index::Static(index) => { + write!(self.out, "{index}")?; + } + Index::Expression(index) => { + self.write_expr(module, index, func_ctx)?; + } + } + Ok(()) + } + /// Helper method used to write statements /// /// # Notes @@ -1919,13 +1941,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars). // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads). - struct MatrixAccess { - base: Handle, - index: u32, - } - enum Index { - Expression(Handle), - Static(u32), + enum MatrixAccess { + Direct { + base: Handle, + index: u32, + }, + Struct { + columns: crate::VectorSize, + base: Handle, + }, } let get_members = |expr: Handle| { @@ -1939,187 +1963,28 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } }; - let mut matrix = None; - let mut vector = None; - let mut scalar = None; - - let mut current_expr = pointer; - for _ in 0..3 { - let resolved = func_ctx.resolve_type(current_expr, &module.types); - - match (resolved, &func_ctx.expressions[current_expr]) { - ( - &TypeInner::Pointer { base: ty, .. }, - &crate::Expression::AccessIndex { base, index }, - ) if matches!( - module.types[ty].inner, - TypeInner::Matrix { - rows: crate::VectorSize::Bi, - .. - } - ) && get_members(base) - .map(|members| members[index as usize].binding.is_none()) - == Some(true) => - { - matrix = Some(MatrixAccess { base, index }); - break; - } - ( - &TypeInner::ValuePointer { - size: Some(crate::VectorSize::Bi), - .. - }, - &crate::Expression::Access { base, index }, - ) => { - vector = Some(Index::Expression(index)); - current_expr = base; - } - ( - &TypeInner::ValuePointer { - size: Some(crate::VectorSize::Bi), - .. - }, - &crate::Expression::AccessIndex { base, index }, - ) => { - vector = Some(Index::Static(index)); - current_expr = base; - } - ( - &TypeInner::ValuePointer { size: None, .. }, - &crate::Expression::Access { base, index }, - ) => { - scalar = Some(Index::Expression(index)); - current_expr = base; - } - ( - &TypeInner::ValuePointer { size: None, .. }, - &crate::Expression::AccessIndex { base, index }, - ) => { - scalar = Some(Index::Static(index)); - current_expr = base; - } - _ => break, - } - } - write!(self.out, "{level}")?; - if let Some(MatrixAccess { index, base }) = matrix { - let base_ty_res = &func_ctx.info[base].ty; - let resolved = base_ty_res.inner_with(&module.types); - let ty = match *resolved { - TypeInner::Pointer { base, .. } => base, - _ => base_ty_res.handle().unwrap(), - }; - - if let Some(Index::Static(vec_index)) = vector { - self.write_expr(module, base, func_ctx)?; - write!( - self.out, - ".{}_{}", - &self.names[&NameKey::StructMember(ty, index)], - vec_index - )?; - - if let Some(scalar_index) = scalar { - write!(self.out, "[")?; - match scalar_index { - Index::Static(index) => { - write!(self.out, "{index}")?; - } - Index::Expression(index) => { - self.write_expr(module, index, func_ctx)?; - } - } - write!(self.out, "]")?; - } - - write!(self.out, " = ")?; - self.write_expr(module, value, func_ctx)?; - writeln!(self.out, ";")?; - } else { - let access = WrappedStructMatrixAccess { ty, index }; - match (&vector, &scalar) { - (&Some(_), &Some(_)) => { - self.write_wrapped_struct_matrix_set_scalar_function_name( - access, - )?; - } - (&Some(_), &None) => { - self.write_wrapped_struct_matrix_set_vec_function_name(access)?; - } - (&None, _) => { - self.write_wrapped_struct_matrix_set_function_name(access)?; - } - } - - write!(self.out, "(")?; - self.write_expr(module, base, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; - - if let Some(Index::Expression(vec_index)) = vector { - write!(self.out, ", ")?; - self.write_expr(module, vec_index, func_ctx)?; - - if let Some(scalar_index) = scalar { - write!(self.out, ", ")?; - match scalar_index { - Index::Static(index) => { - write!(self.out, "{index}")?; - } - Index::Expression(index) => { - self.write_expr(module, index, func_ctx)?; - } - } - } - } - writeln!(self.out, ");")?; - } - } else { - // We handle `Store`s to __matCx2 column vectors and scalar elements via - // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2. - struct MatrixData { - columns: crate::VectorSize, - base: Handle, - } - - enum Index { - Expression(Handle), - Static(u32), - } - - let mut matrix = None; - let mut vector = None; - let mut scalar = None; - - let mut current_expr = pointer; - for _ in 0..3 { - let resolved = func_ctx.resolve_type(current_expr, &module.types); - match (resolved, &func_ctx.expressions[current_expr]) { - ( - &TypeInner::ValuePointer { - size: Some(crate::VectorSize::Bi), - .. - }, - &crate::Expression::Access { base, index }, - ) => { - vector = Some(index); - current_expr = base; - } - ( - &TypeInner::ValuePointer { size: None, .. }, - &crate::Expression::Access { base, index }, - ) => { - scalar = Some(Index::Expression(index)); - current_expr = base; - } + let matrix_access_on_lhs = + find_matrix_in_access_chain(module, pointer, func_ctx).and_then( + |(matrix_expr, vector, scalar)| match ( + func_ctx.resolve_type(matrix_expr, &module.types), + &func_ctx.expressions[matrix_expr], + ) { ( - &TypeInner::ValuePointer { size: None, .. }, + &TypeInner::Pointer { base: ty, .. }, &crate::Expression::AccessIndex { base, index }, - ) => { - scalar = Some(Index::Static(index)); - current_expr = base; + ) if matches!( + module.types[ty].inner, + TypeInner::Matrix { + rows: crate::VectorSize::Bi, + .. + } + ) && get_members(base) + .map(|members| members[index as usize].binding.is_none()) + == Some(true) => + { + Some((MatrixAccess::Direct { base, index }, vector, scalar)) } _ => { if let Some(MatrixType { @@ -2128,24 +1993,95 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { width: 4, }) = get_inner_matrix_of_struct_array_member( module, - current_expr, + matrix_expr, func_ctx, true, ) { - matrix = Some(MatrixData { - columns, - base: current_expr, - }); + Some(( + MatrixAccess::Struct { + columns, + base: matrix_expr, + }, + vector, + scalar, + )) + } else { + None + } + } + }, + ); + + match matrix_access_on_lhs { + Some((MatrixAccess::Direct { index, base }, vector, scalar)) => { + let base_ty_res = &func_ctx.info[base].ty; + let resolved = base_ty_res.inner_with(&module.types); + let ty = match *resolved { + TypeInner::Pointer { base, .. } => base, + _ => base_ty_res.handle().unwrap(), + }; + + if let Some(Index::Static(vec_index)) = vector { + self.write_expr(module, base, func_ctx)?; + write!( + self.out, + ".{}_{}", + &self.names[&NameKey::StructMember(ty, index)], + vec_index + )?; + + if let Some(scalar_index) = scalar { + write!(self.out, "[")?; + self.write_index(module, scalar_index, func_ctx)?; + write!(self.out, "]")?; + } + + write!(self.out, " = ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ";")?; + } else { + let access = WrappedStructMatrixAccess { ty, index }; + match (&vector, &scalar) { + (&Some(_), &Some(_)) => { + self.write_wrapped_struct_matrix_set_scalar_function_name( + access, + )?; + } + (&Some(_), &None) => { + self.write_wrapped_struct_matrix_set_vec_function_name( + access, + )?; } + (&None, _) => { + self.write_wrapped_struct_matrix_set_function_name(access)?; + } + } + + write!(self.out, "(")?; + self.write_expr(module, base, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; - break; + if let Some(Index::Expression(vec_index)) = vector { + write!(self.out, ", ")?; + self.write_expr(module, vec_index, func_ctx)?; + + if let Some(scalar_index) = scalar { + write!(self.out, ", ")?; + self.write_index(module, scalar_index, func_ctx)?; + } } + writeln!(self.out, ");")?; } } + Some(( + MatrixAccess::Struct { columns, base }, + Some(Index::Expression(vec_index)), + scalar, + )) => { + // We handle `Store`s to __matCx2 column vectors and scalar elements via + // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2. - if let (Some(MatrixData { columns, base }), Some(vec_index)) = - (matrix, vector) - { if scalar.is_some() { write!(self.out, "__set_el_of_mat{}x2", columns as u8)?; } else { @@ -2158,21 +2094,17 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { if let Some(scalar_index) = scalar { write!(self.out, ", ")?; - match scalar_index { - Index::Static(index) => { - write!(self.out, "{index}")?; - } - Index::Expression(index) => { - self.write_expr(module, index, func_ctx)?; - } - } + self.write_index(module, scalar_index, func_ctx)?; } write!(self.out, ", ")?; self.write_expr(module, value, func_ctx)?; writeln!(self.out, ");")?; - } else { + } + Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _)) + | Some((MatrixAccess::Struct { .. }, None, _)) + | None => { self.write_expr(module, pointer, func_ctx)?; write!(self.out, " = ")?; @@ -4341,12 +4273,17 @@ pub(super) fn get_inner_matrix_data( } } +/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`, +/// returns a tuple of the matrix, the column (vector) index (if present), and +/// the row (scalar) index (if present). fn find_matrix_in_access_chain( module: &Module, base: Handle, func_ctx: &back::FunctionCtx<'_>, -) -> Option> { +) -> Option<(Handle, Option, Option)> { let mut current_base = base; + let mut vector = None; + let mut scalar = None; loop { let resolved_tr = func_ctx .resolve_type(current_base, &module.types) @@ -4354,15 +4291,22 @@ fn find_matrix_in_access_chain( let resolved = resolved_tr.as_ref()?.inner_with(&module.types); match *resolved { + TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)), TypeInner::Scalar(_) | TypeInner::Vector { .. } => {} - TypeInner::Matrix { .. } => return Some(current_base), _ => return None, } - current_base = match func_ctx.expressions[current_base] { - crate::Expression::Access { base, .. } => base, - crate::Expression::AccessIndex { base, .. } => base, + let index; + (current_base, index) = match func_ctx.expressions[current_base] { + crate::Expression::Access { base, index } => (base, Index::Expression(index)), + crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)), _ => return None, + }; + + match *resolved { + TypeInner::Scalar(_) => scalar = Some(index), + TypeInner::Vector { .. } => vector = Some(index), + _ => unreachable!(), } } } @@ -4434,7 +4378,7 @@ fn get_inner_matrix_of_global_uniform( base: Handle, func_ctx: &back::FunctionCtx<'_>, ) -> Option { - let mut current_base = find_matrix_in_access_chain(module, base, func_ctx)?; + let (mut current_base, _, _) = find_matrix_in_access_chain(module, base, func_ctx)?; let mut mat_data = None; let mut array_base = None;