Skip to content

codegen refactor #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added src/kirin/codegen/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions src/kirin/codegen/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from dataclasses import field, dataclass

from kirin import ir, interp

ValueType = TypeVar("ValueType")


@dataclass
class Frame(interp.FrameABC[ir.SSAValue | ir.Block, ValueType]):

Check failure on line 11 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

Too many type arguments provided for "FrameABC"; expected 1 but received 2 (reportInvalidTypeArguments)
ssa: dict[ir.SSAValue, ValueType] = field(default_factory=dict, kw_only=True)

Check failure on line 12 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

Type variable "ValueType" has no meaning in this context (reportGeneralTypeIssues)
block: dict[ir.Block, ValueType] = field(default_factory=dict, kw_only=True)

Check failure on line 13 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

Type variable "ValueType" has no meaning in this context (reportGeneralTypeIssues)

@abstractmethod
def new(self, key: ir.SSAValue | ir.Block) -> ValueType: ...

Check warning on line 16 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

TypeVar "ValueType" appears only once in generic function signature   Use "object" instead (reportInvalidTypeVarUse)

def get(self, key: ir.SSAValue | ir.Block) -> ValueType:

Check failure on line 18 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

Method "get" overrides class "FrameABC" in an incompatible manner   Return type mismatch: base method returns type "SSAValue | Block", override returns type "ValueType@get"     Type "ValueType@get" is not assignable to type "SSAValue | Block"       "object*" is not assignable to "SSAValue"       "object*" is not assignable to "Block" (reportIncompatibleMethodOverride)

Check warning on line 18 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

TypeVar "ValueType" appears only once in generic function signature   Use "object" instead (reportInvalidTypeVarUse)
if isinstance(key, ir.Block):
return self.__get_item(self.block, key)

Check failure on line 20 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

Type "ValueType" is not assignable to return type "ValueType@get"   Type "ValueType" is not assignable to type "ValueType@get" (reportReturnType)
else:
return self.__get_item(self.ssa, key)

Check failure on line 22 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

Type "ValueType" is not assignable to return type "ValueType@get"   Type "ValueType" is not assignable to type "ValueType@get" (reportReturnType)

KeyType = TypeVar("KeyType", bound=ir.SSAValue | ir.Block)

def __get_item(self, entries: dict[KeyType, ValueType], key: KeyType) -> ValueType:
value = entries.get(key, interp.Undefined)

Check failure on line 27 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

"Undefined" is not a known attribute of module "kirin.interp" (reportAttributeAccessIssue)
if interp.is_undefined(value):

Check failure on line 28 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

"is_undefined" is not a known attribute of module "kirin.interp" (reportAttributeAccessIssue)
value = self.new(key)
entries[key] = value
return value
return value

Check failure on line 32 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

Type "ValueType@__get_item | None" is not assignable to return type "ValueType@__get_item"   Type "ValueType@__get_item | None" is not assignable to type "ValueType@__get_item" (reportReturnType)

def set(self, key: ir.SSAValue | ir.Block, value: ValueType) -> None:

Check warning on line 34 in src/kirin/codegen/abc.py

View workflow job for this annotation

GitHub Actions / pyright

TypeVar "ValueType" appears only once in generic function signature   Use "object" instead (reportInvalidTypeVarUse)
if isinstance(key, ir.Block):
self.block[key] = value
else:
self.ssa[key] = value


FrameType = TypeVar("FrameType", bound=Frame)


@dataclass
class CodegenABC(interp.BaseInterpreter[FrameType, ValueType], ABC):
pass
8 changes: 4 additions & 4 deletions src/kirin/dialects/cf/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Branch
):
interp.writeln(frame, f"@goto {interp.block_id[stmt.successor]};")
frame.worklist.append(

Check failure on line 21 in src/kirin/dialects/cf/emit.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access attribute "worklist" for class "EmitStrFrame"   Attribute "worklist" is unknown (reportAttributeAccessIssue)
Successor(stmt.successor, frame.get_values(stmt.arguments))
)
return ()
Expand All @@ -29,24 +29,24 @@
):
cond = frame.get(stmt.cond)
interp.writeln(frame, f"if {cond}")
frame.indent += 1
frame.set_indent += 1
values = frame.get_values(stmt.then_arguments)
block_values = tuple(interp.ssa_id[x] for x in stmt.then_successor.args)
frame.set_values(stmt.then_successor.args, block_values)
for x, y in zip(block_values, values):
interp.writeln(frame, f"{x} = {y};")
interp.writeln(frame, f"@goto {interp.block_id[stmt.then_successor]};")
frame.indent -= 1
frame.set_indent -= 1
interp.writeln(frame, "else")
frame.indent += 1
frame.set_indent += 1

values = frame.get_values(stmt.else_arguments)
block_values = tuple(interp.ssa_id[x] for x in stmt.else_successor.args)
frame.set_values(stmt.else_successor.args, block_values)
for x, y in zip(block_values, values):
interp.writeln(frame, f"{x} = {y};")
interp.writeln(frame, f"@goto {interp.block_id[stmt.else_successor]};")
frame.indent -= 1
frame.set_indent -= 1
interp.writeln(frame, "end")

frame.worklist.append(
Expand Down
32 changes: 22 additions & 10 deletions src/kirin/dialects/func/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,29 @@ class JuliaMethodTable(MethodTable):

@impl(Function)
def emit_function(
self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Function
self, emit: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Function
):
fn_args = stmt.body.blocks[0].args[1:]
argnames = tuple(interp.ssa_id[arg] for arg in fn_args)
argtypes = tuple(interp.emit_attribute(x.type) for x in fn_args)
argnames = tuple(emit.ssa_id[arg] for arg in fn_args)
argtypes = tuple(emit.emit_attribute(x.type) for x in fn_args)
args = [f"{name}::{type}" for name, type in zip(argnames, argtypes)]
interp.write(f"function {stmt.sym_name}({', '.join(args)})")
frame.indent += 1
interp.run_ssacfg_region(frame, stmt.body, (stmt.sym_name,) + argnames)
frame.indent -= 1
interp.writeln(frame, "end")
emit.write(f"function {stmt.sym_name}({', '.join(args)})")
with frame.set_indent(1):
for block in stmt.body.blocks:
block_id = emit.block_id[block]
frame.block_ref[block] = block_id
emit.newline(frame)
emit.write(f"@label {block_id};")

for each_stmt in block.stmts:
results = emit.eval_stmt(frame, each_stmt)
if isinstance(results, tuple):
frame.set_values(each_stmt.results, results)
elif results is not None:
raise InterpreterError(
f"Unexpected result {results} from statement {each_stmt.name}"
)
emit.writeln(frame, "end")
return ()

@impl(Return)
Expand Down Expand Up @@ -66,9 +78,9 @@ def emit_lambda(
frame.set_values(stmt.body.blocks[0].args, (stmt.sym_name,) + args)
frame.captured[stmt.body.blocks[0].args[0]] = frame.get_values(stmt.captured)
interp.writeln(frame, f"function {stmt.sym_name}({', '.join(args[1:])})")
frame.indent += 1
frame.set_indent += 1
interp.run_ssacfg_region(frame, stmt.body, args)
frame.indent -= 1
frame.set_indent -= 1
interp.writeln(frame, "end")
return (stmt.sym_name,)

Expand Down
78 changes: 17 additions & 61 deletions src/kirin/emit/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from dataclasses import field, dataclass

from kirin import ir, interp
from kirin.worklist import WorkList

ValueType = TypeVar("ValueType")


@dataclass
class EmitFrame(interp.Frame[ValueType]):
worklist: WorkList[interp.Successor] = field(default_factory=WorkList)
block_ref: dict[ir.Block, ValueType] = field(default_factory=dict)


Expand All @@ -20,32 +18,26 @@ class EmitFrame(interp.Frame[ValueType]):
@dataclass
class EmitABC(interp.BaseInterpreter[FrameType, ValueType], ABC):

def run_callable_region(
self,
frame: FrameType,
code: ir.Statement,
region: ir.Region,
args: tuple[ValueType, ...],
) -> ValueType:
results = self.eval_stmt(frame, code)
if isinstance(results, tuple):
if len(results) == 0:
return self.void
elif len(results) == 1:
return results[0]
raise interp.InterpreterError(f"Unexpected results {results}")
def emit(self, code: ir.Statement | ir.Method) -> ValueType:
if isinstance(code, ir.Method):
code = code.code

def run_ssacfg_region(
self, frame: FrameType, region: ir.Region, args: tuple[ValueType, ...]
) -> tuple[ValueType, ...]:
frame.worklist.append(interp.Successor(region.blocks[0], *args))
while (succ := frame.worklist.pop()) is not None:
frame.set_values(succ.block.args, succ.block_args)
block_header = self.emit_block(frame, succ.block)
frame.block_ref[succ.block] = block_header
return ()
with self.new_frame(code) as frame:
result = self.eval_stmt(frame, code)
if result is None:
return self.void
elif isinstance(result, tuple) and len(result) == 1:
return result[0]
raise interp.InterpreterError(
f"Unexpected result {result} from statement {code.name}"
)

def emit_attribute(self, attr: ir.Attribute) -> ValueType:
if attr.dialect not in self.dialects:
raise interp.InterpreterError(
f"Attribute {attr} not in dialects {self.dialects}"
)

return getattr(
self, f"emit_type_{type(attr).__name__}", self.emit_attribute_fallback
)(attr)
Expand All @@ -54,39 +46,3 @@ def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType:
if (method := self.registry.attributes.get(type(attr))) is not None:
return method(self, attr)
raise NotImplementedError(f"Attribute {type(attr)} not implemented")

def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None:
return

def emit_stmt_end(self, frame: FrameType, stmt: ir.Statement) -> None:
return

def emit_block_begin(self, frame: FrameType, block: ir.Block) -> None:
return

def emit_block_end(self, frame: FrameType, block: ir.Block) -> None:
return

def emit_block(self, frame: FrameType, block: ir.Block) -> ValueType:
self.emit_block_begin(frame, block)
stmt = block.first_stmt
while stmt is not None:
if self.consume_fuel() == self.FuelResult.Stop:
raise interp.FuelExhaustedError("fuel exhausted")

self.emit_stmt_begin(frame, stmt)
stmt_results = self.eval_stmt(frame, stmt)
self.emit_stmt_end(frame, stmt)

match stmt_results:
case tuple(values):
frame.set_values(stmt._results, values)
case interp.ReturnValue(_) | interp.YieldValue(_):
pass
case _:
raise ValueError(f"Unexpected result {stmt_results}")

stmt = stmt.next_stmt

self.emit_block_end(frame, block)
return frame.block_ref[block]
18 changes: 1 addition & 17 deletions src/kirin/emit/abc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,17 @@ from typing import TypeVar
from dataclasses import field, dataclass

from kirin import ir, types, interp
from kirin.worklist import WorkList

ValueType = TypeVar("ValueType")

@dataclass
class EmitFrame(interp.Frame[ValueType]):
worklist: WorkList[interp.Successor] = field(default_factory=WorkList)
block_ref: dict[ir.Block, ValueType] = field(default_factory=dict)

FrameType = TypeVar("FrameType", bound=EmitFrame)

class EmitABC(interp.BaseInterpreter[FrameType, ValueType]):
def run_callable_region(
self,
frame: FrameType,
code: ir.Statement,
region: ir.Region,
args: tuple[ValueType, ...],
) -> ValueType: ...
def run_ssacfg_region(
self, frame: FrameType, region: ir.Region, args: tuple[ValueType, ...]
) -> tuple[ValueType, ...]: ...
def emit(self, code: ir.Statement) -> ValueType: ...
def emit_attribute(self, attr: ir.Attribute) -> ValueType: ...
def emit_type_Any(self, attr: types.AnyType) -> ValueType: ...
def emit_type_Bottom(self, attr: types.BottomType) -> ValueType: ...
Expand All @@ -35,8 +24,3 @@ class EmitABC(interp.BaseInterpreter[FrameType, ValueType]):
def emit_type_PyClass(self, attr: types.PyClass) -> ValueType: ...
def emit_type_PyAttr(self, attr: ir.PyAttr) -> ValueType: ...
def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType: ...
def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None: ...
def emit_stmt_end(self, frame: FrameType, stmt: ir.Statement) -> None: ...
def emit_block_begin(self, frame: FrameType, block: ir.Block) -> None: ...
def emit_block_end(self, frame: FrameType, block: ir.Block) -> None: ...
def emit_block(self, frame: FrameType, block: ir.Block) -> ValueType: ...
41 changes: 40 additions & 1 deletion src/kirin/emit/str.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
from abc import ABC
from typing import IO, Generic, TypeVar
from contextlib import contextmanager
from dataclasses import field, dataclass

from kirin import ir, interp, idtable
from kirin.emit.abc import EmitABC, EmitFrame

from .exceptions import EmitError

IO_t = TypeVar("IO_t", bound=IO)


@dataclass
class EmitStrFrame(EmitFrame[str]):
indent: int = 0
ssa_id: idtable.IdTable[ir.SSAValue] = field(
default_factory=lambda: idtable.IdTable(prefix="", prefix_if_none="var_")
)
captured: dict[ir.SSAValue, tuple[str, ...]] = field(default_factory=dict)

@contextmanager
def set_indent(self, indent: int):
self.indent += indent
try:
yield
finally:
self.indent -= indent


@dataclass
class EmitStr(EmitABC[EmitStrFrame, str], ABC, Generic[IO_t]):
void = ""
file: IO_t
prefix: str = field(default="", kw_only=True)
prefix_if_none: str = field(default="var_", kw_only=True)

Expand All @@ -41,6 +54,32 @@ def run_method(
raise interp.InterpreterError("maximum recursion depth exceeded")
return self.run_callable(method.code, (method.sym_name,) + args)

def run_callable_region(
self,
frame: EmitStrFrame,
code: ir.Statement,
region: ir.Region,
args: tuple[str, ...],
) -> str:
lines = []
for block in region.blocks:
block_id = self.block_id[block]
frame.block_ref[block] = block_id
with frame.set_indent(1):
self.run_succ(
frame, interp.Successor(block, frame.get_values(block.args))
)
self.write(f"@label {block_id};")

for each_stmt in block.stmts:
results = self.eval_stmt(frame, each_stmt)
if isinstance(results, tuple):
frame.set_values(each_stmt.results, results)
elif results is not None:
raise EmitError(
f"Unexpected result {results} from statement {each_stmt.name}"
)

def write(self, *args):
for arg in args:
self.file.write(arg)
Expand Down
12 changes: 12 additions & 0 deletions src/kirin/emit/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import TypeVar

from kirin import ir, interp

from .abc import EmitABC, EmitFrame

ValueType = TypeVar("ValueType")
FrameType = TypeVar("FrameType", bound=interp.Frame)


class Transform(EmitABC[FrameType, ir.IRNode]):
pass
Loading