Skip to content

Commit 50ae017

Browse files
committed
[naga hlsl-out] Handle additional cases of Cx2 matrices
Fixes gfx-rs#4423
1 parent 90afc88 commit 50ae017

13 files changed

+2055
-228
lines changed

naga/src/back/hlsl/mod.rs

+16-9
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,17 @@ type should be stored in `uniform` and `storage` buffers. The HLSL we
1313
generate must access values in that form, even when it is not what
1414
HLSL would use normally.
1515
16-
The rules described here only apply to WGSL `uniform` variables. WGSL
17-
`storage` buffers are translated as HLSL `ByteAddressBuffers`, for
18-
which we generate `Load` and `Store` method calls with explicit byte
19-
offsets. WGSL pipeline inputs must be scalars or vectors; they cannot
20-
be matrices, which is where the interesting problems arise.
16+
Matching the WGSL memory layout is a concern only for `uniform`
17+
variables. WGSL `storage` buffers are translated as HLSL
18+
`ByteAddressBuffers`, for which we generate `Load` and `Store` method
19+
calls with explicit byte offsets. WGSL pipeline inputs must be scalars
20+
or vectors; they cannot be matrices, which is where the interesting
21+
problems arise. However, when an affected type appears in a struct
22+
definition, the transformations described here are applied without
23+
consideration of where the struct is used.
24+
25+
Access to storage buffers is implemented in `storage.rs`. Access to
26+
uniform buffers is implemented where applicable in `writer.rs`.
2127
2228
## Row- and column-major ordering for matrices
2329
@@ -57,10 +63,9 @@ that the columns of a `matKx2<f32>` need only be [aligned as required
5763
for `vec2<f32>`][ilov], which is [eight-byte alignment][8bb].
5864
5965
To compensate for this, any time a `matKx2<f32>` appears in a WGSL
60-
`uniform` variable, whether directly as the variable's type or as part
61-
of a struct/array, we actually emit `K` separate `float2` members, and
62-
assemble/disassemble the matrix from its columns (in WGSL; rows in
63-
HLSL) upon load and store.
66+
`uniform` value or as part of a struct/array, we actually emit `K`
67+
separate `float2` members, and assemble/disassemble the matrix from its
68+
columns (in WGSL; rows in HLSL) upon load and store.
6469
6570
For example, the following WGSL struct type:
6671
@@ -448,6 +453,8 @@ pub enum Error {
448453
Override,
449454
#[error(transparent)]
450455
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
456+
#[error("Internal error: reached unreachable code: {0}")]
457+
Unreachable(String),
451458
}
452459

453460
#[derive(PartialEq, Eq, Hash)]

naga/src/back/hlsl/storage.rs

+111-30
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ pub(super) enum StoreValue {
108108
base: Handle<crate::Type>,
109109
member_index: u32,
110110
},
111+
// Access to a single column of a Cx2 matrix within a struct
112+
TempColumnAccess {
113+
depth: usize,
114+
base: Handle<crate::Type>,
115+
member_index: u32,
116+
column: u32,
117+
},
111118
}
112119

113120
impl<W: fmt::Write> super::Writer<'_, W> {
@@ -290,6 +297,15 @@ impl<W: fmt::Write> super::Writer<'_, W> {
290297
let name = &self.names[&NameKey::StructMember(base, member_index)];
291298
write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")?
292299
}
300+
StoreValue::TempColumnAccess {
301+
depth,
302+
base,
303+
member_index,
304+
column,
305+
} => {
306+
let name = &self.names[&NameKey::StructMember(base, member_index)];
307+
write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}_{column}")?
308+
}
293309
}
294310
Ok(())
295311
}
@@ -302,6 +318,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
302318
value: StoreValue,
303319
func_ctx: &FunctionCtx,
304320
level: crate::back::Level,
321+
within_struct: Option<Handle<crate::Type>>,
305322
) -> BackendResult {
306323
let temp_resolution;
307324
let ty_resolution = match value {
@@ -325,6 +342,11 @@ impl<W: fmt::Write> super::Writer<'_, W> {
325342
temp_resolution = TypeResolution::Handle(ty_handle);
326343
&temp_resolution
327344
}
345+
StoreValue::TempColumnAccess { .. } => {
346+
return Err(Error::Unreachable(
347+
"attempting write_storage_store for TempColumnAccess".into(),
348+
));
349+
}
328350
};
329351
match *ty_resolution.inner_with(&module.types) {
330352
crate::TypeInner::Scalar(scalar) => {
@@ -372,37 +394,89 @@ impl<W: fmt::Write> super::Writer<'_, W> {
372394
rows,
373395
scalar,
374396
} => {
375-
// first, assign the value to a temporary
376-
writeln!(self.out, "{level}{{")?;
377-
let depth = level.0 + 1;
378-
write!(
379-
self.out,
380-
"{}{}{}x{} {}{} = ",
381-
level.next(),
382-
scalar.to_hlsl_str()?,
383-
columns as u8,
384-
rows as u8,
385-
STORE_TEMP_NAME,
386-
depth,
387-
)?;
388-
self.write_store_value(module, &value, func_ctx)?;
389-
writeln!(self.out, ";")?;
390-
391397
// Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
392398
let row_stride = Alignment::from(rows) * scalar.width as u32;
393399

394-
// then iterate the stores
395-
for i in 0..columns as u32 {
396-
self.temp_access_chain
397-
.push(SubAccess::Offset(i * row_stride));
398-
let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
399-
let sv = StoreValue::TempIndex {
400+
writeln!(self.out, "{level}{{")?;
401+
402+
if let Some(containing_struct) = within_struct {
403+
// If we are within a struct, then the struct was already assigned to
404+
// a temporary, we don't need to make another.
405+
let mut chain = mem::take(&mut self.temp_access_chain);
406+
for i in 0..columns as u32 {
407+
chain.push(SubAccess::Offset(i * row_stride));
408+
// working around the borrow checker in `self.write_expr`
409+
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
410+
let StoreValue::TempAccess { member_index, .. } = value else {
411+
return Err(Error::Unreachable(
412+
"write_storage_store within_struct but not TempAccess".into(),
413+
));
414+
};
415+
let column_value = StoreValue::TempColumnAccess {
416+
depth: level.0, // note not incrementing, b/c no temp
417+
base: containing_struct,
418+
member_index,
419+
column: i,
420+
};
421+
// See note about DXC and Load/Store in the module's documentation.
422+
if scalar.width == 4 {
423+
write!(
424+
self.out,
425+
"{}{}.Store{}(",
426+
level.next(),
427+
var_name,
428+
rows as u8
429+
)?;
430+
self.write_storage_address(module, &chain, func_ctx)?;
431+
write!(self.out, ", asuint(")?;
432+
self.write_store_value(module, &column_value, func_ctx)?;
433+
writeln!(self.out, "));")?;
434+
} else {
435+
write!(self.out, "{}{var_name}.Store(", level.next())?;
436+
self.write_storage_address(module, &chain, func_ctx)?;
437+
write!(self.out, ", ")?;
438+
self.write_store_value(module, &column_value, func_ctx)?;
439+
writeln!(self.out, ");")?;
440+
}
441+
chain.pop();
442+
}
443+
self.temp_access_chain = chain;
444+
} else {
445+
// first, assign the value to a temporary
446+
let depth = level.0 + 1;
447+
write!(
448+
self.out,
449+
"{}{}{}x{} {}{} = ",
450+
level.next(),
451+
scalar.to_hlsl_str()?,
452+
columns as u8,
453+
rows as u8,
454+
STORE_TEMP_NAME,
400455
depth,
401-
index: i,
402-
ty: TypeResolution::Value(ty_inner),
403-
};
404-
self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
405-
self.temp_access_chain.pop();
456+
)?;
457+
self.write_store_value(module, &value, func_ctx)?;
458+
writeln!(self.out, ";")?;
459+
460+
// then iterate the stores
461+
for i in 0..columns as u32 {
462+
self.temp_access_chain
463+
.push(SubAccess::Offset(i * row_stride));
464+
let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
465+
let sv = StoreValue::TempIndex {
466+
depth,
467+
index: i,
468+
ty: TypeResolution::Value(ty_inner),
469+
};
470+
self.write_storage_store(
471+
module,
472+
var_handle,
473+
sv,
474+
func_ctx,
475+
level.next(),
476+
None,
477+
)?;
478+
self.temp_access_chain.pop();
479+
}
406480
}
407481
// done
408482
writeln!(self.out, "{level}}}")?;
@@ -415,7 +489,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
415489
// first, assign the value to a temporary
416490
writeln!(self.out, "{level}{{")?;
417491
write!(self.out, "{}", level.next())?;
418-
self.write_value_type(module, &module.types[base].inner)?;
492+
self.write_type(module, base)?;
419493
let depth = level.next().0;
420494
write!(self.out, " {STORE_TEMP_NAME}{depth}")?;
421495
self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
@@ -430,7 +504,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
430504
index: i,
431505
ty: TypeResolution::Handle(base),
432506
};
433-
self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
507+
self.write_storage_store(module, var_handle, sv, func_ctx, level.next(), None)?;
434508
self.temp_access_chain.pop();
435509
}
436510
// done
@@ -461,7 +535,14 @@ impl<W: fmt::Write> super::Writer<'_, W> {
461535
base: struct_ty,
462536
member_index: i as u32,
463537
};
464-
self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
538+
self.write_storage_store(
539+
module,
540+
var_handle,
541+
sv,
542+
func_ctx,
543+
level.next(),
544+
Some(struct_ty),
545+
)?;
465546
self.temp_access_chain.pop();
466547
}
467548
// done

naga/src/back/hlsl/writer.rs

+77-3
Original file line numberDiff line numberDiff line change
@@ -1894,6 +1894,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
18941894
StoreValue::Expression(value),
18951895
func_ctx,
18961896
level,
1897+
None,
18971898
)?;
18981899
} else {
18991900
// We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
@@ -2878,12 +2879,33 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
28782879
//
28792880
// Note that this only works for `Load`s and we handle
28802881
// `Store`s differently in `Statement::Store`.
2882+
let cx2_columns;
28812883
if let Some(MatrixType {
28822884
columns,
28832885
rows: crate::VectorSize::Bi,
28842886
width: 4,
28852887
}) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
28862888
{
2889+
cx2_columns = Some(columns);
2890+
} else {
2891+
let base_tr = func_ctx
2892+
.resolve_type(base, &module.types)
2893+
.pointer_base_type();
2894+
let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
2895+
match (&func_ctx.expressions[base], base_ty) {
2896+
(
2897+
&Expression::GlobalVariable(handle),
2898+
Some(&TypeInner::Matrix { columns, .. }),
2899+
) if module.global_variables[handle].space
2900+
== crate::AddressSpace::Uniform =>
2901+
{
2902+
cx2_columns = Some(columns);
2903+
}
2904+
_ => cx2_columns = None,
2905+
}
2906+
}
2907+
2908+
if let Some(columns) = cx2_columns {
28872909
write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
28882910
self.write_expr(module, base, func_ctx)?;
28892911
write!(self.out, ", ")?;
@@ -2997,12 +3019,33 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
29973019
} else {
29983020
// We write the matrix column access in a special way since
29993021
// the type of `base` is our special __matCx2 struct.
3022+
let is_cx2;
30003023
if let Some(MatrixType {
30013024
rows: crate::VectorSize::Bi,
30023025
width: 4,
30033026
..
30043027
}) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
30053028
{
3029+
is_cx2 = true;
3030+
} else {
3031+
let base_tr = func_ctx
3032+
.resolve_type(base, &module.types)
3033+
.pointer_base_type();
3034+
let base_ty = base_tr.as_ref().map(|tr| tr.inner_with(&module.types));
3035+
match (&func_ctx.expressions[base], base_ty) {
3036+
(
3037+
&Expression::GlobalVariable(handle),
3038+
Some(&TypeInner::Matrix { .. }),
3039+
) if module.global_variables[handle].space
3040+
== crate::AddressSpace::Uniform =>
3041+
{
3042+
is_cx2 = true;
3043+
}
3044+
_ => is_cx2 = false,
3045+
}
3046+
}
3047+
3048+
if is_cx2 {
30063049
self.write_expr(module, base, func_ctx)?;
30073050
write!(self.out, "._{index}")?;
30083051
return Ok(());
@@ -3275,8 +3318,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
32753318
.or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
32763319
{
32773320
let mut resolved = func_ctx.resolve_type(pointer, &module.types);
3278-
if let TypeInner::Pointer { base, .. } = *resolved {
3279-
resolved = &module.types[base].inner;
3321+
let ptr_tr = resolved.pointer_base_type();
3322+
if let Some(ptr_ty) =
3323+
ptr_tr.as_ref().map(|tr| tr.inner_with(&module.types))
3324+
{
3325+
resolved = ptr_ty;
32803326
}
32813327

32823328
write!(self.out, "((")?;
@@ -4162,6 +4208,34 @@ pub(super) fn get_inner_matrix_data(
41624208
}
41634209
}
41644210

4211+
fn find_matrix_in_access_chain(
4212+
module: &Module,
4213+
base: Handle<crate::Expression>,
4214+
func_ctx: &back::FunctionCtx<'_>,
4215+
) -> Option<Handle<crate::Expression>> {
4216+
let mut current_base = base;
4217+
loop {
4218+
let resolved_tr = func_ctx
4219+
.resolve_type(current_base, &module.types)
4220+
.pointer_base_type();
4221+
let resolved = resolved_tr
4222+
.as_ref()?
4223+
.inner_with(&module.types);
4224+
4225+
match *resolved {
4226+
TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
4227+
TypeInner::Matrix { .. } => return Some(current_base),
4228+
_ => return None,
4229+
}
4230+
4231+
current_base = match func_ctx.expressions[current_base] {
4232+
crate::Expression::Access { base, .. } => base,
4233+
crate::Expression::AccessIndex { base, .. } => base,
4234+
_ => return None,
4235+
}
4236+
}
4237+
}
4238+
41654239
/// Returns the matrix data if the access chain starting at `base`:
41664240
/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
41674241
/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
@@ -4229,10 +4303,10 @@ fn get_inner_matrix_of_global_uniform(
42294303
base: Handle<crate::Expression>,
42304304
func_ctx: &back::FunctionCtx<'_>,
42314305
) -> Option<MatrixType> {
4306+
let mut current_base = find_matrix_in_access_chain(module, base, func_ctx)?;
42324307
let mut mat_data = None;
42334308
let mut array_base = None;
42344309

4235-
let mut current_base = base;
42364310
loop {
42374311
let mut resolved = func_ctx.resolve_type(current_base, &module.types);
42384312
if let TypeInner::Pointer { base, .. } = *resolved {

0 commit comments

Comments
 (0)