Skip to content

Commit 12b2ed4

Browse files
authored
fix(forge): support preproc with try contract creation (#10498)
* fix(forge): support preproc with try contract creation * visit nested vars and statements of try stmt
1 parent e0ad278 commit 12b2ed4

File tree

2 files changed

+273
-40
lines changed

2 files changed

+273
-40
lines changed

crates/common/src/preprocessor/deps.rs

Lines changed: 96 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use foundry_compilers::Updates;
66
use itertools::Itertools;
77
use solar_parse::interface::Session;
88
use solar_sema::{
9-
hir::{ContractId, Expr, ExprKind, Hir, NamedArg, TypeKind, Visit},
9+
hir::{CallArgs, ContractId, Expr, ExprKind, Hir, NamedArg, Stmt, StmtKind, TypeKind, Visit},
1010
interface::{data_structures::Never, source_map::FileName, SourceMap},
1111
};
1212
use std::{
@@ -105,6 +105,8 @@ enum BytecodeDependencyKind {
105105
value: Option<String>,
106106
/// `salt` (if any) used when creating contract.
107107
salt: Option<String>,
108+
/// Whether it's a try contract creation statement.
109+
try_stmt: bool,
108110
},
109111
}
110112

@@ -182,42 +184,17 @@ impl<'hir> Visit<'hir> for BytecodeDependencyCollector<'hir> {
182184

183185
fn visit_expr(&mut self, expr: &'hir Expr<'hir>) -> ControlFlow<Self::BreakValue> {
184186
match &expr.kind {
185-
ExprKind::Call(ty, call_args, named_args) => {
186-
if let ExprKind::New(ty_new) = &ty.kind {
187-
if let TypeKind::Custom(item_id) = ty_new.kind {
188-
if let Some(contract_id) = item_id.as_contract() {
189-
let name_loc = span_to_range(self.source_map, ty_new.span);
190-
let name = &self.src[name_loc];
191-
192-
// Calculate offset to remove named args, e.g. for an expression like
193-
// `new Counter {value: 333} ( address(this))`
194-
// the offset will be used to replace `{value: 333} ( ` with `(`
195-
let call_args_offset = if named_args.is_some() && !call_args.is_empty()
196-
{
197-
(call_args.span().lo() - ty_new.span.hi()).to_usize()
198-
} else {
199-
0
200-
};
201-
202-
let args_len = expr.span.hi() - ty_new.span.hi();
203-
self.collect_dependency(BytecodeDependency {
204-
kind: BytecodeDependencyKind::New {
205-
name: name.to_string(),
206-
args_length: args_len.to_usize(),
207-
call_args_offset,
208-
value: named_arg(
209-
self.src,
210-
named_args,
211-
"value",
212-
self.source_map,
213-
),
214-
salt: named_arg(self.src, named_args, "salt", self.source_map),
215-
},
216-
loc: span_to_range(self.source_map, ty.span),
217-
referenced_contract: contract_id,
218-
});
219-
}
220-
}
187+
ExprKind::Call(call_expr, call_args, named_args) => {
188+
if let Some(dependency) = handle_call_expr(
189+
self.src,
190+
self.source_map,
191+
expr,
192+
call_expr,
193+
call_args,
194+
named_args,
195+
false,
196+
) {
197+
self.collect_dependency(dependency);
221198
}
222199
}
223200
ExprKind::Member(member_expr, ident) => {
@@ -239,6 +216,78 @@ impl<'hir> Visit<'hir> for BytecodeDependencyCollector<'hir> {
239216
}
240217
self.walk_expr(expr)
241218
}
219+
220+
fn visit_stmt(&mut self, stmt: &'hir Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
221+
if let StmtKind::Try(stmt_try) = stmt.kind {
222+
if let ExprKind::Call(call_expr, call_args, named_args) = stmt_try.expr.kind {
223+
if let Some(dependency) = handle_call_expr(
224+
self.src,
225+
self.source_map,
226+
&stmt_try.expr,
227+
call_expr,
228+
&call_args,
229+
&named_args,
230+
true,
231+
) {
232+
self.collect_dependency(dependency);
233+
for clause in stmt_try.clauses {
234+
for &var in clause.args {
235+
self.visit_nested_var(var)?;
236+
}
237+
for stmt in clause.block {
238+
self.visit_stmt(stmt)?;
239+
}
240+
}
241+
return ControlFlow::Continue(());
242+
}
243+
}
244+
}
245+
self.walk_stmt(stmt)
246+
}
247+
}
248+
249+
/// Helper function to analyze and extract bytecode dependency from a given call expression.
250+
fn handle_call_expr(
251+
src: &str,
252+
source_map: &SourceMap,
253+
parent_expr: &Expr<'_>,
254+
call_expr: &Expr<'_>,
255+
call_args: &CallArgs<'_>,
256+
named_args: &Option<&[NamedArg<'_>]>,
257+
try_stmt: bool,
258+
) -> Option<BytecodeDependency> {
259+
if let ExprKind::New(ty_new) = &call_expr.kind {
260+
if let TypeKind::Custom(item_id) = ty_new.kind {
261+
if let Some(contract_id) = item_id.as_contract() {
262+
let name_loc = span_to_range(source_map, ty_new.span);
263+
let name = &src[name_loc];
264+
265+
// Calculate offset to remove named args, e.g. for an expression like
266+
// `new Counter {value: 333} ( address(this))`
267+
// the offset will be used to replace `{value: 333} ( ` with `(`
268+
let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
269+
(call_args.span().lo() - ty_new.span.hi()).to_usize()
270+
} else {
271+
0
272+
};
273+
274+
let args_len = parent_expr.span.hi() - ty_new.span.hi();
275+
return Some(BytecodeDependency {
276+
kind: BytecodeDependencyKind::New {
277+
name: name.to_string(),
278+
args_length: args_len.to_usize(),
279+
call_args_offset,
280+
value: named_arg(src, named_args, "value", source_map),
281+
salt: named_arg(src, named_args, "salt", source_map),
282+
try_stmt,
283+
},
284+
loc: span_to_range(source_map, call_expr.span),
285+
referenced_contract: contract_id,
286+
})
287+
}
288+
}
289+
}
290+
None
242291
}
243292

244293
/// Helper function to extract value of a given named arg.
@@ -300,8 +349,14 @@ pub(crate) fn remove_bytecode_dependencies(
300349
call_args_offset,
301350
value,
302351
salt,
352+
try_stmt,
303353
} => {
304-
let mut update = format!("{name}(payable({vm}.deployCode({{");
354+
let (mut update, closing_seq) = if *try_stmt {
355+
(String::new(), "})")
356+
} else {
357+
(format!("{name}(payable("), "})))")
358+
};
359+
update.push_str(&format!("{vm}.deployCode({{"));
305360
update.push_str(&format!("_artifact: \"{artifact}\""));
306361

307362
if let Some(value) = value {
@@ -327,13 +382,14 @@ pub(crate) fn remove_bytecode_dependencies(
327382
update.push('(');
328383
}
329384
updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
385+
330386
updates.insert((
331387
dep.loc.end + args_length,
332388
dep.loc.end + args_length,
333-
")})))".to_string(),
389+
format!("){closing_seq}"),
334390
));
335391
} else {
336-
update.push_str("})))");
392+
update.push_str(closing_seq);
337393
updates.insert((dep.loc.start, dep.loc.end + args_length, update));
338394
}
339395
}

crates/forge/tests/cli/test_optimizer.rs

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,3 +1412,180 @@ Ran 1 test suite [ELAPSED]: 1 tests passed, 0 failed, 0 skipped (1 total tests)
14121412
14131413
"#]]);
14141414
});
1415+
1416+
// <https://github.com/foundry-rs/foundry/issues/10492>
1417+
// Preprocess test contracts with try constructor statements.
1418+
forgetest_init!(preprocess_contract_with_try_ctor_stmt, |prj, cmd| {
1419+
prj.wipe_contracts();
1420+
prj.update_config(|config| {
1421+
config.dynamic_test_linking = true;
1422+
});
1423+
1424+
prj.add_source(
1425+
"CounterA.sol",
1426+
r#"
1427+
contract CounterA {
1428+
uint256 number;
1429+
}
1430+
"#,
1431+
)
1432+
.unwrap();
1433+
prj.add_source(
1434+
"CounterB.sol",
1435+
r#"
1436+
contract CounterB {
1437+
uint256 number;
1438+
constructor(uint256 a) payable {
1439+
require(a > 0, "ctor failure");
1440+
number = a;
1441+
}
1442+
}
1443+
"#,
1444+
)
1445+
.unwrap();
1446+
prj.add_source(
1447+
"CounterC.sol",
1448+
r#"
1449+
contract CounterC {
1450+
uint256 number;
1451+
constructor(uint256 a) {
1452+
require(a > 0, "ctor failure");
1453+
number = a;
1454+
}
1455+
}
1456+
"#,
1457+
)
1458+
.unwrap();
1459+
1460+
prj.add_test(
1461+
"Counter.t.sol",
1462+
r#"
1463+
import {Test} from "forge-std/Test.sol";
1464+
import {CounterA} from "../src/CounterA.sol";
1465+
import {CounterB} from "../src/CounterB.sol";
1466+
import {CounterC} from "../src/CounterC.sol";
1467+
1468+
contract CounterTest is Test {
1469+
function test_try_counterA_creation() public {
1470+
try new CounterA() {} catch {
1471+
revert();
1472+
}
1473+
}
1474+
1475+
function test_try_counterB_creation() public {
1476+
try new CounterB(1) {} catch {
1477+
revert();
1478+
}
1479+
}
1480+
1481+
function test_try_counterB_creation_with_salt() public {
1482+
try new CounterB{value: 111, salt: bytes32("preprocess_counter_with_salt")}(1) {} catch {
1483+
revert();
1484+
}
1485+
}
1486+
1487+
function test_try_counterC_creation() public {
1488+
try new CounterC(2) {
1489+
new CounterC(1);
1490+
} catch {
1491+
revert();
1492+
}
1493+
}
1494+
}
1495+
"#,
1496+
)
1497+
.unwrap();
1498+
// All 23 files should properly compile, tests pass.
1499+
cmd.args(["test"]).with_no_redact().assert_success().stdout_eq(str![[r#"
1500+
...
1501+
Compiling 23 files with [..]
1502+
...
1503+
[PASS] test_try_counterA_creation() (gas: [..])
1504+
[PASS] test_try_counterB_creation() (gas: [..])
1505+
[PASS] test_try_counterB_creation_with_salt() (gas: [..])
1506+
[PASS] test_try_counterC_creation() (gas: [..])
1507+
...
1508+
1509+
"#]]);
1510+
1511+
// Change CounterB to fail test.
1512+
prj.add_source(
1513+
"CounterB.sol",
1514+
r#"
1515+
contract CounterB {
1516+
uint256 number;
1517+
constructor(uint256 a) payable {
1518+
require(a > 11, "ctor failure");
1519+
number = a;
1520+
}
1521+
}
1522+
"#,
1523+
)
1524+
.unwrap();
1525+
// Only CounterB should compile.
1526+
cmd.assert_failure().stdout_eq(str![[r#"
1527+
...
1528+
Compiling 1 files with [..]
1529+
...
1530+
[PASS] test_try_counterA_creation() (gas: [..])
1531+
[FAIL: EvmError: Revert] test_try_counterB_creation() (gas: [..])
1532+
[FAIL: EvmError: Revert] test_try_counterB_creation_with_salt() (gas: [..])
1533+
[PASS] test_try_counterC_creation() (gas: [..])
1534+
...
1535+
1536+
"#]]);
1537+
1538+
// Change CounterC to fail test in try statement.
1539+
prj.add_source(
1540+
"CounterC.sol",
1541+
r#"
1542+
contract CounterC {
1543+
uint256 number;
1544+
constructor(uint256 a) {
1545+
require(a > 1, "ctor failure");
1546+
number = a;
1547+
}
1548+
}
1549+
"#,
1550+
)
1551+
.unwrap();
1552+
// Only CounterC should compile.
1553+
cmd.assert_failure().stdout_eq(str![[r#"
1554+
...
1555+
Compiling 1 files with [..]
1556+
...
1557+
[PASS] test_try_counterA_creation() (gas: [..])
1558+
[FAIL: EvmError: Revert] test_try_counterB_creation() (gas: [..])
1559+
[FAIL: EvmError: Revert] test_try_counterB_creation_with_salt() (gas: [..])
1560+
[FAIL: ctor failure] test_try_counterC_creation() (gas: [..])
1561+
...
1562+
1563+
"#]]);
1564+
1565+
// Change CounterC to fail test in try statement.
1566+
prj.add_source(
1567+
"CounterC.sol",
1568+
r#"
1569+
contract CounterC {
1570+
uint256 number;
1571+
constructor(uint256 a) {
1572+
require(a > 2, "ctor failure");
1573+
number = a;
1574+
}
1575+
}
1576+
"#,
1577+
)
1578+
.unwrap();
1579+
// Only CounterC should compile and revert.
1580+
cmd.assert_failure().stdout_eq(str![[r#"
1581+
...
1582+
Compiling 1 files with [..]
1583+
...
1584+
[PASS] test_try_counterA_creation() (gas: [..])
1585+
[FAIL: EvmError: Revert] test_try_counterB_creation() (gas: [..])
1586+
[FAIL: EvmError: Revert] test_try_counterB_creation_with_salt() (gas: [..])
1587+
[FAIL: EvmError: Revert] test_try_counterC_creation() (gas: [..])
1588+
...
1589+
1590+
"#]]);
1591+
});

0 commit comments

Comments
 (0)