Skip to content

Commit 50ce3f9

Browse files
adonovancopybara-github
authored andcommitted
starlark: allow lambda expressions
This change introduces lambda expressions, following Python, as a shorthand for declaring anonymous functions whose body is a single expression. RELNOTES: Starlark now supports lambda (anonymous function) expressions. PiperOrigin-RevId: 346583352
1 parent 337e717 commit 50ce3f9

File tree

13 files changed

+250
-47
lines changed

13 files changed

+250
-47
lines changed

src/main/java/com/google/devtools/build/lib/packages/PackageFactory.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
import net.starlark.java.syntax.Identifier;
7676
import net.starlark.java.syntax.IfStatement;
7777
import net.starlark.java.syntax.IntLiteral;
78+
import net.starlark.java.syntax.LambdaExpression;
7879
import net.starlark.java.syntax.ListExpression;
7980
import net.starlark.java.syntax.Location;
8081
import net.starlark.java.syntax.NodeVisitor;
@@ -1003,7 +1004,15 @@ void recordGeneratorName(CallExpression call) {
10031004
public void visit(DefStatement node) {
10041005
error(
10051006
node.getStartLocation(),
1006-
"function definitions are not allowed in BUILD files. You may move the function to "
1007+
"functions may not be defined in BUILD files. You may move the function to "
1008+
+ "a .bzl file and load it.");
1009+
}
1010+
1011+
@Override
1012+
public void visit(LambdaExpression node) {
1013+
error(
1014+
node.getStartLocation(),
1015+
"functions may not be defined in BUILD files. You may move the function to "
10071016
+ "a .bzl file and load it.");
10081017
}
10091018

src/main/java/net/starlark/java/eval/Eval.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import net.starlark.java.syntax.IfStatement;
4242
import net.starlark.java.syntax.IndexExpression;
4343
import net.starlark.java.syntax.IntLiteral;
44+
import net.starlark.java.syntax.LambdaExpression;
4445
import net.starlark.java.syntax.ListExpression;
4546
import net.starlark.java.syntax.LoadStatement;
4647
import net.starlark.java.syntax.Location;
@@ -148,10 +149,8 @@ private static TokenKind execFor(StarlarkThread.Frame fr, ForStatement node)
148149
return TokenKind.PASS;
149150
}
150151

151-
private static void execDef(StarlarkThread.Frame fr, DefStatement node)
152+
private static StarlarkFunction newFunction(StarlarkThread.Frame fr, Resolver.Function rfn)
152153
throws EvalException, InterruptedException {
153-
Resolver.Function rfn = node.getResolvedFunction();
154-
155154
// Evaluate default value expressions of optional parameters.
156155
// We use MANDATORY to indicate a required parameter
157156
// (not null, because defaults must be a legal tuple value, as
@@ -196,11 +195,8 @@ private static void execDef(StarlarkThread.Frame fr, DefStatement node)
196195
// Nested functions use the same globalIndex as their enclosing function,
197196
// since both were compiled from the same Program.
198197
StarlarkFunction fn = fn(fr);
199-
assignIdentifier(
200-
fr,
201-
node.getIdentifier(),
202-
new StarlarkFunction(
203-
rfn, fn.getModule(), fn.globalIndex, Tuple.wrap(defaults), Tuple.wrap(freevars)));
198+
return new StarlarkFunction(
199+
rfn, fn.getModule(), fn.globalIndex, Tuple.wrap(defaults), Tuple.wrap(freevars));
204200
}
205201

206202
private static TokenKind execIf(StarlarkThread.Frame fr, IfStatement node)
@@ -289,7 +285,9 @@ private static TokenKind exec(StarlarkThread.Frame fr, Statement st)
289285
case FOR:
290286
return execFor(fr, (ForStatement) st);
291287
case DEF:
292-
execDef(fr, (DefStatement) st);
288+
DefStatement def = (DefStatement) st;
289+
StarlarkFunction fn = newFunction(fr, def.getResolvedFunction());
290+
assignIdentifier(fr, def.getIdentifier(), fn);
293291
return TokenKind.PASS;
294292
case IF:
295293
return execIf(fr, (IfStatement) st);
@@ -481,6 +479,8 @@ private static Object eval(StarlarkThread.Frame fr, Expression expr)
481479
}
482480
case FLOAT_LITERAL:
483481
return StarlarkFloat.of(((FloatLiteral) expr).getValue());
482+
case LAMBDA:
483+
return newFunction(fr, ((LambdaExpression) expr).getResolvedFunction());
484484
case LIST_EXPR:
485485
return evalList(fr, (ListExpression) expr);
486486
case SLICE:

src/main/java/net/starlark/java/eval/StarlarkFunction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ public Location getLocation() {
139139
}
140140

141141
/**
142-
* Returns the name of the function. Implicit functions (those not created by a def statement),
143-
* may have names such as "<toplevel>" or "<expr>".
142+
* Returns the name of the function, or "lambda" if anonymous. Implicit functions (those not
143+
* created by a def statement), may have names such as "<toplevel>" or "<expr>".
144144
*/
145145
@Override
146146
public String getName() {

src/main/java/net/starlark/java/syntax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ java_library(
3535
"IfStatement.java",
3636
"IndexExpression.java",
3737
"IntLiteral.java",
38+
"LambdaExpression.java",
3839
"Lexer.java",
3940
"ListExpression.java",
4041
"LoadStatement.java",

src/main/java/net/starlark/java/syntax/Expression.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public enum Kind {
3939
IDENTIFIER,
4040
INDEX,
4141
INT_LITERAL,
42+
LAMBDA,
4243
LIST_EXPR,
4344
SLICE,
4445
STRING_LITERAL,
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright 2020 The Bazel Authors. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
package net.starlark.java.syntax;
15+
16+
import com.google.common.base.Preconditions;
17+
import com.google.common.collect.ImmutableList;
18+
import javax.annotation.Nullable;
19+
20+
/** A LambdaExpression ({@code lambda params: body}) denotes an anonymous function. */
21+
public final class LambdaExpression extends Expression {
22+
23+
private final int lambdaOffset; // offset of 'lambda' token
24+
private final ImmutableList<Parameter> parameters;
25+
private final Expression body;
26+
27+
// set by resolver
28+
@Nullable private Resolver.Function resolved;
29+
30+
LambdaExpression(
31+
FileLocations locs, int lambdaOffset, ImmutableList<Parameter> parameters, Expression body) {
32+
super(locs);
33+
this.lambdaOffset = lambdaOffset;
34+
this.parameters = Preconditions.checkNotNull(parameters);
35+
this.body = Preconditions.checkNotNull(body);
36+
}
37+
38+
public ImmutableList<Parameter> getParameters() {
39+
return parameters;
40+
}
41+
42+
public Expression getBody() {
43+
return body;
44+
}
45+
46+
/** Returns information about the resolved function. Set by the resolver. */
47+
@Nullable
48+
public Resolver.Function getResolvedFunction() {
49+
return resolved;
50+
}
51+
52+
void setResolvedFunction(Resolver.Function resolved) {
53+
this.resolved = resolved;
54+
}
55+
56+
@Override
57+
public int getStartOffset() {
58+
return lambdaOffset;
59+
}
60+
61+
@Override
62+
public int getEndOffset() {
63+
return body.getEndOffset();
64+
}
65+
66+
@Override
67+
public void accept(NodeVisitor visitor) {
68+
visitor.visit(this);
69+
}
70+
71+
@Override
72+
public Kind kind() {
73+
return Kind.LAMBDA;
74+
}
75+
}

src/main/java/net/starlark/java/syntax/NodePrinter.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,21 @@ private void printExpr(Expression expr) {
368368
break;
369369
}
370370

371+
case LAMBDA:
372+
{
373+
LambdaExpression lambda = (LambdaExpression) expr;
374+
buf.append("lambda");
375+
String sep = " ";
376+
for (Parameter param : lambda.getParameters()) {
377+
buf.append(sep);
378+
sep = ", ";
379+
printParameter(param);
380+
}
381+
buf.append(": ");
382+
printExpr(lambda.getBody());
383+
break;
384+
}
385+
371386
case LIST_EXPR:
372387
{
373388
ListExpression list = (ListExpression) expr;

src/main/java/net/starlark/java/syntax/NodeVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ public void visit(IndexExpression node) {
170170
visit(node.getKey());
171171
}
172172

173+
public void visit(LambdaExpression node) {
174+
visitAll(node.getParameters());
175+
visit(node.getBody());
176+
}
177+
173178
public void visit(SliceExpression node) {
174179
visit(node.getObject());
175180
if (node.getStart() != null) {

src/main/java/net/starlark/java/syntax/Parser.java

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,6 @@ private int syncTo(EnumSet<TokenKind> terminatingTokens) {
340340
TokenKind.GLOBAL,
341341
TokenKind.IMPORT,
342342
TokenKind.IS,
343-
TokenKind.LAMBDA,
344343
TokenKind.NONLOCAL,
345344
TokenKind.RAISE,
346345
TokenKind.TRY,
@@ -360,7 +359,6 @@ private void checkForbiddenKeywords() {
360359
break;
361360
case IMPORT: error = "'import' not supported, use 'load' instead"; break;
362361
case IS: error = "'is' not supported, use '==' instead"; break;
363-
case LAMBDA: error = "'lambda' not supported, declare a function instead"; break;
364362
case RAISE: error = "'raise' not supported, use 'fail' instead"; break;
365363
case TRY: error = "'try' not supported, all exceptions are fatal"; break;
366364
case WHILE: error = "'while' not supported, use 'for' instead"; break;
@@ -432,7 +430,7 @@ private Argument parseArgument() {
432430

433431
// arg = IDENTIFIER '=' test
434432
// | IDENTIFIER
435-
private Parameter parseFunctionParameter() {
433+
private Parameter parseParameter() {
436434
// **kwargs
437435
if (token.kind == TokenKind.STAR_STAR) {
438436
int starStarOffset = nextToken();
@@ -752,7 +750,7 @@ private Expression parseComprehensionSuffix(int loffset, Node body, TokenKind cl
752750
int ifOffset = nextToken();
753751
// [x for x in li if 1, 2] # parse error
754752
// [x for x in li if (1, 2)] # ok
755-
Expression cond = parseTest(0);
753+
Expression cond = parseTestNoCond();
756754
clauses.add(new Comprehension.If(locs, ifOffset, cond));
757755
} else if (token.kind == closingBracket) {
758756
break;
@@ -928,6 +926,10 @@ private Expression optimizeBinOpExpression(
928926
// Parses a non-tuple expression ("test" in Python terminology).
929927
private Expression parseTest() {
930928
int start = token.start;
929+
if (token.kind == TokenKind.LAMBDA) {
930+
return parseLambda(/*allowCond=*/ true);
931+
}
932+
931933
Expression expr = parseTest(0);
932934
if (token.kind == TokenKind.IF) {
933935
nextToken();
@@ -954,6 +956,25 @@ private Expression parseTest(int prec) {
954956
return parseBinOpExpression(prec);
955957
}
956958

959+
// parseLambda parses a lambda expression.
960+
// The allowCond flag allows the body to be an 'a if b else c' conditional.
961+
private LambdaExpression parseLambda(boolean allowCond) {
962+
int lambdaOffset = expect(TokenKind.LAMBDA);
963+
ImmutableList<Parameter> params = parseParameters();
964+
expect(TokenKind.COLON);
965+
Expression body = allowCond ? parseTest() : parseTestNoCond();
966+
return new LambdaExpression(locs, lambdaOffset, params, body);
967+
}
968+
969+
// parseTestNoCond parses a a single-component expression without
970+
// consuming a trailing 'if expr else expr'.
971+
private Expression parseTestNoCond() {
972+
if (token.kind == TokenKind.LAMBDA) {
973+
return parseLambda(/*allowCond=*/ false);
974+
}
975+
return parseTest(0);
976+
}
977+
957978
// not_expr = 'not' expr
958979
private Expression parseNotExpression(int prec) {
959980
int notOffset = expect(TokenKind.NOT);
@@ -1184,15 +1205,17 @@ private ImmutableList<Parameter> parseParameters() {
11841205
boolean hasParam = false;
11851206
ImmutableList.Builder<Parameter> list = ImmutableList.builder();
11861207

1187-
while (token.kind != TokenKind.RPAREN && token.kind != TokenKind.EOF) {
1208+
while (token.kind != TokenKind.RPAREN
1209+
&& token.kind != TokenKind.COLON
1210+
&& token.kind != TokenKind.EOF) {
11881211
if (hasParam) {
11891212
expect(TokenKind.COMMA);
11901213
// The list may end with a comma.
11911214
if (token.kind == TokenKind.RPAREN) {
11921215
break;
11931216
}
11941217
}
1195-
Parameter param = parseFunctionParameter();
1218+
Parameter param = parseParameter();
11961219
hasParam = true;
11971220
list.add(param);
11981221
}

src/main/java/net/starlark/java/syntax/Resolver.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ public static Module moduleWithPredeclared(String... names) {
326326

327327
private static class Block {
328328
@Nullable private final Block parent; // enclosing block, or null for tail of list
329-
@Nullable Node syntax; // Comprehension, DefStatement, StarlarkFile, or null
329+
@Nullable Node syntax; // Comprehension, DefStatement/LambdaExpression, StarlarkFile, or null
330330
private final ArrayList<Binding> frame; // accumulated locals of enclosing function
331331
// Accumulated CELL/FREE bindings of the enclosing function that will provide
332332
// the values for the free variables of this function; see Function.getFreeVars.
@@ -554,7 +554,8 @@ private static Binding lookupLexical(String name, Block b) {
554554
// This step may occur many times if the lookupLexical
555555
// recursion returns through many functions.
556556
// TODO(adonovan): make this 'DEF or LAMBDA' when we have lambda.
557-
if (bind != null && b.syntax instanceof DefStatement) {
557+
if (bind != null
558+
&& (b.syntax instanceof DefStatement || b.syntax instanceof LambdaExpression)) {
558559
Scope scope = bind.getScope();
559560
if (scope == Scope.LOCAL || scope == Scope.FREE || scope == Scope.CELL) {
560561
if (scope == Scope.LOCAL) {
@@ -717,8 +718,20 @@ public void visit(DefStatement node) {
717718
node.getBody()));
718719
}
719720

721+
@Override
722+
public void visit(LambdaExpression expr) {
723+
expr.setResolvedFunction(
724+
resolveFunction(
725+
expr,
726+
"lambda",
727+
expr.getStartLocation(),
728+
expr.getParameters(),
729+
ImmutableList.of(ReturnStatement.make(expr.getBody()))));
730+
}
731+
732+
// Common code for def, lambda.
720733
private Function resolveFunction(
721-
DefStatement def,
734+
Node syntax, // DefStatement or LambdaExpression
722735
String name,
723736
Location loc,
724737
ImmutableList<Parameter> parameters,
@@ -734,7 +747,7 @@ private Function resolveFunction(
734747
// Enter function block.
735748
ArrayList<Binding> frame = new ArrayList<>();
736749
ArrayList<Binding> freevars = new ArrayList<>();
737-
pushLocalBlock(def, frame, freevars);
750+
pushLocalBlock(syntax, frame, freevars);
738751

739752
// Check parameter order and convert to run-time order:
740753
// positionals, keyword-only, *args, **kwargs.

src/test/java/com/google/devtools/build/lib/packages/PackageFactoryTest.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1154,7 +1154,14 @@ public void testGlobPatternExtractor() {
11541154
public void testDefInBuild() throws Exception {
11551155
checkBuildDialectError(
11561156
"def func(): pass", //
1157-
"function definitions are not allowed in BUILD files");
1157+
"functions may not be defined in BUILD files");
1158+
}
1159+
1160+
@Test
1161+
public void testLambdaInBuild() throws Exception {
1162+
checkBuildDialectError(
1163+
"lambda: None", //
1164+
"functions may not be defined in BUILD files");
11581165
}
11591166

11601167
@Test

0 commit comments

Comments
 (0)