Skip to content

[PEx] Correct inlining for functions returning a value #840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions Src/PCompiler/CompilerCore/Backend/PEx/TransformASTPass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Plang.Compiler.TypeChecker.AST.Expressions;
using Plang.Compiler.TypeChecker.AST.Statements;
using Plang.Compiler.TypeChecker.AST.States;
using Plang.Compiler.TypeChecker.Types;

namespace Plang.Compiler.Backend.PEx;

Expand Down Expand Up @@ -145,7 +146,7 @@ private static IStateAction TransformAction(IStateAction action, IDictionary<Fun
}
}

private static void GenerateInline(Function caller, Function callee, IReadOnlyList<IPExpr> argsList,
private static Variable GenerateInline(Function caller, Function callee, IReadOnlyList<IPExpr> argsList,
List<IPStmt> body, ParserRuleContext sourceLocation)
{
var newVarMap = new Dictionary<Variable, Variable>();
Expand All @@ -168,36 +169,76 @@ private static void GenerateInline(Function caller, Function callee, IReadOnlyLi
caller.AddLocalVariable(newVar);
}

var inVar = new Variable($"in_{callNum}_{callee.Name}", sourceLocation, VariableRole.Temp);
inVar.Type = PrimitiveType.Bool;
caller.AddLocalVariable(inVar);

foreach (var funStmt in callee.Body.Statements)
body.Add(ReplaceVars(funStmt, newVarMap));
callNum++;

return inVar;
}

private static List<IPStmt> ReplaceReturn(IReadOnlyList<IPStmt> body, IPExpr location)
private static bool CanReturn(IPStmt stmt)
{
if (stmt == null) return false;
switch (stmt)
{
case CompoundStmt compoundStmt:
foreach (var inner in compoundStmt.Statements)
if (CanReturn(inner))
return true;
return false;
case IfStmt ifStmt:
return CanReturn(ifStmt.ThenBranch) || CanReturn(ifStmt.ElseBranch);
case WhileStmt whileStmt:
return CanReturn(whileStmt.Body);
case ReceiveStmt recv:
foreach (var c in recv.Cases)
{
if (c.Value.Body != null)
{
if (CanReturn(c.Value.Body))
return true;
}
}
return false;
case ReturnStmt:
return true;
default:
return false;
}
}

private static List<IPStmt> ReplaceReturn(IReadOnlyList<IPStmt> body, IPExpr location, Variable inVar, Function func)
{
var newBody = new List<IPStmt>();
foreach (var stmt in body)
switch (stmt)
{
case ReturnStmt returnStmt:
newBody.Add(new AssignStmt(returnStmt.SourceLocation, location, returnStmt.ReturnValue));
newBody.Add(new AssignStmt(returnStmt.SourceLocation,
new VariableAccessExpr(returnStmt.SourceLocation, inVar),
new BoolLiteralExpr(returnStmt.SourceLocation, false)));
break;
case CompoundStmt compoundStmt:
var replace = ReplaceReturn(compoundStmt.Statements, location);
var replace = ReplaceReturn(compoundStmt.Statements, location, inVar, func);
foreach (var statement in replace) newBody.Add(statement);
break;
case IfStmt ifStmt:
IPStmt thenStmt = null;
if (ifStmt.ThenBranch != null)
{
var replaceThen = ReplaceReturn(ifStmt.ThenBranch.Statements, location);
var replaceThen = ReplaceReturn(ifStmt.ThenBranch.Statements, location, inVar, func);
thenStmt = new CompoundStmt(ifStmt.ThenBranch.SourceLocation, replaceThen);
}

IPStmt elseStmt = null;
if (ifStmt.ElseBranch != null)
{
var replaceElse = ReplaceReturn(ifStmt.ElseBranch.Statements, location);
var replaceElse = ReplaceReturn(ifStmt.ElseBranch.Statements, location, inVar, func);
elseStmt = new CompoundStmt(ifStmt.ElseBranch.SourceLocation, replaceElse);
}

Expand All @@ -206,17 +247,27 @@ private static List<IPStmt> ReplaceReturn(IReadOnlyList<IPStmt> body, IPExpr loc
case ReceiveStmt receiveStmt:
foreach (var entry in receiveStmt.Cases)
{
if (CanReturn(entry.Value.Body))
{
throw new NotImplementedException($"Function with a return statement inside a receive-case isn't supported. Found in {func.Name}.");
}

entry.Value.Body = new CompoundStmt(entry.Value.Body.SourceLocation,
ReplaceReturn(entry.Value.Body.Statements, location));
ReplaceReturn(entry.Value.Body.Statements, location, inVar, func));
entry.Value.Signature.ReturnType = null;
}

newBody.Add(new ReceiveStmt(receiveStmt.SourceLocation, receiveStmt.Cases));
break;
case WhileStmt whileStmt:
if (CanReturn(whileStmt.Body))
{
throw new NotImplementedException($"Function with a return statement inside a loop isn't supported. Found in {func.Name}.");
}

var bodyList = new List<IPStmt>();
bodyList.Add(whileStmt.Body);
var replaceWhile = ReplaceReturn(bodyList, location);
var replaceWhile = ReplaceReturn(bodyList, location, inVar, func);
newBody.Add(new WhileStmt(whileStmt.SourceLocation, whileStmt.Condition,
new CompoundStmt(whileStmt.Body.SourceLocation, replaceWhile)));
break;
Expand All @@ -227,7 +278,7 @@ private static List<IPStmt> ReplaceReturn(IReadOnlyList<IPStmt> body, IPExpr loc

return newBody;
}

private static void InlineStmt(Function function, IPStmt stmt, List<IPStmt> body)
{
switch (stmt)
Expand All @@ -236,16 +287,25 @@ private static void InlineStmt(Function function, IPStmt stmt, List<IPStmt> body
if (assign.Value is FunCallExpr)
{
var rhsExpr = (FunCallExpr)assign.Value;
if (!rhsExpr.Function.IsForeign)
if (!rhsExpr.Function.IsForeign && (rhsExpr.Function.CanReceive || rhsExpr.Function.CanRaiseEvent || rhsExpr.Function.CanChangeState))
{
var inlined = InlineInFunction(rhsExpr.Function);
if (inlined)
function.RemoveCallee(rhsExpr.Function);
var appendToBody = new List<IPStmt>();
GenerateInline(function, rhsExpr.Function, rhsExpr.Arguments, appendToBody,
var inVar = GenerateInline(function, rhsExpr.Function, rhsExpr.Arguments, appendToBody,
assign.SourceLocation);
appendToBody = ReplaceReturn(appendToBody, assign.Location);
foreach (var statement in appendToBody) body.Add(statement);
body.Add(new AssignStmt(assign.SourceLocation,
new VariableAccessExpr(assign.SourceLocation, inVar),
new BoolLiteralExpr(assign.SourceLocation, true)));
appendToBody = ReplaceReturn(appendToBody, assign.Location, inVar, rhsExpr.Function);
foreach (var statement in appendToBody)
{
var inCond = new BinOpExpr(statement.SourceLocation, BinOpType.Eq,
new VariableAccessExpr(statement.SourceLocation, inVar),
new BoolLiteralExpr(statement.SourceLocation, true));
body.Add(new IfStmt(statement.SourceLocation, inCond, statement, null));
}
}
else
{
Expand All @@ -262,7 +322,7 @@ private static void InlineStmt(Function function, IPStmt stmt, List<IPStmt> body
foreach (var statement in compound.Statements) InlineStmt(function, statement, body);
break;
case FunCallStmt call:
if (!call.Function.IsForeign & (call.Function.CanReceive || call.Function.CanRaiseEvent || call.Function.CanChangeState))
if (!call.Function.IsForeign && (call.Function.CanReceive || call.Function.CanRaiseEvent || call.Function.CanChangeState))
{
var inlined = InlineInFunction(call.Function);
if (inlined)
Expand Down
Loading