Skip to content
This repository was archived by the owner on Jan 29, 2025. It is now read-only.

Commit 2c56c35

Browse files
committed
[hlsl-out] Implicitly transpose all matrices
1 parent c30146c commit 2c56c35

File tree

6 files changed

+35
-18
lines changed

6 files changed

+35
-18
lines changed

src/back/hlsl/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@
55
//! - 5.1
66
//! - 6.0
77
//!
8+
//! All matrix construction/deconstruction is row based in HLSL. This means that when
9+
//! we construct a matrix from column vectors, our matrix will be implicitly transposed.
10+
//! The inverse transposition happens when we call `[0]` to get the zeroth column vector.
11+
//!
12+
//! Because all of our matrices are implicitly transposed, we flip arguments to `mul`. `mat * vec`
13+
//! becomes `vec * mat`, etc. This acts as the inverse transpose making the results identical.
14+
//!
15+
//! The only time we don't get this implicit transposition is when reading matrices from Uniforms/Push Constants.
16+
//! To deal with this, we add `row_major` to all declarations of matrices in Uniforms/Push Constants.
17+
//!
18+
//! Finally because all of our matrices are transposed, if you use `mat3x4`, it'll become `float4x3` in HLSL.
819
920
mod conv;
1021
mod help;

src/back/hlsl/storage.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
//! Logic related to `ByteAddressBuffer` operations.
22
//!
33
//! HLSL backend uses byte address buffers for all storage buffers in IR.
4-
//! Matrices have to be transposed, because HLSL syntax implies row majority.
54
65
use super::{
76
super::{FunctionCtx, INDENT},
@@ -122,7 +121,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
122121
} => {
123122
write!(
124123
self.out,
125-
"transpose({}{}x{}(",
124+
"{}{}x{}(",
126125
crate::ScalarKind::Float.to_hlsl_str(width)?,
127126
rows as u8,
128127
columns as u8,
@@ -144,7 +143,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
144143
(TypeResolution::Value(ty_inner), i * row_stride)
145144
});
146145
self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
147-
write!(self.out, "))")?;
146+
write!(self.out, ")")?;
148147
}
149148
crate::TypeInner::Array {
150149
base,
@@ -267,7 +266,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
267266
let depth = indent + 1;
268267
write!(
269268
self.out,
270-
"{}{}{}x{} {}{} = transpose(",
269+
"{}{}{}x{} {}{} = ",
271270
INDENT.repeat(indent + 1),
272271
crate::ScalarKind::Float.to_hlsl_str(width)?,
273272
rows as u8,
@@ -276,7 +275,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
276275
depth,
277276
)?;
278277
self.write_store_value(module, &value, func_ctx)?;
279-
writeln!(self.out, ");")?;
278+
writeln!(self.out, ";")?;
280279
// then iterate the stores
281280
let row_stride = width as u32 * columns as u32;
282281
for i in 0..rows as u32 {

src/back/hlsl/writer.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
640640
}
641641
}
642642

643+
if let TypeInner::Matrix { .. } = module.types[ty].inner {
644+
write!(self.out, "row_major ")?;
645+
}
646+
643647
// Write the member type and name
644648
self.write_type(module, ty)?;
645649
write!(
@@ -700,12 +704,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
700704
} => {
701705
// The IR supports only float matrix
702706
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
707+
708+
// Because of the implicit transpose all matrices have in HLSL, we need to tranpose the size as well.
703709
write!(
704710
self.out,
705711
"{}{}x{}",
706712
crate::ScalarKind::Float.to_hlsl_str(width)?,
707-
back::vector_size_str(columns),
708713
back::vector_size_str(rows),
714+
back::vector_size_str(columns),
709715
)?;
710716
}
711717
TypeInner::Image {
@@ -1302,10 +1308,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
13021308
.inner_with(&module.types)
13031309
.is_matrix() =>
13041310
{
1311+
// We intentionally flip the order of multiplication as our matrices are implicitly transposed.
13051312
write!(self.out, "mul(")?;
1306-
self.write_expr(module, left, func_ctx)?;
1307-
write!(self.out, ", ")?;
13081313
self.write_expr(module, right, func_ctx)?;
1314+
write!(self.out, ", ")?;
1315+
self.write_expr(module, left, func_ctx)?;
13091316
write!(self.out, ")")?;
13101317
}
13111318
Expression::Binary { op, left, right } => {

tests/out/hlsl/access.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position
1919

2020
float baz = foo1;
2121
foo1 = 1.0;
22-
float4x4 matrix1 = transpose(float4x4(asfloat(bar.Load4(0+0)), asfloat(bar.Load4(0+16)), asfloat(bar.Load4(0+32)), asfloat(bar.Load4(0+48))));
22+
float4x4 matrix1 = float4x4(asfloat(bar.Load4(0+0)), asfloat(bar.Load4(0+16)), asfloat(bar.Load4(0+32)), asfloat(bar.Load4(0+48)));
2323
uint2 arr[2] = {asuint(bar.Load2(72+0)), asuint(bar.Load2(72+8))};
2424
float4 _expr13 = asfloat(bar.Load4(48+0));
2525
float b = _expr13.x;
2626
int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 4) - 2u)*4+88));
2727
bar.Store(8+16+0, asuint(1.0));
2828
{
29-
float4x4 _value2 = transpose(float4x4(float4(0.0.xxxx), float4(1.0.xxxx), float4(2.0.xxxx), float4(3.0.xxxx)));
29+
float4x4 _value2 = float4x4(float4(0.0.xxxx), float4(1.0.xxxx), float4(2.0.xxxx), float4(3.0.xxxx));
3030
bar.Store4(0+0, asuint(_value2[0]));
3131
bar.Store4(0+16, asuint(_value2[1]));
3232
bar.Store4(0+32, asuint(_value2[2]));
@@ -43,7 +43,7 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position
4343
}
4444
c[(vertexinput_foo.vi1 + 1u)] = 42;
4545
int value = c[vertexinput_foo.vi1];
46-
return mul(matrix1, float4(int4(value.xxxx)));
46+
return mul(float4(int4(value.xxxx)), matrix1);
4747
}
4848

4949
[numthreads(1, 1, 1)]

tests/out/hlsl/shadow.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ struct Globals {
66
};
77

88
struct Light {
9-
float4x4 proj;
9+
row_major float4x4 proj;
1010
float4 pos;
1111
float4 color;
1212
};
@@ -51,9 +51,9 @@ float4 fs_main(FragmentInput_fs_main fragmentinput_fs_main) : SV_Target0
5151
break;
5252
}
5353
uint _expr19 = i;
54-
Light light = {transpose(float4x4(asfloat(s_lights.Load4(_expr19*96+0+0+0)), asfloat(s_lights.Load4(_expr19*96+0+0+16)), asfloat(s_lights.Load4(_expr19*96+0+0+32)), asfloat(s_lights.Load4(_expr19*96+0+0+48)))), asfloat(s_lights.Load4(_expr19*96+0+64)), asfloat(s_lights.Load4(_expr19*96+0+80))};
54+
Light light = {float4x4(asfloat(s_lights.Load4(_expr19*96+0+0+0)), asfloat(s_lights.Load4(_expr19*96+0+0+16)), asfloat(s_lights.Load4(_expr19*96+0+0+32)), asfloat(s_lights.Load4(_expr19*96+0+0+48))), asfloat(s_lights.Load4(_expr19*96+0+64)), asfloat(s_lights.Load4(_expr19*96+0+80))};
5555
uint _expr22 = i;
56-
const float _e25 = fetch_shadow(_expr22, mul(light.proj, fragmentinput_fs_main.position1));
56+
const float _e25 = fetch_shadow(_expr22, mul(fragmentinput_fs_main.position1, light.proj));
5757
float3 light_dir = normalize((light.pos.xyz - fragmentinput_fs_main.position1.xyz));
5858
float diffuse = max(0.0, dot(normal, light_dir));
5959
float3 _expr34 = color;

tests/out/hlsl/skybox.hlsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ struct VertexOutput {
1010
};
1111

1212
struct Data {
13-
float4x4 proj_inv;
14-
float4x4 view;
13+
row_major float4x4 proj_inv;
14+
row_major float4x4 view;
1515
};
1616

1717
cbuffer r_data : register(b0) { Data r_data; }
@@ -41,8 +41,8 @@ VertexOutput vs_main(VertexInput_vs_main vertexinput_vs_main)
4141
float4 _expr35 = r_data.view[2];
4242
float3x3 inv_model_view = transpose(float3x3(_expr27.xyz, _expr31.xyz, _expr35.xyz));
4343
float4x4 _expr40 = r_data.proj_inv;
44-
float4 unprojected = mul(_expr40, pos);
45-
const VertexOutput vertexoutput1 = { pos, mul(inv_model_view, unprojected.xyz) };
44+
float4 unprojected = mul(pos, _expr40);
45+
const VertexOutput vertexoutput1 = { pos, mul(unprojected.xyz, inv_model_view) };
4646
return vertexoutput1;
4747
}
4848

0 commit comments

Comments
 (0)