From 78966205536b6e94cc61308b699117847f10b2b7 Mon Sep 17 00:00:00 2001 From: Roger-luo Date: Mon, 21 Apr 2025 18:41:07 -0400 Subject: [PATCH 1/2] feat: refactor emit --- src/kirin/emit/abc.py | 78 +++++++++--------------------------------- src/kirin/emit/abc.pyi | 16 +-------- 2 files changed, 18 insertions(+), 76 deletions(-) diff --git a/src/kirin/emit/abc.py b/src/kirin/emit/abc.py index 40cc58dc1..f86397535 100644 --- a/src/kirin/emit/abc.py +++ b/src/kirin/emit/abc.py @@ -3,14 +3,12 @@ from dataclasses import field, dataclass from kirin import ir, interp -from kirin.worklist import WorkList ValueType = TypeVar("ValueType") @dataclass class EmitFrame(interp.Frame[ValueType]): - worklist: WorkList[interp.Successor] = field(default_factory=WorkList) block_ref: dict[ir.Block, ValueType] = field(default_factory=dict) @@ -20,32 +18,26 @@ class EmitFrame(interp.Frame[ValueType]): @dataclass class EmitABC(interp.BaseInterpreter[FrameType, ValueType], ABC): - def run_callable_region( - self, - frame: FrameType, - code: ir.Statement, - region: ir.Region, - args: tuple[ValueType, ...], - ) -> ValueType: - results = self.eval_stmt(frame, code) - if isinstance(results, tuple): - if len(results) == 0: - return self.void - elif len(results) == 1: - return results[0] - raise interp.InterpreterError(f"Unexpected results {results}") + def emit(self, code: ir.Statement | ir.Method) -> ValueType: + if isinstance(code, ir.Method): + code = code.code - def run_ssacfg_region( - self, frame: FrameType, region: ir.Region, args: tuple[ValueType, ...] - ) -> tuple[ValueType, ...]: - frame.worklist.append(interp.Successor(region.blocks[0], *args)) - while (succ := frame.worklist.pop()) is not None: - frame.set_values(succ.block.args, succ.block_args) - block_header = self.emit_block(frame, succ.block) - frame.block_ref[succ.block] = block_header - return () + with self.new_frame(code) as frame: + result = self.eval_stmt(frame, code) + if result is None: + return self.void + elif isinstance(result, tuple) and len(result) == 1: + return result[0] + raise interp.InterpreterError( + f"Unexpected result {result} from statement {code.name}" + ) def emit_attribute(self, attr: ir.Attribute) -> ValueType: + if attr.dialect not in self.dialects: + raise interp.InterpreterError( + f"Attribute {attr} not in dialects {self.dialects}" + ) + return getattr( self, f"emit_type_{type(attr).__name__}", self.emit_attribute_fallback )(attr) @@ -54,39 +46,3 @@ def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType: if (method := self.registry.attributes.get(type(attr))) is not None: return method(self, attr) raise NotImplementedError(f"Attribute {type(attr)} not implemented") - - def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None: - return - - def emit_stmt_end(self, frame: FrameType, stmt: ir.Statement) -> None: - return - - def emit_block_begin(self, frame: FrameType, block: ir.Block) -> None: - return - - def emit_block_end(self, frame: FrameType, block: ir.Block) -> None: - return - - def emit_block(self, frame: FrameType, block: ir.Block) -> ValueType: - self.emit_block_begin(frame, block) - stmt = block.first_stmt - while stmt is not None: - if self.consume_fuel() == self.FuelResult.Stop: - raise interp.FuelExhaustedError("fuel exhausted") - - self.emit_stmt_begin(frame, stmt) - stmt_results = self.eval_stmt(frame, stmt) - self.emit_stmt_end(frame, stmt) - - match stmt_results: - case tuple(values): - frame.set_values(stmt._results, values) - case interp.ReturnValue(_) | interp.YieldValue(_): - pass - case _: - raise ValueError(f"Unexpected result {stmt_results}") - - stmt = stmt.next_stmt - - self.emit_block_end(frame, block) - return frame.block_ref[block] diff --git a/src/kirin/emit/abc.pyi b/src/kirin/emit/abc.pyi index 8dd321feb..dbfdda2a9 100644 --- a/src/kirin/emit/abc.pyi +++ b/src/kirin/emit/abc.pyi @@ -14,16 +14,7 @@ class EmitFrame(interp.Frame[ValueType]): FrameType = TypeVar("FrameType", bound=EmitFrame) class EmitABC(interp.BaseInterpreter[FrameType, ValueType]): - def run_callable_region( - self, - frame: FrameType, - code: ir.Statement, - region: ir.Region, - args: tuple[ValueType, ...], - ) -> ValueType: ... - def run_ssacfg_region( - self, frame: FrameType, region: ir.Region, args: tuple[ValueType, ...] - ) -> tuple[ValueType, ...]: ... + def emit(self, code: ir.Statement) -> ValueType: ... def emit_attribute(self, attr: ir.Attribute) -> ValueType: ... def emit_type_Any(self, attr: types.AnyType) -> ValueType: ... def emit_type_Bottom(self, attr: types.BottomType) -> ValueType: ... @@ -35,8 +26,3 @@ class EmitABC(interp.BaseInterpreter[FrameType, ValueType]): def emit_type_PyClass(self, attr: types.PyClass) -> ValueType: ... def emit_type_PyAttr(self, attr: ir.PyAttr) -> ValueType: ... def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType: ... - def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None: ... - def emit_stmt_end(self, frame: FrameType, stmt: ir.Statement) -> None: ... - def emit_block_begin(self, frame: FrameType, block: ir.Block) -> None: ... - def emit_block_end(self, frame: FrameType, block: ir.Block) -> None: ... - def emit_block(self, frame: FrameType, block: ir.Block) -> ValueType: ... From 5f4e3fa3a15b15a93fdb15ca43e3a86a4c6a6c31 Mon Sep 17 00:00:00 2001 From: Roger-luo Date: Tue, 22 Apr 2025 15:34:46 -0400 Subject: [PATCH 2/2] rework codegen --- src/kirin/codegen/__init__.py | 0 src/kirin/codegen/abc.py | 46 +++++++++++++++++++++++++++++++++ src/kirin/dialects/cf/emit.py | 8 +++--- src/kirin/dialects/func/emit.py | 32 ++++++++++++++++------- src/kirin/emit/abc.pyi | 2 -- src/kirin/emit/str.py | 41 ++++++++++++++++++++++++++++- src/kirin/emit/transform.py | 12 +++++++++ 7 files changed, 124 insertions(+), 17 deletions(-) create mode 100644 src/kirin/codegen/__init__.py create mode 100644 src/kirin/codegen/abc.py create mode 100644 src/kirin/emit/transform.py diff --git a/src/kirin/codegen/__init__.py b/src/kirin/codegen/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/kirin/codegen/abc.py b/src/kirin/codegen/abc.py new file mode 100644 index 000000000..73b8ecdda --- /dev/null +++ b/src/kirin/codegen/abc.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod +from typing import Generic, TypeVar +from dataclasses import field, dataclass + +from kirin import ir, interp + +ValueType = TypeVar("ValueType") + + +@dataclass +class Frame(interp.FrameABC[ir.SSAValue | ir.Block, ValueType]): + ssa: dict[ir.SSAValue, ValueType] = field(default_factory=dict, kw_only=True) + block: dict[ir.Block, ValueType] = field(default_factory=dict, kw_only=True) + + @abstractmethod + def new(self, key: ir.SSAValue | ir.Block) -> ValueType: ... + + def get(self, key: ir.SSAValue | ir.Block) -> ValueType: + if isinstance(key, ir.Block): + return self.__get_item(self.block, key) + else: + return self.__get_item(self.ssa, key) + + KeyType = TypeVar("KeyType", bound=ir.SSAValue | ir.Block) + + def __get_item(self, entries: dict[KeyType, ValueType], key: KeyType) -> ValueType: + value = entries.get(key, interp.Undefined) + if interp.is_undefined(value): + value = self.new(key) + entries[key] = value + return value + return value + + def set(self, key: ir.SSAValue | ir.Block, value: ValueType) -> None: + if isinstance(key, ir.Block): + self.block[key] = value + else: + self.ssa[key] = value + + +FrameType = TypeVar("FrameType", bound=Frame) + + +@dataclass +class CodegenABC(interp.BaseInterpreter[FrameType, ValueType], ABC): + pass diff --git a/src/kirin/dialects/cf/emit.py b/src/kirin/dialects/cf/emit.py index 618c68faa..734a115d0 100644 --- a/src/kirin/dialects/cf/emit.py +++ b/src/kirin/dialects/cf/emit.py @@ -29,16 +29,16 @@ def emit_cbr( ): cond = frame.get(stmt.cond) interp.writeln(frame, f"if {cond}") - frame.indent += 1 + frame.set_indent += 1 values = frame.get_values(stmt.then_arguments) block_values = tuple(interp.ssa_id[x] for x in stmt.then_successor.args) frame.set_values(stmt.then_successor.args, block_values) for x, y in zip(block_values, values): interp.writeln(frame, f"{x} = {y};") interp.writeln(frame, f"@goto {interp.block_id[stmt.then_successor]};") - frame.indent -= 1 + frame.set_indent -= 1 interp.writeln(frame, "else") - frame.indent += 1 + frame.set_indent += 1 values = frame.get_values(stmt.else_arguments) block_values = tuple(interp.ssa_id[x] for x in stmt.else_successor.args) @@ -46,7 +46,7 @@ def emit_cbr( for x, y in zip(block_values, values): interp.writeln(frame, f"{x} = {y};") interp.writeln(frame, f"@goto {interp.block_id[stmt.else_successor]};") - frame.indent -= 1 + frame.set_indent -= 1 interp.writeln(frame, "end") frame.worklist.append( diff --git a/src/kirin/dialects/func/emit.py b/src/kirin/dialects/func/emit.py index 20f672a0f..b8e237a63 100644 --- a/src/kirin/dialects/func/emit.py +++ b/src/kirin/dialects/func/emit.py @@ -15,17 +15,29 @@ class JuliaMethodTable(MethodTable): @impl(Function) def emit_function( - self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Function + self, emit: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Function ): fn_args = stmt.body.blocks[0].args[1:] - argnames = tuple(interp.ssa_id[arg] for arg in fn_args) - argtypes = tuple(interp.emit_attribute(x.type) for x in fn_args) + argnames = tuple(emit.ssa_id[arg] for arg in fn_args) + argtypes = tuple(emit.emit_attribute(x.type) for x in fn_args) args = [f"{name}::{type}" for name, type in zip(argnames, argtypes)] - interp.write(f"function {stmt.sym_name}({', '.join(args)})") - frame.indent += 1 - interp.run_ssacfg_region(frame, stmt.body, (stmt.sym_name,) + argnames) - frame.indent -= 1 - interp.writeln(frame, "end") + emit.write(f"function {stmt.sym_name}({', '.join(args)})") + with frame.set_indent(1): + for block in stmt.body.blocks: + block_id = emit.block_id[block] + frame.block_ref[block] = block_id + emit.newline(frame) + emit.write(f"@label {block_id};") + + for each_stmt in block.stmts: + results = emit.eval_stmt(frame, each_stmt) + if isinstance(results, tuple): + frame.set_values(each_stmt.results, results) + elif results is not None: + raise InterpreterError( + f"Unexpected result {results} from statement {each_stmt.name}" + ) + emit.writeln(frame, "end") return () @impl(Return) @@ -66,9 +78,9 @@ def emit_lambda( frame.set_values(stmt.body.blocks[0].args, (stmt.sym_name,) + args) frame.captured[stmt.body.blocks[0].args[0]] = frame.get_values(stmt.captured) interp.writeln(frame, f"function {stmt.sym_name}({', '.join(args[1:])})") - frame.indent += 1 + frame.set_indent += 1 interp.run_ssacfg_region(frame, stmt.body, args) - frame.indent -= 1 + frame.set_indent -= 1 interp.writeln(frame, "end") return (stmt.sym_name,) diff --git a/src/kirin/emit/abc.pyi b/src/kirin/emit/abc.pyi index dbfdda2a9..d728c1681 100644 --- a/src/kirin/emit/abc.pyi +++ b/src/kirin/emit/abc.pyi @@ -2,13 +2,11 @@ from typing import TypeVar from dataclasses import field, dataclass from kirin import ir, types, interp -from kirin.worklist import WorkList ValueType = TypeVar("ValueType") @dataclass class EmitFrame(interp.Frame[ValueType]): - worklist: WorkList[interp.Successor] = field(default_factory=WorkList) block_ref: dict[ir.Block, ValueType] = field(default_factory=dict) FrameType = TypeVar("FrameType", bound=EmitFrame) diff --git a/src/kirin/emit/str.py b/src/kirin/emit/str.py index ad82d52e4..0156a85f2 100644 --- a/src/kirin/emit/str.py +++ b/src/kirin/emit/str.py @@ -1,23 +1,36 @@ from abc import ABC from typing import IO, Generic, TypeVar +from contextlib import contextmanager from dataclasses import field, dataclass from kirin import ir, interp, idtable from kirin.emit.abc import EmitABC, EmitFrame +from .exceptions import EmitError + IO_t = TypeVar("IO_t", bound=IO) @dataclass class EmitStrFrame(EmitFrame[str]): indent: int = 0 + ssa_id: idtable.IdTable[ir.SSAValue] = field( + default_factory=lambda: idtable.IdTable(prefix="", prefix_if_none="var_") + ) captured: dict[ir.SSAValue, tuple[str, ...]] = field(default_factory=dict) + @contextmanager + def set_indent(self, indent: int): + self.indent += indent + try: + yield + finally: + self.indent -= indent + @dataclass class EmitStr(EmitABC[EmitStrFrame, str], ABC, Generic[IO_t]): void = "" - file: IO_t prefix: str = field(default="", kw_only=True) prefix_if_none: str = field(default="var_", kw_only=True) @@ -41,6 +54,32 @@ def run_method( raise interp.InterpreterError("maximum recursion depth exceeded") return self.run_callable(method.code, (method.sym_name,) + args) + def run_callable_region( + self, + frame: EmitStrFrame, + code: ir.Statement, + region: ir.Region, + args: tuple[str, ...], + ) -> str: + lines = [] + for block in region.blocks: + block_id = self.block_id[block] + frame.block_ref[block] = block_id + with frame.set_indent(1): + self.run_succ( + frame, interp.Successor(block, frame.get_values(block.args)) + ) + self.write(f"@label {block_id};") + + for each_stmt in block.stmts: + results = self.eval_stmt(frame, each_stmt) + if isinstance(results, tuple): + frame.set_values(each_stmt.results, results) + elif results is not None: + raise EmitError( + f"Unexpected result {results} from statement {each_stmt.name}" + ) + def write(self, *args): for arg in args: self.file.write(arg) diff --git a/src/kirin/emit/transform.py b/src/kirin/emit/transform.py new file mode 100644 index 000000000..4d4e04305 --- /dev/null +++ b/src/kirin/emit/transform.py @@ -0,0 +1,12 @@ +from typing import TypeVar + +from kirin import ir, interp + +from .abc import EmitABC, EmitFrame + +ValueType = TypeVar("ValueType") +FrameType = TypeVar("FrameType", bound=interp.Frame) + + +class Transform(EmitABC[FrameType, ir.IRNode]): + pass