Skip to content
This repository was archived by the owner on Jan 29, 2025. It is now read-only.

Commit afece6f

Browse files
committed
[msl-out] Replace per_stage_map option with per-entry-point_map
The existing `per_stage_map` field of MSL backend options specifies resource binding maps that apply to all entry points of each stage type. It is useful to have the ability to provide a separate binding index map for each entry point, especially when the same shader module defines multiple entry points of the same stage kind. This patch replaces `per_stage_map` with a new `per_entry_point_map` option where resources are keyed by the entry-point function name.
1 parent a5c2cf9 commit afece6f

16 files changed

+227
-82
lines changed

src/back/msl/mod.rs

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ holding the result.
2727
*/
2828

2929
use crate::{arena::Handle, proc::index, valid::ModuleInfo};
30-
use std::{
31-
fmt::{Error as FmtError, Write},
32-
ops,
33-
};
30+
use std::fmt::{Error as FmtError, Write};
3431

3532
mod keywords;
3633
pub mod sampler;
@@ -69,7 +66,7 @@ pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTar
6966
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
7067
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
7168
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
72-
pub struct PerStageResources {
69+
pub struct EntryPointResources {
7370
pub resources: BindingMap,
7471

7572
pub push_constant_buffer: Option<Slot>,
@@ -80,26 +77,7 @@ pub struct PerStageResources {
8077
pub sizes_buffer: Option<Slot>,
8178
}
8279

83-
#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
84-
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
85-
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
86-
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
87-
pub struct PerStageMap {
88-
pub vs: PerStageResources,
89-
pub fs: PerStageResources,
90-
pub cs: PerStageResources,
91-
}
92-
93-
impl ops::Index<crate::ShaderStage> for PerStageMap {
94-
type Output = PerStageResources;
95-
fn index(&self, stage: crate::ShaderStage) -> &PerStageResources {
96-
match stage {
97-
crate::ShaderStage::Vertex => &self.vs,
98-
crate::ShaderStage::Fragment => &self.fs,
99-
crate::ShaderStage::Compute => &self.cs,
100-
}
101-
}
102-
}
80+
pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;
10381

10482
enum ResolvedBinding {
10583
BuiltIn(crate::BuiltIn),
@@ -198,8 +176,8 @@ enum LocationMode {
198176
pub struct Options {
199177
/// (Major, Minor) target version of the Metal Shading Language.
200178
pub lang_version: (u8, u8),
201-
/// Map of per-stage resources to slots.
202-
pub per_stage_map: PerStageMap,
179+
/// Map of entry-point resources, indexed by entry point function name, to slots.
180+
pub per_entry_point_map: EntryPointResourceMap,
203181
/// Samplers to be inlined into the code.
204182
pub inline_samplers: Vec<sampler::InlineSampler>,
205183
/// Make it possible to link different stages via SPIRV-Cross.
@@ -217,7 +195,7 @@ impl Default for Options {
217195
fn default() -> Self {
218196
Options {
219197
lang_version: (2, 0),
220-
per_stage_map: PerStageMap::default(),
198+
per_entry_point_map: EntryPointResourceMap::default(),
221199
inline_samplers: Vec::new(),
222200
spirv_cross_compatibility: false,
223201
fake_missing_bindings: true,
@@ -296,12 +274,26 @@ impl Options {
296274
}
297275
}
298276

277+
fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
278+
self.per_entry_point_map.get(&ep.name)
279+
}
280+
281+
fn get_resource_binding_target(
282+
&self,
283+
ep: &crate::EntryPoint,
284+
res_binding: &crate::ResourceBinding,
285+
) -> Option<&BindTarget> {
286+
self.get_entry_point_resources(ep)
287+
.and_then(|res| res.resources.get(res_binding))
288+
}
289+
299290
fn resolve_resource_binding(
300291
&self,
301-
stage: crate::ShaderStage,
292+
ep: &crate::EntryPoint,
302293
res_binding: &crate::ResourceBinding,
303294
) -> Result<ResolvedBinding, EntryPointError> {
304-
match self.per_stage_map[stage].resources.get(res_binding) {
295+
let target = self.get_resource_binding_target(ep, res_binding);
296+
match target {
305297
Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
306298
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
307299
prefix: "fake",
@@ -312,15 +304,13 @@ impl Options {
312304
}
313305
}
314306

315-
const fn resolve_push_constants(
307+
fn resolve_push_constants(
316308
&self,
317-
stage: crate::ShaderStage,
309+
ep: &crate::EntryPoint,
318310
) -> Result<ResolvedBinding, EntryPointError> {
319-
let slot = match stage {
320-
crate::ShaderStage::Vertex => self.per_stage_map.vs.push_constant_buffer,
321-
crate::ShaderStage::Fragment => self.per_stage_map.fs.push_constant_buffer,
322-
crate::ShaderStage::Compute => self.per_stage_map.cs.push_constant_buffer,
323-
};
311+
let slot = self
312+
.get_entry_point_resources(ep)
313+
.and_then(|res| res.push_constant_buffer);
324314
match slot {
325315
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
326316
buffer: Some(slot),
@@ -340,9 +330,11 @@ impl Options {
340330

341331
fn resolve_sizes_buffer(
342332
&self,
343-
stage: crate::ShaderStage,
333+
ep: &crate::EntryPoint,
344334
) -> Result<ResolvedBinding, EntryPointError> {
345-
let slot = self.per_stage_map[stage].sizes_buffer;
335+
let slot = self
336+
.get_entry_point_resources(ep)
337+
.and_then(|res| res.sizes_buffer);
346338
match slot {
347339
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
348340
buffer: Some(slot),

src/back/msl/writer.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3368,7 +3368,8 @@ impl<W: Write> Writer<W> {
33683368
break;
33693369
}
33703370
};
3371-
let good = match options.per_stage_map[ep.stage].resources.get(br) {
3371+
let target = options.get_resource_binding_target(&ep, &br);
3372+
let good = match target {
33723373
Some(target) => {
33733374
let binding_ty = match module.types[var.ty].inner {
33743375
crate::TypeInner::BindingArray { base, .. } => {
@@ -3393,7 +3394,7 @@ impl<W: Write> Writer<W> {
33933394
}
33943395
}
33953396
crate::AddressSpace::PushConstant => {
3396-
if let Err(e) = options.resolve_push_constants(ep.stage) {
3397+
if let Err(e) = options.resolve_push_constants(ep) {
33973398
ep_error = Some(e);
33983399
break;
33993400
}
@@ -3404,7 +3405,7 @@ impl<W: Write> Writer<W> {
34043405
}
34053406
}
34063407
if supports_array_length {
3407-
if let Err(err) = options.resolve_sizes_buffer(ep.stage) {
3408+
if let Err(err) = options.resolve_sizes_buffer(ep) {
34083409
ep_error = Some(err);
34093410
}
34103411
}
@@ -3673,15 +3674,13 @@ impl<W: Write> Writer<W> {
36733674
}
36743675
// the resolves have already been checked for `!fake_missing_bindings` case
36753676
let resolved = match var.space {
3676-
crate::AddressSpace::PushConstant => {
3677-
options.resolve_push_constants(ep.stage).ok()
3678-
}
3677+
crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
36793678
crate::AddressSpace::WorkGroup => None,
36803679
crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => {
36813680
return Err(Error::UnsupportedAddressSpace(var.space))
36823681
}
36833682
_ => options
3684-
.resolve_resource_binding(ep.stage, var.binding.as_ref().unwrap())
3683+
.resolve_resource_binding(ep, var.binding.as_ref().unwrap())
36853684
.ok(),
36863685
};
36873686
if let Some(ref resolved) = resolved {
@@ -3726,7 +3725,7 @@ impl<W: Write> Writer<W> {
37263725
// passed as a final struct-typed argument.
37273726
if supports_array_length {
37283727
// this is checked earlier
3729-
let resolved = options.resolve_sizes_buffer(ep.stage).unwrap();
3728+
let resolved = options.resolve_sizes_buffer(ep).unwrap();
37303729
let separator = if module.global_variables.is_empty() {
37313730
' '
37323731
} else {
@@ -3786,7 +3785,7 @@ impl<W: Write> Writer<W> {
37863785
};
37873786
} else if let Some(ref binding) = var.binding {
37883787
// write an inline sampler
3789-
let resolved = options.resolve_resource_binding(ep.stage, binding).unwrap();
3788+
let resolved = options.resolve_resource_binding(ep, binding).unwrap();
37903789
if let Some(sampler) = resolved.as_inline_sampler(options) {
37913790
let name = &self.names[&NameKey::GlobalVariable(handle)];
37923791
writeln!(

tests/in/access.param.ron

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
),
77
msl: (
88
lang_version: (2, 0),
9-
per_stage_map: (
10-
vs: (
9+
per_entry_point_map: {
10+
"foo_vert": (
1111
resources: {
1212
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
1313
(group: 0, binding: 1): (buffer: Some(1), mutable: false),
@@ -16,20 +16,20 @@
1616
},
1717
sizes_buffer: Some(24),
1818
),
19-
fs: (
19+
"foo_frag": (
2020
resources: {
2121
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
2222
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
2323
},
2424
sizes_buffer: Some(24),
2525
),
26-
cs: (
26+
"atomics": (
2727
resources: {
2828
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
2929
},
3030
sizes_buffer: Some(24),
3131
),
32-
),
32+
},
3333
inline_samplers: [],
3434
spirv_cross_compatibility: false,
3535
fake_missing_bindings: false,

tests/in/binding-arrays.param.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
),
2020
msl: (
2121
lang_version: (2, 0),
22-
per_stage_map: (
23-
fs: (
22+
per_entry_point_map: {
23+
"main": (
2424
resources: {
2525
(group: 0, binding: 0): (texture: Some(0), binding_array_size: Some(10), mutable: false),
2626
},
2727
sizes_buffer: None,
2828
)
29-
),
29+
},
3030
inline_samplers: [],
3131
spirv_cross_compatibility: false,
3232
fake_missing_bindings: true,

tests/in/bitcast.params.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
(
22
msl: (
33
lang_version: (1, 2),
4-
per_stage_map: (
5-
cs: (
4+
per_entry_point_map: {
5+
"main": (
66
resources: {
77
},
88
sizes_buffer: Some(0),
99
)
10-
),
10+
},
1111
inline_samplers: [],
1212
spirv_cross_compatibility: false,
1313
fake_missing_bindings: false,

tests/in/bits.param.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
(
22
msl: (
33
lang_version: (1, 2),
4-
per_stage_map: (
5-
cs: (
4+
per_entry_point_map: {
5+
"main": (
66
resources: {
77
},
88
sizes_buffer: Some(0),
99
)
10-
),
10+
},
1111
inline_samplers: [],
1212
spirv_cross_compatibility: false,
1313
fake_missing_bindings: false,

tests/in/boids.param.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
),
77
msl: (
88
lang_version: (2, 0),
9-
per_stage_map: (
10-
cs: (
9+
per_entry_point_map: {
10+
"main": (
1111
resources: {
1212
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
1313
(group: 0, binding: 1): (buffer: Some(1), mutable: true),
1414
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
1515
},
1616
sizes_buffer: Some(3),
1717
)
18-
),
18+
},
1919
inline_samplers: [],
2020
spirv_cross_compatibility: false,
2121
fake_missing_bindings: false,

tests/in/extra.param.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
),
66
msl: (
77
lang_version: (2, 2),
8-
per_stage_map: (
9-
fs: (
8+
per_entry_point_map: {
9+
"main": (
1010
push_constant_buffer: Some(1),
1111
),
12-
),
12+
},
1313
inline_samplers: [],
1414
spirv_cross_compatibility: false,
1515
fake_missing_bindings: false,

tests/in/interface.param.ron

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
),
2020
msl: (
2121
lang_version: (2, 1),
22-
per_stage_map: (),
22+
per_entry_point_map: {},
2323
inline_samplers: [],
2424
spirv_cross_compatibility: false,
2525
fake_missing_bindings: false,

tests/in/padding.param.ron

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
),
77
msl: (
88
lang_version: (2, 0),
9-
per_stage_map: (
10-
vs: (
9+
per_entry_point_map: {
10+
"vertex": (
1111
resources: {
1212
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
1313
(group: 0, binding: 1): (buffer: Some(1), mutable: false),
1414
(group: 0, binding: 2): (buffer: Some(2), mutable: false),
1515
},
1616
),
17-
),
17+
},
1818
inline_samplers: [],
1919
spirv_cross_compatibility: false,
2020
fake_missing_bindings: false,

0 commit comments

Comments
 (0)