Skip to content

Commit b9a6730

Browse files
committed
[naga] Ensure that FooResult expressions are correctly populated.
Make Naga module validation require that `CallResult` and `AtomicResult` expressions are indeed visited by exactly one `Call` / `Atomic` statement.
1 parent c745863 commit b9a6730

File tree

3 files changed

+145
-55
lines changed

3 files changed

+145
-55
lines changed

naga/src/valid/function.rs

+39-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ pub enum CallError {
2222
},
2323
#[error("Result expression {0:?} has already been introduced earlier")]
2424
ResultAlreadyInScope(Handle<crate::Expression>),
25+
#[error("Result expression {0:?} is populated by multiple `Call` statements")]
26+
ResultAlreadyPopulated(Handle<crate::Expression>),
2527
#[error("Result value is invalid")]
2628
ResultValue(#[source] ExpressionError),
2729
#[error("Requires {required} arguments, but {seen} are provided")]
@@ -45,6 +47,8 @@ pub enum AtomicError {
4547
InvalidOperand(Handle<crate::Expression>),
4648
#[error("Result type for {0:?} doesn't match the statement")]
4749
ResultTypeMismatch(Handle<crate::Expression>),
50+
#[error("Result expression {0:?} is populated by multiple `Atomic` statements")]
51+
ResultAlreadyPopulated(Handle<crate::Expression>),
4852
}
4953

5054
#[derive(Clone, Debug, thiserror::Error)]
@@ -174,6 +178,8 @@ pub enum FunctionError {
174178
InvalidSubgroup(#[from] SubgroupError),
175179
#[error("Emit statement should not cover \"result\" expressions like {0:?}")]
176180
EmitResult(Handle<crate::Expression>),
181+
#[error("Expression not visited by the appropriate statement")]
182+
UnvisitedExpression(Handle<crate::Expression>),
177183
}
178184

179185
bitflags::bitflags! {
@@ -317,7 +323,13 @@ impl super::Validator {
317323
}
318324
match context.expressions[expr] {
319325
crate::Expression::CallResult(callee)
320-
if fun.result.is_some() && callee == function => {}
326+
if fun.result.is_some() && callee == function =>
327+
{
328+
if !self.needs_visit.remove(expr.index()) {
329+
return Err(CallError::ResultAlreadyPopulated(expr)
330+
.with_span_handle(expr, context.expressions));
331+
}
332+
}
321333
_ => {
322334
return Err(CallError::ExpressionMismatch(result)
323335
.with_span_handle(expr, context.expressions))
@@ -409,7 +421,14 @@ impl super::Validator {
409421
}
410422
_ => false,
411423
}
412-
} => {}
424+
} =>
425+
{
426+
if !self.needs_visit.remove(result.index()) {
427+
return Err(AtomicError::ResultAlreadyPopulated(result)
428+
.with_span_handle(result, context.expressions)
429+
.into_other());
430+
}
431+
}
413432
_ => {
414433
return Err(AtomicError::ResultTypeMismatch(result)
415434
.with_span_handle(result, context.expressions)
@@ -1307,11 +1326,20 @@ impl super::Validator {
13071326

13081327
self.valid_expression_set.clear();
13091328
self.valid_expression_list.clear();
1329+
self.needs_visit.clear();
13101330
for (handle, expr) in fun.expressions.iter() {
13111331
if expr.needs_pre_emit() {
13121332
self.valid_expression_set.insert(handle.index());
13131333
}
13141334
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1335+
// Mark expressions that need to be visited by a particular kind of
1336+
// statement.
1337+
if let crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } =
1338+
*expr
1339+
{
1340+
self.needs_visit.insert(handle.index());
1341+
}
1342+
13151343
match self.validate_expression(
13161344
handle,
13171345
expr,
@@ -1338,6 +1366,15 @@ impl super::Validator {
13381366
)?
13391367
.stages;
13401368
info.available_stages &= stages;
1369+
1370+
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1371+
if let Some(unvisited) = self.needs_visit.iter().next() {
1372+
let index = std::num::NonZeroU32::new(unvisited as u32 + 1).unwrap();
1373+
let handle = Handle::new(index);
1374+
return Err(FunctionError::UnvisitedExpression(handle)
1375+
.with_span_handle(handle, &fun.expressions));
1376+
}
1377+
}
13411378
}
13421379
Ok(info)
13431380
}

naga/src/valid/mod.rs

+21
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,26 @@ pub struct Validator {
233233
valid_expression_set: BitSet,
234234
override_ids: FastHashSet<u16>,
235235
allow_overrides: bool,
236+
237+
/// A checklist of expressions that must be visited by a specific kind of
238+
/// statement.
239+
///
240+
/// For example:
241+
///
242+
/// - [`CallResult`] expressions must be visited by a [`Call`] statement.
243+
/// - [`AtomicResult`] expressions must be visited by an [`Atomic`] statement.
244+
///
245+
/// Be sure not to remove any [`Expression`] handle from this set unless
246+
/// you've explicitly checked that it is the right kind of expression for
247+
/// the visiting [`Statement`].
248+
///
249+
/// [`CallResult`]: crate::Expression::CallResult
250+
/// [`Call`]: crate::Statement::Call
251+
/// [`AtomicResult`]: crate::Expression::AtomicResult
252+
/// [`Atomic`]: crate::Statement::Atomic
253+
/// [`Expression`]: crate::Expression
254+
/// [`Statement`]: crate::Statement
255+
needs_visit: BitSet,
236256
}
237257

238258
#[derive(Clone, Debug, thiserror::Error)]
@@ -385,6 +405,7 @@ impl Validator {
385405
valid_expression_set: BitSet::new(),
386406
override_ids: FastHashSet::default(),
387407
allow_overrides: true,
408+
needs_visit: BitSet::new(),
388409
}
389410
}
390411

naga/tests/validation.rs

+85-53
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,30 @@
11
use naga::{valid, Expression, Function, Scalar};
22

3+
/// Validation should fail if `AtomicResult` expressions are not
4+
/// populated by `Atomic` statements.
35
#[test]
4-
fn emit_atomic_result() {
6+
fn populate_atomic_result() {
57
use naga::{Module, Type, TypeInner};
68

7-
// We want to ensure that the *only* problem with the code is the
8-
// use of an `Emit` statement instead of an `Atomic` statement. So
9-
// validate two versions of the module varying only in that
10-
// aspect.
11-
//
12-
// Looking at uses of the `atomic` makes it easy to identify the
13-
// differences between the two variants.
14-
fn variant(
15-
atomic: bool,
9+
/// Different variants of the test case that we want to exercise.
10+
enum Variant {
11+
/// An `AtomicResult` expression with an `Atomic` statement
12+
/// that populates it: valid.
13+
Atomic,
14+
15+
/// An `AtomicResult` expression visited by an `Emit`
16+
/// statement: invalid.
17+
Emit,
18+
19+
/// An `AtomicResult` expression visited by no statement at
20+
/// all: invalid
21+
None,
22+
}
23+
24+
// Looking at uses of `variant` should make it easy to identify
25+
// the differences between the test cases.
26+
fn try_variant(
27+
variant: Variant,
1628
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
1729
let span = naga::Span::default();
1830
let mut module = Module::default();
@@ -56,21 +68,25 @@ fn emit_atomic_result() {
5668
span,
5769
);
5870

59-
if atomic {
60-
fun.body.push(
61-
naga::Statement::Atomic {
62-
pointer: ex_global,
63-
fun: naga::AtomicFunction::Add,
64-
value: ex_42,
65-
result: ex_result,
66-
},
67-
span,
68-
);
69-
} else {
70-
fun.body.push(
71-
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
72-
span,
73-
);
71+
match variant {
72+
Variant::Atomic => {
73+
fun.body.push(
74+
naga::Statement::Atomic {
75+
pointer: ex_global,
76+
fun: naga::AtomicFunction::Add,
77+
value: ex_42,
78+
result: ex_result,
79+
},
80+
span,
81+
);
82+
}
83+
Variant::Emit => {
84+
fun.body.push(
85+
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
86+
span,
87+
);
88+
}
89+
Variant::None => {}
7490
}
7591

7692
module.functions.append(fun, span);
@@ -82,23 +98,34 @@ fn emit_atomic_result() {
8298
.validate(&module)
8399
}
84100

85-
variant(true).expect("module should validate");
86-
assert!(variant(false).is_err());
101+
try_variant(Variant::Atomic).expect("module should validate");
102+
assert!(try_variant(Variant::Emit).is_err());
103+
assert!(try_variant(Variant::None).is_err());
87104
}
88105

89106
#[test]
90-
fn emit_call_result() {
107+
fn populate_call_result() {
91108
use naga::{Module, Type, TypeInner};
92109

93-
// We want to ensure that the *only* problem with the code is the
94-
// use of an `Emit` statement instead of a `Call` statement. So
95-
// validate two versions of the module varying only in that
96-
// aspect.
97-
//
98-
// Looking at uses of the `call` makes it easy to identify the
99-
// differences between the two variants.
100-
fn variant(
101-
call: bool,
110+
/// Different variants of the test case that we want to exercise.
111+
enum Variant {
112+
/// A `CallResult` expression with an `Call` statement that
113+
/// populates it: valid.
114+
Call,
115+
116+
/// A `CallResult` expression visited by an `Emit` statement:
117+
/// invalid.
118+
Emit,
119+
120+
/// A `CallResult` expression visited by no statement at all:
121+
/// invalid
122+
None,
123+
}
124+
125+
// Looking at uses of `variant` should make it easy to identify
126+
// the differences between the test cases.
127+
fn try_variant(
128+
variant: Variant,
102129
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
103130
let span = naga::Span::default();
104131
let mut module = Module::default();
@@ -130,20 +157,24 @@ fn emit_call_result() {
130157
.expressions
131158
.append(Expression::CallResult(fun_callee), span);
132159

133-
if call {
134-
fun_caller.body.push(
135-
naga::Statement::Call {
136-
function: fun_callee,
137-
arguments: vec![],
138-
result: Some(ex_result),
139-
},
140-
span,
141-
);
142-
} else {
143-
fun_caller.body.push(
144-
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
145-
span,
146-
);
160+
match variant {
161+
Variant::Call => {
162+
fun_caller.body.push(
163+
naga::Statement::Call {
164+
function: fun_callee,
165+
arguments: vec![],
166+
result: Some(ex_result),
167+
},
168+
span,
169+
);
170+
}
171+
Variant::Emit => {
172+
fun_caller.body.push(
173+
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
174+
span,
175+
);
176+
}
177+
Variant::None => {}
147178
}
148179

149180
module.functions.append(fun_caller, span);
@@ -155,8 +186,9 @@ fn emit_call_result() {
155186
.validate(&module)
156187
}
157188

158-
variant(true).expect("should validate");
159-
assert!(variant(false).is_err());
189+
try_variant(Variant::Call).expect("should validate");
190+
assert!(try_variant(Variant::Emit).is_err());
191+
assert!(try_variant(Variant::None).is_err());
160192
}
161193

162194
#[test]

0 commit comments

Comments
 (0)