diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index 271557fbf2..d794976602 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -27,10 +27,7 @@ holding the result. */ use crate::{arena::Handle, proc::index, valid::ModuleInfo}; -use std::{ - fmt::{Error as FmtError, Write}, - ops, -}; +use std::fmt::{Error as FmtError, Write}; mod keywords; pub mod sampler; @@ -69,7 +66,7 @@ pub type BindingMap = std::collections::BTreeMap, @@ -80,26 +77,7 @@ pub struct PerStageResources { pub sizes_buffer: Option, } -#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)] -#[cfg_attr(feature = "serialize", derive(serde::Serialize))] -#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] -#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))] -pub struct PerStageMap { - pub vs: PerStageResources, - pub fs: PerStageResources, - pub cs: PerStageResources, -} - -impl ops::Index for PerStageMap { - type Output = PerStageResources; - fn index(&self, stage: crate::ShaderStage) -> &PerStageResources { - match stage { - crate::ShaderStage::Vertex => &self.vs, - crate::ShaderStage::Fragment => &self.fs, - crate::ShaderStage::Compute => &self.cs, - } - } -} +pub type EntryPointResourceMap = std::collections::BTreeMap; enum ResolvedBinding { BuiltIn(crate::BuiltIn), @@ -198,8 +176,8 @@ enum LocationMode { pub struct Options { /// (Major, Minor) target version of the Metal Shading Language. pub lang_version: (u8, u8), - /// Map of per-stage resources to slots. - pub per_stage_map: PerStageMap, + /// Map of entry-point resources, indexed by entry point function name, to slots. + pub per_entry_point_map: EntryPointResourceMap, /// Samplers to be inlined into the code. pub inline_samplers: Vec, /// Make it possible to link different stages via SPIRV-Cross. @@ -217,7 +195,7 @@ impl Default for Options { fn default() -> Self { Options { lang_version: (2, 0), - per_stage_map: PerStageMap::default(), + per_entry_point_map: EntryPointResourceMap::default(), inline_samplers: Vec::new(), spirv_cross_compatibility: false, fake_missing_bindings: true, @@ -296,12 +274,26 @@ impl Options { } } + fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> { + self.per_entry_point_map.get(&ep.name) + } + + fn get_resource_binding_target( + &self, + ep: &crate::EntryPoint, + res_binding: &crate::ResourceBinding, + ) -> Option<&BindTarget> { + self.get_entry_point_resources(ep) + .and_then(|res| res.resources.get(res_binding)) + } + fn resolve_resource_binding( &self, - stage: crate::ShaderStage, + ep: &crate::EntryPoint, res_binding: &crate::ResourceBinding, ) -> Result { - match self.per_stage_map[stage].resources.get(res_binding) { + let target = self.get_resource_binding_target(ep, res_binding); + match target { Some(target) => Ok(ResolvedBinding::Resource(target.clone())), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", @@ -312,15 +304,13 @@ impl Options { } } - const fn resolve_push_constants( + fn resolve_push_constants( &self, - stage: crate::ShaderStage, + ep: &crate::EntryPoint, ) -> Result { - let slot = match stage { - crate::ShaderStage::Vertex => self.per_stage_map.vs.push_constant_buffer, - crate::ShaderStage::Fragment => self.per_stage_map.fs.push_constant_buffer, - crate::ShaderStage::Compute => self.per_stage_map.cs.push_constant_buffer, - }; + let slot = self + .get_entry_point_resources(ep) + .and_then(|res| res.push_constant_buffer); match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), @@ -340,9 +330,11 @@ impl Options { fn resolve_sizes_buffer( &self, - stage: crate::ShaderStage, + ep: &crate::EntryPoint, ) -> Result { - let slot = self.per_stage_map[stage].sizes_buffer; + let slot = self + .get_entry_point_resources(ep) + .and_then(|res| res.sizes_buffer); match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index d424c11b20..f53879f5bb 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -3368,7 +3368,8 @@ impl Writer { break; } }; - let good = match options.per_stage_map[ep.stage].resources.get(br) { + let target = options.get_resource_binding_target(ep, br); + let good = match target { Some(target) => { let binding_ty = match module.types[var.ty].inner { crate::TypeInner::BindingArray { base, .. } => { @@ -3393,7 +3394,7 @@ impl Writer { } } crate::AddressSpace::PushConstant => { - if let Err(e) = options.resolve_push_constants(ep.stage) { + if let Err(e) = options.resolve_push_constants(ep) { ep_error = Some(e); break; } @@ -3404,7 +3405,7 @@ impl Writer { } } if supports_array_length { - if let Err(err) = options.resolve_sizes_buffer(ep.stage) { + if let Err(err) = options.resolve_sizes_buffer(ep) { ep_error = Some(err); } } @@ -3673,15 +3674,13 @@ impl Writer { } // the resolves have already been checked for `!fake_missing_bindings` case let resolved = match var.space { - crate::AddressSpace::PushConstant => { - options.resolve_push_constants(ep.stage).ok() - } + crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(), crate::AddressSpace::WorkGroup => None, crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => { return Err(Error::UnsupportedAddressSpace(var.space)) } _ => options - .resolve_resource_binding(ep.stage, var.binding.as_ref().unwrap()) + .resolve_resource_binding(ep, var.binding.as_ref().unwrap()) .ok(), }; if let Some(ref resolved) = resolved { @@ -3726,7 +3725,7 @@ impl Writer { // passed as a final struct-typed argument. if supports_array_length { // this is checked earlier - let resolved = options.resolve_sizes_buffer(ep.stage).unwrap(); + let resolved = options.resolve_sizes_buffer(ep).unwrap(); let separator = if module.global_variables.is_empty() { ' ' } else { @@ -3786,7 +3785,7 @@ impl Writer { }; } else if let Some(ref binding) = var.binding { // write an inline sampler - let resolved = options.resolve_resource_binding(ep.stage, binding).unwrap(); + let resolved = options.resolve_resource_binding(ep, binding).unwrap(); if let Some(sampler) = resolved.as_inline_sampler(options) { let name = &self.names[&NameKey::GlobalVariable(handle)]; writeln!( diff --git a/tests/in/access.param.ron b/tests/in/access.param.ron index af45e4f970..5cd8e79a48 100644 --- a/tests/in/access.param.ron +++ b/tests/in/access.param.ron @@ -6,8 +6,8 @@ ), msl: ( lang_version: (2, 0), - per_stage_map: ( - vs: ( + per_entry_point_map: { + "foo_vert": ( resources: { (group: 0, binding: 0): (buffer: Some(0), mutable: false), (group: 0, binding: 1): (buffer: Some(1), mutable: false), @@ -16,20 +16,20 @@ }, sizes_buffer: Some(24), ), - fs: ( + "foo_frag": ( resources: { (group: 0, binding: 0): (buffer: Some(0), mutable: true), (group: 0, binding: 2): (buffer: Some(2), mutable: true), }, sizes_buffer: Some(24), ), - cs: ( + "atomics": ( resources: { (group: 0, binding: 0): (buffer: Some(0), mutable: true), }, sizes_buffer: Some(24), ), - ), + }, inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, diff --git a/tests/in/binding-arrays.param.ron b/tests/in/binding-arrays.param.ron index ad5b0ee319..1f2ac686b8 100644 --- a/tests/in/binding-arrays.param.ron +++ b/tests/in/binding-arrays.param.ron @@ -19,14 +19,14 @@ ), msl: ( lang_version: (2, 0), - per_stage_map: ( - fs: ( + per_entry_point_map: { + "main": ( resources: { (group: 0, binding: 0): (texture: Some(0), binding_array_size: Some(10), mutable: false), }, sizes_buffer: None, ) - ), + }, inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: true, diff --git a/tests/in/bitcast.params.ron b/tests/in/bitcast.params.ron index 324cd4a518..b40cf9fa08 100644 --- a/tests/in/bitcast.params.ron +++ b/tests/in/bitcast.params.ron @@ -1,13 +1,13 @@ ( msl: ( lang_version: (1, 2), - per_stage_map: ( - cs: ( + per_entry_point_map: { + "main": ( resources: { }, sizes_buffer: Some(0), ) - ), + }, inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, diff --git a/tests/in/bits.param.ron b/tests/in/bits.param.ron index 324cd4a518..b40cf9fa08 100644 --- a/tests/in/bits.param.ron +++ b/tests/in/bits.param.ron @@ -1,13 +1,13 @@ ( msl: ( lang_version: (1, 2), - per_stage_map: ( - cs: ( + per_entry_point_map: { + "main": ( resources: { }, sizes_buffer: Some(0), ) - ), + }, inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, diff --git a/tests/in/boids.param.ron b/tests/in/boids.param.ron index e6d752fca8..976b457ede 100644 --- a/tests/in/boids.param.ron +++ b/tests/in/boids.param.ron @@ -6,8 +6,8 @@ ), msl: ( lang_version: (2, 0), - per_stage_map: ( - cs: ( + per_entry_point_map: { + "main": ( resources: { (group: 0, binding: 0): (buffer: Some(0), mutable: false), (group: 0, binding: 1): (buffer: Some(1), mutable: true), @@ -15,7 +15,7 @@ }, sizes_buffer: Some(3), ) - ), + }, inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, diff --git a/tests/in/extra.param.ron b/tests/in/extra.param.ron index 051dee6432..581c6d6bdb 100644 --- a/tests/in/extra.param.ron +++ b/tests/in/extra.param.ron @@ -5,11 +5,11 @@ ), msl: ( lang_version: (2, 2), - per_stage_map: ( - fs: ( + per_entry_point_map: { + "main": ( push_constant_buffer: Some(1), ), - ), + }, inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, diff --git a/tests/in/interface.param.ron b/tests/in/interface.param.ron index 5df3e3b9c3..1c910d2b86 100644 --- a/tests/in/interface.param.ron +++ b/tests/in/interface.param.ron @@ -19,7 +19,7 @@ ), msl: ( lang_version: (2, 1), - per_stage_map: (), + per_entry_point_map: {}, inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, diff --git a/tests/in/padding.param.ron b/tests/in/padding.param.ron index 41af2916a9..14859bba3e 100644 --- a/tests/in/padding.param.ron +++ b/tests/in/padding.param.ron @@ -6,15 +6,15 @@ ), msl: ( lang_version: (2, 0), - per_stage_map: ( - vs: ( + per_entry_point_map: { + "vertex": ( resources: { (group: 0, binding: 0): (buffer: Some(0), mutable: false), (group: 0, binding: 1): (buffer: Some(1), mutable: false), (group: 0, binding: 2): (buffer: Some(2), mutable: false), }, ), - ), + }, inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, diff --git a/tests/in/resource-binding-map.param.ron b/tests/in/resource-binding-map.param.ron new file mode 100644 index 0000000000..6b0b43c2a7 --- /dev/null +++ b/tests/in/resource-binding-map.param.ron @@ -0,0 +1,53 @@ +( + god_mode: true, + msl: ( + lang_version: (2, 0), + per_entry_point_map: { + "entry_point_one": ( + resources: { + (group: 0, binding: 0): (texture: Some(0)), + (group: 0, binding: 2): (sampler: Some(Inline(0))), + (group: 0, binding: 4): (buffer: Some(0)), + } + ), + "entry_point_two": ( + resources: { + (group: 0, binding: 0): (texture: Some(0)), + (group: 0, binding: 2): (sampler: Some(Resource(0))), + (group: 0, binding: 4): (buffer: Some(0)), + } + ), + "entry_point_three": ( + resources: { + (group: 0, binding: 0): (texture: Some(0)), + (group: 0, binding: 1): (texture: Some(1)), + (group: 0, binding: 2): (sampler: Some(Inline(0))), + (group: 0, binding: 3): (sampler: Some(Resource(1))), + (group: 0, binding: 4): (buffer: Some(0)), + (group: 1, binding: 0): (buffer: Some(1)), + } + ) + }, + inline_samplers: [ + ( + coord: Normalized, + address: (ClampToEdge, ClampToEdge, ClampToEdge), + mag_filter: Linear, + min_filter: Linear, + mip_filter: None, + border_color: TransparentBlack, + compare_func: Never, + lod_clamp: Some((start: 0.5, end: 10.0)), + max_anisotropy: Some(8), + ), + ], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, + ), + bounds_check_policies: ( + index: ReadZeroSkipWrite, + buffer: ReadZeroSkipWrite, + image: ReadZeroSkipWrite, + ) +) diff --git a/tests/in/resource-binding-map.wgsl b/tests/in/resource-binding-map.wgsl new file mode 100644 index 0000000000..fa5fce4ee1 --- /dev/null +++ b/tests/in/resource-binding-map.wgsl @@ -0,0 +1,23 @@ +@group(0) @binding(0) var t1: texture_2d; +@group(0) @binding(1) var t2: texture_2d; +@group(0) @binding(2) var s1: sampler; +@group(0) @binding(3) var s2: sampler; + +@group(0) @binding(4) var uniformOne: vec2; +@group(1) @binding(0) var uniformTwo: vec2; + +@fragment +fn entry_point_one(@builtin(position) pos: vec4) -> @location(0) vec4 { + return textureSample(t1, s1, pos.xy); +} + +@fragment +fn entry_point_two() -> @location(0) vec4 { + return textureSample(t1, s1, uniformOne); +} + +@fragment +fn entry_point_three() -> @location(0) vec4 { + return textureSample(t1, s1, uniformTwo + uniformOne) + + textureSample(t2, s2, uniformOne); +} diff --git a/tests/in/skybox.param.ron b/tests/in/skybox.param.ron index 905721c914..6bca28fc66 100644 --- a/tests/in/skybox.param.ron +++ b/tests/in/skybox.param.ron @@ -7,19 +7,19 @@ ), msl: ( lang_version: (2, 1), - per_stage_map: ( - vs: ( + per_entry_point_map: { + "vs_main": ( resources: { (group: 0, binding: 0): (buffer: Some(0)), }, ), - fs: ( + "fs_main": ( resources: { (group: 0, binding: 1): (texture: Some(0)), (group: 0, binding: 2): (sampler: Some(Inline(0))), }, ), - ), + }, inline_samplers: [ ( coord: Normalized, diff --git a/tests/in/workgroup-var-init.param.ron b/tests/in/workgroup-var-init.param.ron index fd10e95d60..be5302284b 100644 --- a/tests/in/workgroup-var-init.param.ron +++ b/tests/in/workgroup-var-init.param.ron @@ -6,17 +6,17 @@ ), msl: ( lang_version: (2, 0), - per_stage_map: ( - cs: ( + per_entry_point_map: { + "main": ( resources: { (group: 0, binding: 0): (buffer: Some(0), mutable: true), }, sizes_buffer: None, ), - ), + }, inline_samplers: [], spirv_cross_compatibility: false, fake_missing_bindings: false, zero_initialize_workgroup_memory: true, ), -) \ No newline at end of file +) diff --git a/tests/out/msl/resource-binding-map.msl b/tests/out/msl/resource-binding-map.msl new file mode 100644 index 0000000000..4e0b601320 --- /dev/null +++ b/tests/out/msl/resource-binding-map.msl @@ -0,0 +1,74 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + +struct DefaultConstructible { + template + operator T() && { + return T {}; + } +}; + +struct entry_point_oneInput { +}; +struct entry_point_oneOutput { + metal::float4 member [[color(0)]]; +}; +fragment entry_point_oneOutput entry_point_one( + metal::float4 pos [[position]] +, metal::texture2d t1_ [[texture(0)]] +) { + constexpr metal::sampler s1_( + metal::s_address::clamp_to_edge, + metal::t_address::clamp_to_edge, + metal::r_address::clamp_to_edge, + metal::mag_filter::linear, + metal::min_filter::linear, + metal::coord::normalized + ); + metal::float4 _e4 = t1_.sample(s1_, pos.xy); + return entry_point_oneOutput { _e4 }; +} + + +struct entry_point_twoOutput { + metal::float4 member_1 [[color(0)]]; +}; +fragment entry_point_twoOutput entry_point_two( + metal::texture2d t1_ [[texture(0)]] +, metal::sampler s1_ [[sampler(0)]] +, constant metal::float2& uniformOne [[buffer(0)]] +) { + metal::float2 _e3 = uniformOne; + metal::float4 _e4 = t1_.sample(s1_, _e3); + return entry_point_twoOutput { _e4 }; +} + + +struct entry_point_threeOutput { + metal::float4 member_2 [[color(0)]]; +}; +fragment entry_point_threeOutput entry_point_three( + metal::texture2d t1_ [[texture(0)]] +, metal::texture2d t2_ [[texture(1)]] +, metal::sampler s2_ [[sampler(1)]] +, constant metal::float2& uniformOne [[buffer(0)]] +, constant metal::float2& uniformTwo [[buffer(1)]] +) { + constexpr metal::sampler s1_( + metal::s_address::clamp_to_edge, + metal::t_address::clamp_to_edge, + metal::r_address::clamp_to_edge, + metal::mag_filter::linear, + metal::min_filter::linear, + metal::coord::normalized + ); + metal::float2 _e3 = uniformTwo; + metal::float2 _e5 = uniformOne; + metal::float4 _e7 = t1_.sample(s1_, _e3 + _e5); + metal::float2 _e11 = uniformOne; + metal::float4 _e12 = t2_.sample(s2_, _e11); + return entry_point_threeOutput { _e7 + _e12 }; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 25f307c110..ed4a9bddec 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -90,8 +90,11 @@ struct Parameters { #[allow(unused_variables)] fn check_targets(module: &naga::Module, name: &str, targets: Targets) { let root = env!("CARGO_MANIFEST_DIR"); - let params = match fs::read_to_string(format!("{root}/{BASE_DIR_IN}/{name}.param.ron")) { - Ok(string) => ron::de::from_str(&string).expect("Couldn't parse param file"), + let filepath = format!("{root}/{BASE_DIR_IN}/{name}.param.ron"); + let params = match fs::read_to_string(&filepath) { + Ok(string) => { + ron::de::from_str(&string).expect(&format!("Couldn't parse param file: {}", filepath)) + } Err(_) => Parameters::default(), }; @@ -543,6 +546,7 @@ fn convert_wgsl() { "binding-arrays", Targets::WGSL | Targets::HLSL | Targets::METAL | Targets::SPIRV, ), + ("resource-binding-map", Targets::METAL), ("multiview", Targets::SPIRV | Targets::GLSL | Targets::WGSL), ("multiview_webgl", Targets::GLSL), (