diff --git a/src/kirin/dialects/py/indexing.py b/src/kirin/dialects/py/indexing.py index a1d88243b..564f25915 100644 --- a/src/kirin/dialects/py/indexing.py +++ b/src/kirin/dialects/py/indexing.py @@ -98,7 +98,17 @@ class Concrete(interp.MethodTable): @interp.impl(GetItem) def getindex(self, interp, frame: interp.Frame, stmt: GetItem): - return (frame.get(stmt.obj)[frame.get(stmt.index)],) + from kirin.dialects.py.slice import SliceAttribute + + index = frame.get(stmt.index) + + # need to handle special case of slice attribute + if isinstance(index, SliceAttribute): + index_value = index.unwrap() + else: + index_value = index + + return (frame.get(stmt.obj)[index_value],) @dialect.register(key="typeinfer") @@ -208,7 +218,16 @@ def getitem( return (const.Unknown(),) if isinstance(obj, const.Value): - return (const.Value(obj.data[index.data]),) + from kirin.dialects.py.slice import SliceAttribute + + # need to handle special case of slice attribute + if isinstance(index.data, SliceAttribute): + index_value = index.data.unwrap() + else: + index_value = index.data + + return (const.Value(obj.data[index_value]),) + elif isinstance(obj, const.PartialTuple): obj = obj.data if isinstance(index.data, int) and 0 <= index.data < len(obj): diff --git a/src/kirin/dialects/py/slice.py b/src/kirin/dialects/py/slice.py index e7a2adde1..93d9e8936 100644 --- a/src/kirin/dialects/py/slice.py +++ b/src/kirin/dialects/py/slice.py @@ -13,6 +13,7 @@ from kirin import ir, types, interp, lowering from kirin.decl import info, statement +from kirin.print.printer import Printer from kirin.dialects.py.constant import Constant dialect = ir.Dialect("py.slice") @@ -62,6 +63,33 @@ def __init__( ) +@dataclass +class SliceAttribute(ir.Data[slice]): + + start: int | None + stop: int | None + step: int | None + + def __post_init__(self) -> None: + if self.start is None and self.step is None: + self.type = types.Slice[types.Literal(self.stop)] + else: + self.type = types.Slice3[ + types.Literal(self.start), + types.Literal(self.stop), + types.Literal(self.step), + ] + + def unwrap(self): + return slice(self.start, self.stop, self.step) + + def __hash__(self): + return hash((type(self), slice, self.start, self.stop, self.step)) + + def print_impl(self, printer: Printer) -> None: + return printer.plain_print(f"slice({self.start}, {self.stop}, {self.step})") + + @dialect.register class Concrete(interp.MethodTable): @@ -69,11 +97,11 @@ class Concrete(interp.MethodTable): def _slice(self, interp, frame: interp.Frame, stmt: Slice): start, stop, step = frame.get_values(stmt.args) if start is None and step is None: - return (slice(stop),) + return (SliceAttribute(None, stop, None),) elif step is None: - return (slice(start, stop),) + return (SliceAttribute(start, stop, None),) else: - return (slice(start, stop, step),) + return (SliceAttribute(start, stop, step),) @dialect.register diff --git a/src/kirin/types.py b/src/kirin/types.py index 2fbd7a104..1e4721353 100644 --- a/src/kirin/types.py +++ b/src/kirin/types.py @@ -25,6 +25,7 @@ NoneType = PyClass(type(None)) List = Generic(list, TypeVar("T")) Slice = Generic(slice, TypeVar("T")) +Slice3 = Generic(slice, TypeVar("T1"), TypeVar("T2"), TypeVar("T3")) Tuple = Generic(tuple, Vararg(TypeVar("T"))) Dict = Generic(dict, TypeVar("K"), TypeVar("V")) Set = Generic(set, TypeVar("T"))