Skip to content
This repository was archived by the owner on Jan 29, 2025. It is now read-only.

Commit 00be08e

Browse files
authored
[msl-out] Replace per_stage_map with per_entry_point_map (#2237)
The existing `per_stage_map` field of MSL backend options specifies resource binding maps that apply to all entry points of each stage type. It is useful to have the ability to provide a separate binding index map for each entry point, especially when the same shader module defines multiple entry points of the same stage kind. This patch replaces `per_stage_map` with a new `per_entry_point_map` option where resources are keyed by the entry-point function name.
1 parent 9742f16 commit 00be08e

16 files changed

+227
-82
lines changed

src/back/msl/mod.rs

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ holding the result.
2727
*/
2828

2929
use crate::{arena::Handle, proc::index, valid::ModuleInfo};
30-
use std::{
31-
fmt::{Error as FmtError, Write},
32-
ops,
33-
};
30+
use std::fmt::{Error as FmtError, Write};
3431

3532
mod keywords;
3633
pub mod sampler;
@@ -69,7 +66,7 @@ pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTar
6966
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
7067
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
7168
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
72-
pub struct PerStageResources {
69+
pub struct EntryPointResources {
7370
pub resources: BindingMap,
7471

7572
pub push_constant_buffer: Option<Slot>,
@@ -80,26 +77,7 @@ pub struct PerStageResources {
8077
pub sizes_buffer: Option<Slot>,
8178
}
8279

83-
#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
84-
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
85-
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
86-
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
87-
pub struct PerStageMap {
88-
pub vs: PerStageResources,
89-
pub fs: PerStageResources,
90-
pub cs: PerStageResources,
91-
}
92-
93-
impl ops::Index<crate::ShaderStage> for PerStageMap {
94-
type Output = PerStageResources;
95-
fn index(&self, stage: crate::ShaderStage) -> &PerStageResources {
96-
match stage {
97-
crate::ShaderStage::Vertex => &self.vs,
98-
crate::ShaderStage::Fragment => &self.fs,
99-
crate::ShaderStage::Compute => &self.cs,
100-
}
101-
}
102-
}
80+
pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;
10381

10482
enum ResolvedBinding {
10583
BuiltIn(crate::BuiltIn),
@@ -198,8 +176,8 @@ enum LocationMode {
198176
pub struct Options {
199177
/// (Major, Minor) target version of the Metal Shading Language.
200178
pub lang_version: (u8, u8),
201-
/// Map of per-stage resources to slots.
202-
pub per_stage_map: PerStageMap,
179+
/// Map of entry-point resources, indexed by entry point function name, to slots.
180+
pub per_entry_point_map: EntryPointResourceMap,
203181
/// Samplers to be inlined into the code.
204182
pub inline_samplers: Vec<sampler::InlineSampler>,
205183
/// Make it possible to link different stages via SPIRV-Cross.
@@ -217,7 +195,7 @@ impl Default for Options {
217195
fn default() -> Self {
218196
Options {
219197
lang_version: (2, 0),
220-
per_stage_map: PerStageMap::default(),
198+
per_entry_point_map: EntryPointResourceMap::default(),
221199
inline_samplers: Vec::new(),
222200
spirv_cross_compatibility: false,
223201
fake_missing_bindings: true,
@@ -296,12 +274,26 @@ impl Options {
296274
}
297275
}
298276

277+
fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
278+
self.per_entry_point_map.get(&ep.name)
279+
}
280+
281+
fn get_resource_binding_target(
282+
&self,
283+
ep: &crate::EntryPoint,
284+
res_binding: &crate::ResourceBinding,
285+
) -> Option<&BindTarget> {
286+
self.get_entry_point_resources(ep)
287+
.and_then(|res| res.resources.get(res_binding))
288+
}
289+
299290
fn resolve_resource_binding(
300291
&self,
301-
stage: crate::ShaderStage,
292+
ep: &crate::EntryPoint,
302293
res_binding: &crate::ResourceBinding,
303294
) -> Result<ResolvedBinding, EntryPointError> {
304-
match self.per_stage_map[stage].resources.get(res_binding) {
295+
let target = self.get_resource_binding_target(ep, res_binding);
296+
match target {
305297
Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
306298
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
307299
prefix: "fake",
@@ -312,15 +304,13 @@ impl Options {
312304
}
313305
}
314306

315-
const fn resolve_push_constants(
307+
fn resolve_push_constants(
316308
&self,
317-
stage: crate::ShaderStage,
309+
ep: &crate::EntryPoint,
318310
) -> Result<ResolvedBinding, EntryPointError> {
319-
let slot = match stage {
320-
crate::ShaderStage::Vertex => self.per_stage_map.vs.push_constant_buffer,
321-
crate::ShaderStage::Fragment => self.per_stage_map.fs.push_constant_buffer,
322-
crate::ShaderStage::Compute => self.per_stage_map.cs.push_constant_buffer,
323-
};
311+
let slot = self
312+
.get_entry_point_resources(ep)
313+
.and_then(|res| res.push_constant_buffer);
324314
match slot {
325315
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
326316
buffer: Some(slot),
@@ -340,9 +330,11 @@ impl Options {
340330

341331
fn resolve_sizes_buffer(
342332
&self,
343-
stage: crate::ShaderStage,
333+
ep: &crate::EntryPoint,
344334
) -> Result<ResolvedBinding, EntryPointError> {
345-
let slot = self.per_stage_map[stage].sizes_buffer;
335+
let slot = self
336+
.get_entry_point_resources(ep)
337+
.and_then(|res| res.sizes_buffer);
346338
match slot {
347339
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
348340
buffer: Some(slot),

src/back/msl/writer.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3406,7 +3406,8 @@ impl<W: Write> Writer<W> {
34063406
break;
34073407
}
34083408
};
3409-
let good = match options.per_stage_map[ep.stage].resources.get(br) {
3409+
let target = options.get_resource_binding_target(ep, br);
3410+
let good = match target {
34103411
Some(target) => {
34113412
let binding_ty = match module.types[var.ty].inner {
34123413
crate::TypeInner::BindingArray { base, .. } => {
@@ -3431,7 +3432,7 @@ impl<W: Write> Writer<W> {
34313432
}
34323433
}
34333434
crate::AddressSpace::PushConstant => {
3434-
if let Err(e) = options.resolve_push_constants(ep.stage) {
3435+
if let Err(e) = options.resolve_push_constants(ep) {
34353436
ep_error = Some(e);
34363437
break;
34373438
}
@@ -3442,7 +3443,7 @@ impl<W: Write> Writer<W> {
34423443
}
34433444
}
34443445
if supports_array_length {
3445-
if let Err(err) = options.resolve_sizes_buffer(ep.stage) {
3446+
if let Err(err) = options.resolve_sizes_buffer(ep) {
34463447
ep_error = Some(err);
34473448
}
34483449
}
@@ -3711,15 +3712,13 @@ impl<W: Write> Writer<W> {
37113712
}
37123713
// the resolves have already been checked for `!fake_missing_bindings` case
37133714
let resolved = match var.space {
3714-
crate::AddressSpace::PushConstant => {
3715-
options.resolve_push_constants(ep.stage).ok()
3716-
}
3715+
crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
37173716
crate::AddressSpace::WorkGroup => None,
37183717
crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => {
37193718
return Err(Error::UnsupportedAddressSpace(var.space))
37203719
}
37213720
_ => options
3722-
.resolve_resource_binding(ep.stage, var.binding.as_ref().unwrap())
3721+
.resolve_resource_binding(ep, var.binding.as_ref().unwrap())
37233722
.ok(),
37243723
};
37253724
if let Some(ref resolved) = resolved {
@@ -3764,7 +3763,7 @@ impl<W: Write> Writer<W> {
37643763
// passed as a final struct-typed argument.
37653764
if supports_array_length {
37663765
// this is checked earlier
3767-
let resolved = options.resolve_sizes_buffer(ep.stage).unwrap();
3766+
let resolved = options.resolve_sizes_buffer(ep).unwrap();
37683767
let separator = if module.global_variables.is_empty() {
37693768
' '
37703769
} else {
@@ -3824,7 +3823,7 @@ impl<W: Write> Writer<W> {
38243823
};
38253824
} else if let Some(ref binding) = var.binding {
38263825
// write an inline sampler
3827-
let resolved = options.resolve_resource_binding(ep.stage, binding).unwrap();
3826+
let resolved = options.resolve_resource_binding(ep, binding).unwrap();
38283827
if let Some(sampler) = resolved.as_inline_sampler(options) {
38293828
let name = &self.names[&NameKey::GlobalVariable(handle)];
38303829
writeln!(

tests/in/access.param.ron

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
),
77
msl: (
88
lang_version: (2, 0),
9-
per_stage_map: (
10-
vs: (
9+
per_entry_point_map: {
10+
"foo_vert": (
1111
resources: {
1212
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
1313
(group: 0, binding: 1): (buffer: Some(1), mutable: false),
@@ -16,20 +16,20 @@
1616
},
1717
sizes_buffer: Some(24),
1818
),
19-
fs: (
19+
"foo_frag": (
2020
resources: {
2121
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
2222
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
2323
},
2424
sizes_buffer: Some(24),
2525
),
26-
cs: (
26+
"atomics": (
2727
resources: {
2828
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
2929
},
3030
sizes_buffer: Some(24),
3131
),
32-
),
32+
},
3333
inline_samplers: [],
3434
spirv_cross_compatibility: false,
3535
fake_missing_bindings: false,

tests/in/binding-arrays.param.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
),
2020
msl: (
2121
lang_version: (2, 0),
22-
per_stage_map: (
23-
fs: (
22+
per_entry_point_map: {
23+
"main": (
2424
resources: {
2525
(group: 0, binding: 0): (texture: Some(0), binding_array_size: Some(10), mutable: false),
2626
},
2727
sizes_buffer: None,
2828
)
29-
),
29+
},
3030
inline_samplers: [],
3131
spirv_cross_compatibility: false,
3232
fake_missing_bindings: true,

tests/in/bitcast.params.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
(
22
msl: (
33
lang_version: (1, 2),
4-
per_stage_map: (
5-
cs: (
4+
per_entry_point_map: {
5+
"main": (
66
resources: {
77
},
88
sizes_buffer: Some(0),
99
)
10-
),
10+
},
1111
inline_samplers: [],
1212
spirv_cross_compatibility: false,
1313
fake_missing_bindings: false,

tests/in/bits.param.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
(
22
msl: (
33
lang_version: (1, 2),
4-
per_stage_map: (
5-
cs: (
4+
per_entry_point_map: {
5+
"main": (
66
resources: {
77
},
88
sizes_buffer: Some(0),
99
)
10-
),
10+
},
1111
inline_samplers: [],
1212
spirv_cross_compatibility: false,
1313
fake_missing_bindings: false,

tests/in/boids.param.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
),
77
msl: (
88
lang_version: (2, 0),
9-
per_stage_map: (
10-
cs: (
9+
per_entry_point_map: {
10+
"main": (
1111
resources: {
1212
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
1313
(group: 0, binding: 1): (buffer: Some(1), mutable: true),
1414
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
1515
},
1616
sizes_buffer: Some(3),
1717
)
18-
),
18+
},
1919
inline_samplers: [],
2020
spirv_cross_compatibility: false,
2121
fake_missing_bindings: false,

tests/in/extra.param.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
),
66
msl: (
77
lang_version: (2, 2),
8-
per_stage_map: (
9-
fs: (
8+
per_entry_point_map: {
9+
"main": (
1010
push_constant_buffer: Some(1),
1111
),
12-
),
12+
},
1313
inline_samplers: [],
1414
spirv_cross_compatibility: false,
1515
fake_missing_bindings: false,

tests/in/interface.param.ron

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
),
2020
msl: (
2121
lang_version: (2, 1),
22-
per_stage_map: (),
22+
per_entry_point_map: {},
2323
inline_samplers: [],
2424
spirv_cross_compatibility: false,
2525
fake_missing_bindings: false,

tests/in/padding.param.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
),
77
msl: (
88
lang_version: (2, 0),
9-
per_stage_map: (
10-
vs: (
9+
per_entry_point_map: {
10+
"vertex": (
1111
resources: {
1212
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
1313
(group: 0, binding: 1): (buffer: Some(1), mutable: false),
1414
(group: 0, binding: 2): (buffer: Some(2), mutable: false),
1515
},
1616
),
17-
),
17+
},
1818
inline_samplers: [],
1919
spirv_cross_compatibility: false,
2020
fake_missing_bindings: false,

0 commit comments

Comments
 (0)