Skip to content

Commit 8a4ef52

Browse files
authored
switch constant attr to explicit + make IList an attribute (#308)
this PR changes `IList` to an attribute allowing it to be a compile-time value. It also extends the runtime of `IList` so that one can keep track of the element type information at runtime. This solves the problem of capturing global values like below ```python alist = [1, 2, 3] # could be hard to know element type when list is long alist = IList([1, 2, 3], elem=types.Int) # allow user explicitly type the global value @basic def main() return alist ``` This PR is breaking becausae now accessing the `value` field of `Constant` will give either `PyAttr` or some other attribute type (e.g `IList`) instead of original `T` and store as `PyAttr`. This also partially fix #270 . A new attribute base class `Data` and method `unwrap` are provided for convenience to covnert between attribute and runtime Python value. so this is gonna be a tricky breaking change for downstream because it is hard to detect this behaviour change due to the fact that interpreter won't check its value type. But this will allow us using other attributes as a constant value as part of the python dialect.
1 parent 7c033fd commit 8a4ef52

File tree

13 files changed

+143
-49
lines changed

13 files changed

+143
-49
lines changed

src/kirin/dialects/func/typeinfer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@ def const_none(self, interp: TypeInference, frame: Frame, stmt: ConstantNone):
2424
return (types.NoneType,)
2525

2626
@impl(Return)
27-
def return_(self, interp: TypeInference, frame: Frame, stmt: Return) -> ReturnValue:
27+
def return_(
28+
self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: Return
29+
) -> ReturnValue:
2830
if (
2931
isinstance(hint := stmt.value.hints.get("const"), const.Value)
3032
and hint.data is not None
3133
):
32-
return ReturnValue(types.Literal(hint.data))
34+
return ReturnValue(types.Literal(hint.data, frame.get(stmt.value)))
3335
return ReturnValue(frame.get(stmt.value))
3436

3537
@impl(Call)

src/kirin/dialects/ilist/interp.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,30 @@ class IListInterpreter(MethodTable):
1313

1414
@impl(Range)
1515
def _range(self, interp, frame: Frame, stmt: Range):
16-
return (IList(range(*frame.get_values(stmt.args))),)
16+
return (IList(range(*frame.get_values(stmt.args)), elem=types.Int),)
1717

1818
@impl(New)
1919
def new(self, interp, frame: Frame, stmt: New):
20-
return (IList(list(frame.get_values(stmt.values))),)
20+
elem_type = types.Any
21+
if stmt.values:
22+
elem_type = stmt.values[0].type
23+
for each in stmt.values[1:]:
24+
elem_type = elem_type.join(each.type)
25+
return (IList(list(frame.get_values(stmt.values)), elem=elem_type),)
2126

2227
@impl(Len, types.PyClass(IList))
2328
def len(self, interp, frame: Frame, stmt: Len):
2429
return (len(frame.get(stmt.value).data),)
2530

2631
@impl(Add, types.PyClass(IList), types.PyClass(IList))
2732
def add(self, interp, frame: Frame, stmt: Add):
28-
return (IList(frame.get(stmt.lhs).data + frame.get(stmt.rhs).data),)
33+
lhs, rhs = frame.get(stmt.lhs), frame.get(stmt.rhs)
34+
return (IList(lhs.data + rhs.data, elem=lhs.elem.join(rhs.elem)),)
2935

3036
@impl(Push)
3137
def push(self, interp, frame: Frame, stmt: Push):
32-
return (IList(frame.get(stmt.lst).data + [frame.get(stmt.value)]),)
38+
lst = frame.get(stmt.lst)
39+
return (IList(lst.data + [frame.get(stmt.value)], elem=lst.elem),)
3340

3441
@impl(Map)
3542
def map(self, interp: Interpreter, frame: Frame, stmt: Map):
@@ -40,7 +47,7 @@ def map(self, interp: Interpreter, frame: Frame, stmt: Map):
4047
# NOTE: assume fn has been type checked
4148
_, item = interp.run_method(fn, (elem,))
4249
ret.append(item)
43-
return (IList(ret),)
50+
return (IList(ret, elem=fn.return_type),)
4451

4552
@impl(Scan)
4653
def scan(self, interp: Interpreter, frame: Frame, stmt: Scan):
@@ -54,7 +61,15 @@ def scan(self, interp: Interpreter, frame: Frame, stmt: Scan):
5461
# NOTE: assume fn has been type checked
5562
_, (carry, y) = interp.run_method(fn, (carry, elem))
5663
ys.append(y)
57-
return ((carry, IList(ys)),)
64+
65+
if (
66+
isinstance(fn.return_type, types.Generic)
67+
and fn.return_type.is_subseteq(types.Tuple)
68+
and len(fn.return_type.vars) == 2
69+
):
70+
return ((carry, IList(ys, fn.return_type.vars[1])),)
71+
else:
72+
return ((carry, IList(ys, types.Any)),)
5873

5974
@impl(Foldr)
6075
def foldr(self, interp: Interpreter, frame: Frame, stmt: Foldr):

src/kirin/dialects/ilist/rewrite/const.py

+12
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from kirin.analysis import const
33
from kirin.rewrite.abc import RewriteRule
44
from kirin.rewrite.result import RewriteResult
5+
from kirin.dialects.py.constant import Constant
56

67
from ..stmts import IListType
78
from ..runtime import IList
@@ -14,6 +15,9 @@ class ConstList2IList(RewriteRule):
1415
"""
1516

1617
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18+
if isinstance(node, Constant):
19+
return self.rewrite_Constant(node)
20+
1721
has_done_something = False
1822
for result in node.results:
1923
if not isinstance(hint := result.hints.get("const"), const.Value):
@@ -29,6 +33,14 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2933
has_done_something = self._rewrite_IList_type(result, data)
3034
return RewriteResult(has_done_something=has_done_something)
3135

36+
def rewrite_Constant(self, node: Constant) -> RewriteResult:
37+
if isinstance(node.value, ir.PyAttr) and isinstance(node.value.data, list):
38+
stmt = Constant(value=IList(data=node.value.data))
39+
node.replace_by(stmt)
40+
self._rewrite_IList_type(stmt.result, node.value.data)
41+
return RewriteResult(has_done_something=True)
42+
return RewriteResult()
43+
3244
def _rewrite_IList_type(self, result: ir.SSAValue, data):
3345
if not isinstance(data, IList):
3446
return False

src/kirin/dialects/ilist/rewrite/list.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
1616

1717
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
1818
has_done_something = False
19+
if isinstance(node, constant.Constant) and isinstance(node.value, list):
20+
eltype = self._eltype(node.result.type)
21+
node.replace_by(
22+
constant.Constant(value=IList(data=node.value, elem=eltype))
23+
)
24+
1925
for result in node.results:
2026
has_done_something = self._rewrite_SSAValue_type(result)
2127

22-
if has_done_something and isinstance(node, constant.Constant):
23-
node.replace_by(constant.Constant(value=IList(data=node.value)))
24-
2528
return RewriteResult(has_done_something=has_done_something)
2629

2730
def _rewrite_SSAValue_type(self, value: ir.SSAValue):
@@ -36,3 +39,8 @@ def _rewrite_SSAValue_type(self, value: ir.SSAValue):
3639
value.type = IListType[types.Any, types.Any]
3740
return True
3841
return False
42+
43+
def _eltype(self, type: types.TypeAttribute):
44+
if isinstance(type, types.Generic) and issubclass(type.body.typ, (list, IList)):
45+
return type.vars[0]
46+
return types.Any

src/kirin/dialects/ilist/runtime.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,22 @@
33
from dataclasses import dataclass
44
from collections.abc import Sequence
55

6+
from kirin import ir, types
7+
from kirin.print.printer import Printer
8+
69
T = TypeVar("T")
710
L = TypeVar("L")
811

912

1013
@dataclass
11-
class IList(Generic[T, L]):
14+
class IList(ir.Data[Sequence[T]], Generic[T, L]):
1215
"""A simple immutable list."""
1316

1417
data: Sequence[T]
18+
elem: types.TypeAttribute = types.Any
19+
20+
def __post_init__(self):
21+
self.type = types.Generic(IList, self.elem, types.Literal(len(self.data)))
1522

1623
def __hash__(self) -> int:
1724
return id(self) # do not hash the data
@@ -26,7 +33,16 @@ def __add__(self, other: "IList[T, Any]") -> "IList[T, Any]": ...
2633
def __add__(self, other: list[T]) -> "IList[T, Any]": ...
2734

2835
def __add__(self, other):
29-
return IList(self.data + other)
36+
if isinstance(other, list):
37+
return IList(list(self.data) + other, elem=self.elem)
38+
elif isinstance(other, IList):
39+
return IList(
40+
list(self.data) + list(other.data), elem=self.elem.join(other.elem)
41+
)
42+
else:
43+
raise TypeError(
44+
f"unsupported operand type(s) for +: 'IList' and '{type(other)}'"
45+
)
3046

3147
@overload
3248
def __radd__(self, other: "IList[T, Any]") -> "IList[T, Any]": ...
@@ -61,3 +77,13 @@ def __eq__(self, value: object) -> bool:
6177
if not isinstance(value, IList):
6278
return False
6379
return self.data == value.data
80+
81+
def unwrap(self) -> "IList[T, L]":
82+
return self
83+
84+
def print_impl(self, printer: Printer) -> None:
85+
printer.plain_print("IList(")
86+
printer.print_seq(
87+
self.data, delim=", ", prefix="[", suffix="]", emit=printer.plain_print
88+
)
89+
printer.plain_print(")")

src/kirin/dialects/py/constant.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
class Constant(ir.Statement, Generic[T]):
2828
name = "constant"
2929
traits = frozenset({ir.Pure(), ir.ConstantLike(), ir.FromPythonCall()})
30-
value: T = info.attribute()
30+
value: ir.Data[T] = info.attribute()
3131
result: ir.ResultValue = info.result()
3232

3333
# NOTE: we allow py.Constant take data.PyAttr too
34-
def __init__(self, value: T | ir.PyAttr[T]) -> None:
35-
if not isinstance(value, ir.PyAttr):
34+
def __init__(self, value: T | ir.Data[T]) -> None:
35+
if not isinstance(value, ir.Data):
3636
value = ir.PyAttr(value)
3737
super().__init__(
3838
attributes={"value": value},
@@ -68,12 +68,12 @@ class Concrete(interp.MethodTable):
6868

6969
@interp.impl(Constant)
7070
def constant(self, interp, frame: interp.Frame, stmt: Constant):
71-
return (stmt.value,)
71+
return (stmt.value.unwrap(),)
7272

7373

7474
@dialect.register(key="emit.julia")
7575
class JuliaTable(interp.MethodTable):
7676

7777
@interp.impl(Constant)
7878
def emit_Constant(self, emit: EmitJulia, frame: EmitStrFrame, stmt: Constant):
79-
return (emit.emit_attribute(ir.PyAttr(stmt.value)),)
79+
return (emit.emit_attribute(stmt.value),)

src/kirin/ir/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@
4545
from kirin.ir.dialect import Dialect as Dialect
4646
from kirin.ir.attrs.py import PyAttr as PyAttr
4747
from kirin.ir.attrs.abc import Attribute as Attribute, AttributeMeta as AttributeMeta
48+
from kirin.ir.attrs.data import Data as Data

src/kirin/ir/attrs/data.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from abc import abstractmethod
2+
from typing import Generic, TypeVar
3+
from dataclasses import field, dataclass
4+
5+
from .abc import Attribute
6+
from .types import TypeAttribute
7+
8+
T = TypeVar("T", covariant=True)
9+
10+
11+
@dataclass(eq=False)
12+
class Data(Attribute, Generic[T]):
13+
"""Base class for data attributes.
14+
15+
Data attributes are compile-time constants that can be used to
16+
represent runtime data inside the IR.
17+
18+
This class is meant to be subclassed by specific data attributes.
19+
It provides a `type` attribute that should be set to the type of
20+
the data.
21+
"""
22+
23+
type: TypeAttribute = field(init=False, repr=False)
24+
25+
@abstractmethod
26+
def unwrap(self) -> T:
27+
"""Returns the underlying data value."""
28+
...

src/kirin/ir/attrs/py.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
from typing import Generic, TypeVar
1+
from typing import TypeVar
22
from dataclasses import dataclass
33

44
from kirin.print import Printer
55

6-
from .abc import Attribute
6+
from .data import Data
77
from .types import PyClass, TypeAttribute
88

99
T = TypeVar("T")
1010

1111

1212
@dataclass
13-
class PyAttr(Generic[T], Attribute):
13+
class PyAttr(Data[T]):
1414
"""Python attribute for compile-time values.
1515
This is a generic attribute that holds a Python value.
1616
@@ -25,7 +25,6 @@ class PyAttr(Generic[T], Attribute):
2525

2626
name = "PyAttr"
2727
data: T
28-
type: TypeAttribute
2928

3029
def __init__(self, data: T, pytype: TypeAttribute | None = None):
3130
self.data = data
@@ -43,3 +42,6 @@ def print_impl(self, printer: Printer) -> None:
4342
with printer.rich(style="comment"):
4443
printer.plain_print(" : ")
4544
printer.print(self.type)
45+
46+
def unwrap(self) -> T:
47+
return self.data

src/kirin/ir/attrs/types.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,16 @@ def __init__(self, *args, **kwargs):
179179
super(LiteralMeta, self).__init__(*args, **kwargs)
180180
self._cache = {}
181181

182-
def __call__(self, data):
183-
if isinstance(data, Attribute):
184-
return data
182+
def __call__(self, data, datatype=None):
183+
if isinstance(data, TypeAttribute):
184+
return data # already a type
185185
elif not isinstance(data, Hashable):
186-
return PyClass(type(data))
187-
elif data in self._cache:
188-
return self._cache[data]
186+
raise ValueError("Literal data must be hashable")
187+
elif (data, datatype) in self._cache:
188+
return self._cache[(data, datatype)]
189189

190-
instance = super(LiteralMeta, self).__call__(data)
191-
self._cache[data] = instance
190+
instance = super(LiteralMeta, self).__call__(data, datatype)
191+
self._cache[(data, datatype)] = instance
192192
return instance
193193

194194

@@ -200,6 +200,14 @@ def __call__(self, data):
200200
class Literal(TypeAttribute, typing.Generic[LiteralType], metaclass=LiteralMeta):
201201
name = "Literal"
202202
data: LiteralType
203+
type: TypeAttribute
204+
"""type of the literal, this is useful when the Python type of
205+
data does not represent the type in IR, e.g Literal(1, types.Int32)
206+
"""
207+
208+
def __init__(self, data: LiteralType, datatype: TypeAttribute | None = None):
209+
self.data = data
210+
self.type = datatype or PyClass(type(data))
203211

204212
def is_equal(self, other: TypeAttribute) -> bool:
205213
return self is other
@@ -211,16 +219,16 @@ def is_subseteq_Union(self, other: "Union") -> bool:
211219
return any(self.is_subseteq(t) for t in other.types)
212220

213221
def is_subseteq_Literal(self, other: "Literal") -> bool:
214-
return self.data == other.data
222+
return self.data == other.data and self.type.is_subseteq(other.type)
215223

216224
def is_subseteq_fallback(self, other: TypeAttribute) -> bool:
217-
return PyClass(type(self.data)).is_subseteq(other)
225+
return self.type.is_subseteq(other)
218226

219227
def __hash__(self) -> int:
220228
return hash((Literal, self.data))
221229

222230
def print_impl(self, printer: Printer) -> None:
223-
printer.plain_print(repr(self.data))
231+
printer.plain_print("Literal(", repr(self.data), ",", self.type, ")")
224232

225233

226234
@typing.final

src/kirin/ir/nodes/stmt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def print_impl(self, printer: Printer) -> None:
611611
if isinstance(values, SSAValue):
612612
printer.print(values)
613613
else:
614-
printer.print_seq(values, delim=", ")
614+
printer.print_seq(values, delim=", ", prefix="(", suffix=")")
615615

616616
if idx < len(self._name_args_slice) - 1:
617617
printer.plain_print(", ")

0 commit comments

Comments
 (0)