Skip to content

Commit f83ffed

Browse files
committed
[naga] Ensure that FooResult expressions are correctly populated.
Make Naga module validation require that `CallResult` and `AtomicResult` expressions are indeed visited by exactly one `Call` / `Atomic` statement.
1 parent badcaee commit f83ffed

File tree

3 files changed

+145
-55
lines changed

3 files changed

+145
-55
lines changed

naga/src/valid/function.rs

+39-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ pub enum CallError {
2222
},
2323
#[error("Result expression {0:?} has already been introduced earlier")]
2424
ResultAlreadyInScope(Handle<crate::Expression>),
25+
#[error("Result expression {0:?} is populated by multiple `Call` statements")]
26+
ResultAlreadyPopulated(Handle<crate::Expression>),
2527
#[error("Result value is invalid")]
2628
ResultValue(#[source] ExpressionError),
2729
#[error("Requires {required} arguments, but {seen} are provided")]
@@ -45,6 +47,8 @@ pub enum AtomicError {
4547
InvalidOperand(Handle<crate::Expression>),
4648
#[error("Result type for {0:?} doesn't match the statement")]
4749
ResultTypeMismatch(Handle<crate::Expression>),
50+
#[error("Result expression {0:?} is populated by multiple `Atomic` statements")]
51+
ResultAlreadyPopulated(Handle<crate::Expression>),
4852
}
4953

5054
#[derive(Clone, Debug, thiserror::Error)]
@@ -174,6 +178,8 @@ pub enum FunctionError {
174178
InvalidSubgroup(#[from] SubgroupError),
175179
#[error("Emit statement should not cover \"result\" expressions like {0:?}")]
176180
EmitResult(Handle<crate::Expression>),
181+
#[error("Expression not visited by the appropriate statement")]
182+
UnvisitedExpression(Handle<crate::Expression>),
177183
}
178184

179185
bitflags::bitflags! {
@@ -305,7 +311,13 @@ impl super::Validator {
305311
}
306312
match context.expressions[expr] {
307313
crate::Expression::CallResult(callee)
308-
if fun.result.is_some() && callee == function => {}
314+
if fun.result.is_some() && callee == function =>
315+
{
316+
if !self.needs_visit.remove(expr.index()) {
317+
return Err(CallError::ResultAlreadyPopulated(expr)
318+
.with_span_handle(expr, context.expressions));
319+
}
320+
}
309321
_ => {
310322
return Err(CallError::ExpressionMismatch(result)
311323
.with_span_handle(expr, context.expressions))
@@ -397,7 +409,14 @@ impl super::Validator {
397409
}
398410
_ => false,
399411
}
400-
} => {}
412+
} =>
413+
{
414+
if !self.needs_visit.remove(result.index()) {
415+
return Err(AtomicError::ResultAlreadyPopulated(result)
416+
.with_span_handle(result, context.expressions)
417+
.into_other());
418+
}
419+
}
401420
_ => {
402421
return Err(AtomicError::ResultTypeMismatch(result)
403422
.with_span_handle(result, context.expressions)
@@ -1290,11 +1309,20 @@ impl super::Validator {
12901309

12911310
self.valid_expression_set.clear();
12921311
self.valid_expression_list.clear();
1312+
self.needs_visit.clear();
12931313
for (handle, expr) in fun.expressions.iter() {
12941314
if expr.needs_pre_emit() {
12951315
self.valid_expression_set.insert(handle.index());
12961316
}
12971317
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1318+
// Mark expressions that need to be visited by a particular kind of
1319+
// statement.
1320+
if let crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } =
1321+
*expr
1322+
{
1323+
self.needs_visit.insert(handle.index());
1324+
}
1325+
12981326
match self.validate_expression(
12991327
handle,
13001328
expr,
@@ -1321,6 +1349,15 @@ impl super::Validator {
13211349
)?
13221350
.stages;
13231351
info.available_stages &= stages;
1352+
1353+
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1354+
if let Some(unvisited) = self.needs_visit.iter().next() {
1355+
let index = std::num::NonZeroU32::new(unvisited as u32 + 1).unwrap();
1356+
let handle = Handle::new(index);
1357+
return Err(FunctionError::UnvisitedExpression(handle)
1358+
.with_span_handle(handle, &fun.expressions));
1359+
}
1360+
}
13241361
}
13251362
Ok(info)
13261363
}

naga/src/valid/mod.rs

+21
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,26 @@ pub struct Validator {
246246
valid_expression_set: BitSet,
247247
override_ids: FastHashSet<u16>,
248248
allow_overrides: bool,
249+
250+
/// A checklist of expressions that must be visited by a specific kind of
251+
/// statement.
252+
///
253+
/// For example:
254+
///
255+
/// - [`CallResult`] expressions must be visited by a [`Call`] statement.
256+
/// - [`AtomicResult`] expressions must be visited by an [`Atomic`] statement.
257+
///
258+
/// Be sure not to remove any [`Expression`] handle from this set unless
259+
/// you've explicitly checked that it is the right kind of expression for
260+
/// the visiting [`Statement`].
261+
///
262+
/// [`CallResult`]: crate::Expression::CallResult
263+
/// [`Call`]: crate::Statement::Call
264+
/// [`AtomicResult`]: crate::Expression::AtomicResult
265+
/// [`Atomic`]: crate::Statement::Atomic
266+
/// [`Expression`]: crate::Expression
267+
/// [`Statement`]: crate::Statement
268+
needs_visit: BitSet,
249269
}
250270

251271
#[derive(Clone, Debug, thiserror::Error)]
@@ -398,6 +418,7 @@ impl Validator {
398418
valid_expression_set: BitSet::new(),
399419
override_ids: FastHashSet::default(),
400420
allow_overrides: true,
421+
needs_visit: BitSet::new(),
401422
}
402423
}
403424

naga/tests/validation.rs

+85-53
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,30 @@
11
use naga::{valid, Expression, Function, Scalar};
22

3+
/// Validation should fail if `AtomicResult` expressions are not
4+
/// populated by `Atomic` statements.
35
#[test]
4-
fn emit_atomic_result() {
6+
fn populate_atomic_result() {
57
use naga::{Module, Type, TypeInner};
68

7-
// We want to ensure that the *only* problem with the code is the
8-
// use of an `Emit` statement instead of an `Atomic` statement. So
9-
// validate two versions of the module varying only in that
10-
// aspect.
11-
//
12-
// Looking at uses of the `atomic` makes it easy to identify the
13-
// differences between the two variants.
14-
fn variant(
15-
atomic: bool,
9+
/// Different variants of the test case that we want to exercise.
10+
enum Variant {
11+
/// An `AtomicResult` expression with an `Atomic` statement
12+
/// that populates it: valid.
13+
Atomic,
14+
15+
/// An `AtomicResult` expression visited by an `Emit`
16+
/// statement: invalid.
17+
Emit,
18+
19+
/// An `AtomicResult` expression visited by no statement at
20+
/// all: invalid
21+
None,
22+
}
23+
24+
// Looking at uses of `variant` should make it easy to identify
25+
// the differences between the test cases.
26+
fn try_variant(
27+
variant: Variant,
1628
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
1729
let span = naga::Span::default();
1830
let mut module = Module::default();
@@ -56,21 +68,25 @@ fn emit_atomic_result() {
5668
span,
5769
);
5870

59-
if atomic {
60-
fun.body.push(
61-
naga::Statement::Atomic {
62-
pointer: ex_global,
63-
fun: naga::AtomicFunction::Add,
64-
value: ex_42,
65-
result: ex_result,
66-
},
67-
span,
68-
);
69-
} else {
70-
fun.body.push(
71-
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
72-
span,
73-
);
71+
match variant {
72+
Variant::Atomic => {
73+
fun.body.push(
74+
naga::Statement::Atomic {
75+
pointer: ex_global,
76+
fun: naga::AtomicFunction::Add,
77+
value: ex_42,
78+
result: ex_result,
79+
},
80+
span,
81+
);
82+
}
83+
Variant::Emit => {
84+
fun.body.push(
85+
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
86+
span,
87+
);
88+
}
89+
Variant::None => {}
7490
}
7591

7692
module.functions.append(fun, span);
@@ -82,23 +98,34 @@ fn emit_atomic_result() {
8298
.validate(&module)
8399
}
84100

85-
variant(true).expect("module should validate");
86-
assert!(variant(false).is_err());
101+
try_variant(Variant::Atomic).expect("module should validate");
102+
assert!(try_variant(Variant::Emit).is_err());
103+
assert!(try_variant(Variant::None).is_err());
87104
}
88105

89106
#[test]
90-
fn emit_call_result() {
107+
fn populate_call_result() {
91108
use naga::{Module, Type, TypeInner};
92109

93-
// We want to ensure that the *only* problem with the code is the
94-
// use of an `Emit` statement instead of a `Call` statement. So
95-
// validate two versions of the module varying only in that
96-
// aspect.
97-
//
98-
// Looking at uses of the `call` makes it easy to identify the
99-
// differences between the two variants.
100-
fn variant(
101-
call: bool,
110+
/// Different variants of the test case that we want to exercise.
111+
enum Variant {
112+
/// A `CallResult` expression with an `Call` statement that
113+
/// populates it: valid.
114+
Call,
115+
116+
/// A `CallResult` expression visited by an `Emit` statement:
117+
/// invalid.
118+
Emit,
119+
120+
/// A `CallResult` expression visited by no statement at all:
121+
/// invalid
122+
None,
123+
}
124+
125+
// Looking at uses of `variant` should make it easy to identify
126+
// the differences between the test cases.
127+
fn try_variant(
128+
variant: Variant,
102129
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
103130
let span = naga::Span::default();
104131
let mut module = Module::default();
@@ -130,20 +157,24 @@ fn emit_call_result() {
130157
.expressions
131158
.append(Expression::CallResult(fun_callee), span);
132159

133-
if call {
134-
fun_caller.body.push(
135-
naga::Statement::Call {
136-
function: fun_callee,
137-
arguments: vec![],
138-
result: Some(ex_result),
139-
},
140-
span,
141-
);
142-
} else {
143-
fun_caller.body.push(
144-
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
145-
span,
146-
);
160+
match variant {
161+
Variant::Call => {
162+
fun_caller.body.push(
163+
naga::Statement::Call {
164+
function: fun_callee,
165+
arguments: vec![],
166+
result: Some(ex_result),
167+
},
168+
span,
169+
);
170+
}
171+
Variant::Emit => {
172+
fun_caller.body.push(
173+
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
174+
span,
175+
);
176+
}
177+
Variant::None => {}
147178
}
148179

149180
module.functions.append(fun_caller, span);
@@ -155,8 +186,9 @@ fn emit_call_result() {
155186
.validate(&module)
156187
}
157188

158-
variant(true).expect("should validate");
159-
assert!(variant(false).is_err());
189+
try_variant(Variant::Call).expect("should validate");
190+
assert!(try_variant(Variant::Emit).is_err());
191+
assert!(try_variant(Variant::None).is_err());
160192
}
161193

162194
#[test]

0 commit comments

Comments
 (0)