Skip to content

Commit e7f5fb6

Browse files
CST: Handle named function declarations (#230)
Reapplication of #225, looks like it was merged into the assignment branch and I accidentally force pushed over it --- This PR extends the CST support to handle named function declarations (function name() or local function name()). As part of this PR, we also correct the comma positioning that separates arguments in the function definition, by using the comma positions stored on the CST node.
1 parent 278d1d8 commit e7f5fb6

File tree

7 files changed

+84
-12
lines changed

7 files changed

+84
-12
lines changed

batteries/syntax/ast_types.luau

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,21 @@ export type AstStatCompoundAssign = {
266266
value: AstExpr,
267267
}
268268

269+
export type AstStatFunction = {
270+
tag: "function",
271+
["function"]: Token<"function">,
272+
name: AstExpr,
273+
body: AstFunctionBody,
274+
}
275+
276+
export type AstStatLocalFunction = {
277+
tag: "localfunction",
278+
["local"]: Token<"local">,
279+
["function"]: Token<"function">,
280+
name: AstLocal,
281+
body: AstFunctionBody,
282+
}
283+
269284
export type AstStat =
270285
| AstStatBlock
271286
| AstStatIf
@@ -280,5 +295,7 @@ export type AstStat =
280295
| AstStatForIn
281296
| AstStatAssign
282297
| AstStatCompoundAssign
298+
| AstStatFunction
299+
| AstStatLocalFunction
283300

284301
return {}

batteries/syntax/visitor.luau

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ type Visitor = {
1313
visitForIn: (T.AstStatForIn) -> boolean,
1414
visitAssign: (T.AstStatAssign) -> boolean,
1515
visitCompoundAssign: (T.AstStatCompoundAssign) -> boolean,
16+
visitFunction: (T.AstStatFunction) -> boolean,
17+
visitLocalFunction: (T.AstStatLocalFunction) -> boolean,
1618

1719
visitLocalReference: (T.AstExprLocal) -> boolean,
1820
visitGlobal: (T.AstExprGlobal) -> boolean,
@@ -50,6 +52,8 @@ local defaultVisitor: Visitor = {
5052
visitForIn = alwaysVisit :: any,
5153
visitAssign = alwaysVisit :: any,
5254
visitCompoundAssign = alwaysVisit :: any,
55+
visitFunction = alwaysVisit :: any,
56+
visitLocalFunction = alwaysVisit :: any,
5357

5458
visitLocalReference = alwaysVisit :: any,
5559
visitGlobal = alwaysVisit :: any,
@@ -290,6 +294,23 @@ local function visitAnonymousFunction(node: T.AstExprAnonymousFunction, visitor:
290294
end
291295
end
292296

297+
local function visitFunction(node: T.AstStatFunction, visitor: Visitor)
298+
if visitor.visitFunction(node) then
299+
visitToken(node["function"], visitor)
300+
visitExpression(node.name, visitor)
301+
visitFunctionBody(node.body, visitor)
302+
end
303+
end
304+
305+
local function visitLocalFunction(node: T.AstStatLocalFunction, visitor: Visitor)
306+
if visitor.visitLocalFunction(node) then
307+
visitToken(node["local"], visitor)
308+
visitToken(node["function"], visitor)
309+
visitLocal(node.name, visitor)
310+
visitFunctionBody(node.body, visitor)
311+
end
312+
end
313+
293314
local function visitTableItem(node: T.AstExprTableItem, visitor: Visitor)
294315
if visitor.visitTableItem(node) then
295316
if node.kind == "list" then
@@ -412,6 +433,10 @@ function visitStatement(statement: T.AstStat, visitor: Visitor)
412433
visitAssign(statement, visitor)
413434
elseif statement.tag == "compoundassign" then
414435
visitCompoundAssign(statement, visitor)
436+
elseif statement.tag == "function" then
437+
visitFunction(statement, visitor)
438+
elseif statement.tag == "localfunction" then
439+
visitLocalFunction(statement, visitor)
415440
else
416441
exhaustiveMatch(statement.tag)
417442
end

luau/src/luau.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ struct AstSerialize : public Luau::AstVisitor
307307
lua_pushstring(L, name.value);
308308
}
309309

310-
void serialize(Luau::AstLocal* local)
310+
void serialize(Luau::AstLocal* local, bool createToken = true)
311311
{
312312
lua_rawcheckstack(L, 2);
313313

@@ -324,8 +324,11 @@ struct AstSerialize : public Luau::AstVisitor
324324
lua_pushvalue(L, -2);
325325
lua_settable(L, localTableIndex);
326326

327-
serializeToken(local->location.begin, local->name.value);
328-
lua_setfield(L, -2, "name");
327+
if (createToken)
328+
{
329+
serializeToken(local->location.begin, local->name.value);
330+
lua_setfield(L, -2, "name");
331+
}
329332

330333
if (local->shadow)
331334
serialize(local->shadow);
@@ -789,11 +792,13 @@ struct AstSerialize : public Luau::AstVisitor
789792

790793
void serializeFunctionBody(Luau::AstExprFunction* node)
791794
{
795+
const auto* cstNode = lookupCstNode<Luau::CstExprFunction>(node);
796+
792797
lua_rawcheckstack(L, 3);
793798
lua_createtable(L, 0, 7);
794799

795800
if (node->self)
796-
serialize(node->self);
801+
serialize(node->self, /* createToken= */ false);
797802
else
798803
lua_pushnil(L);
799804
lua_setfield(L, -2, "self");
@@ -804,8 +809,7 @@ struct AstSerialize : public Luau::AstVisitor
804809
lua_setfield(L, -2, "openParens");
805810
}
806811

807-
// TODO: separators
808-
serializePunctuated(node->args, {}, ",");
812+
serializePunctuated(node->args, cstNode ? cstNode->argsCommaPositions : Luau::AstArray<Luau::Position>{}, ",");
809813
lua_setfield(L, -2, "parameters");
810814

811815
// TODO: generics, return types, etc.
@@ -1283,29 +1287,41 @@ struct AstSerialize : public Luau::AstVisitor
12831287
void serializeStat(Luau::AstStatFunction* node)
12841288
{
12851289
lua_rawcheckstack(L, 2);
1286-
lua_createtable(L, 0, preambleSize + 2);
1290+
lua_createtable(L, 0, preambleSize + 3);
12871291

12881292
serializeNodePreamble(node, "function");
12891293

1294+
serializeToken(node->location.begin, "function");
1295+
lua_setfield(L, -2, "function");
1296+
12901297
node->name->visit(this);
12911298
lua_setfield(L, -2, "name");
12921299

1293-
node->func->visit(this);
1294-
lua_setfield(L, -2, "function");
1300+
serializeFunctionBody(node->func);
1301+
lua_setfield(L, -2, "body");
12951302
}
12961303

12971304
void serializeStat(Luau::AstStatLocalFunction* node)
12981305
{
12991306
lua_rawcheckstack(L, 2);
1300-
lua_createtable(L, 0, preambleSize + 2);
1307+
lua_createtable(L, 0, preambleSize + 4);
13011308

13021309
serializeNodePreamble(node, "localfunction");
13031310

1311+
serializeToken(node->location.begin, "local");
1312+
lua_setfield(L, -2, "local");
1313+
1314+
if (const auto cstNode = lookupCstNode<Luau::CstStatLocalFunction>(node))
1315+
{
1316+
serializeToken(cstNode->functionKeywordPosition, "function");
1317+
lua_setfield(L, -2, "function");
1318+
}
1319+
13041320
serialize(node->name);
13051321
lua_setfield(L, -2, "name");
13061322

1307-
node->func->visit(this);
1308-
lua_setfield(L, -2, "function");
1323+
serializeFunctionBody(node->func);
1324+
lua_setfield(L, -2, "body");
13091325
}
13101326

13111327
void serializeStat(Luau::AstStatTypeAlias* node)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
function x()
2+
call()
3+
end
4+
5+
function x.y:z() end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-- stylua: ignore
2+
function x(x , y, z)
3+
end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
local function x()
2+
call(1)
3+
end

tests/testAstSerializer.spec.luau

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,10 @@ local function test_roundtrippableAst()
132132
"tests/astSerializerTests/assignment-1.luau",
133133
"tests/astSerializerTests/break-continue-1.luau",
134134
"tests/astSerializerTests/compound-assignment-1.luau",
135+
"tests/astSerializerTests/function-declaration-1.luau",
136+
"tests/astSerializerTests/function-declaration-2.luau",
135137
"tests/astSerializerTests/generic-for-loop-1.luau",
138+
"tests/astSerializerTests/local-function-declaration-1.luau",
136139
"tests/astSerializerTests/numeric-for-loop-1.luau",
137140
"tests/astSerializerTests/while-1.luau",
138141
"tests/astSerializerTests/repeat-until-1.luau",

0 commit comments

Comments
 (0)