Skip to content

Commit a441927

Browse files
committed
Flatten hierarchy of primops
1 parent d3e490c commit a441927

File tree

10 files changed

+106
-313
lines changed

10 files changed

+106
-313
lines changed

src/lib/Builder.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ emit e = case e of
5959
{-# INLINE emit #-}
6060

6161
idExpr :: Atom n -> Expr n
62-
idExpr x = PrimOp (getType x) (UnOp Identity x)
62+
idExpr x = PrimOp (getType x) Identity [x]
6363

6464
declsToExpr :: RNest Decl n l -> Atom l -> Expr n
6565
declsToExpr (RNest ds (Let b e)) (Var v _) | v == binderName b = maybeBlock ds e

src/lib/ConcreteSyntax.hs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,11 @@ leafGroup = leafGroup' >>= appendPostfixGroups
460460
'[' -> cBrackets
461461
'\"' -> toCLeaf CString <$> strLit
462462
'\'' -> toCLeaf CChar <$> charLit
463-
'%' -> do
464-
WithSrc sid name <- primName
465-
case strToPrimName name of
466-
Just prim -> WithSrcs sid [] <$> CPrim prim <$> argList
467-
Nothing -> fail $ "Unrecognized primitive: " ++ bs2str name
463+
'%' -> undefined
464+
-- WithSrc sid name <- primName
465+
-- case strToPrimName name of
466+
-- Just prim -> WithSrcs sid [] <$> CPrim prim <$> argList
467+
-- Nothing -> fail $ "Unrecognized primitive: " ++ bs2str name
468468
_ | isDigit next -> ( toCLeaf CNat <$> natLit
469469
<|> toCLeaf CFloat <$> doubleLit)
470470
'\\' -> withSrcs (cNullaryLam <|> cLam)

src/lib/Inference.hs

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,7 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of
541541
-- UPrim UProjNewtype [x] -> do
542542
-- x' <- bottomUp x >>= unwrapNewtype
543543
-- return $ SigmaAtom Nothing x'
544-
UPrim prim xs -> do
545-
xs' <- mapM bottomUp xs
546-
liftM (SigmaAtom Nothing) $ matchPrimApp prim xs'
544+
UPrim prim xs -> throw sid $ MiscTypeErr "primitive ops must have result type annotations"
547545
-- UNatLit l -> liftM (SigmaAtom Nothing) $ fromNatLit sid l NatTy
548546
-- UIntLit l -> liftM (SigmaAtom Nothing) $ fromIntLit sid l (BaseTy $ Scalar Int32Type)
549547
UFloatLit x -> return $ SigmaAtom Nothing $ CLit $ Float32Lit $ realToFrac x
@@ -988,66 +986,8 @@ checkOrInferApp appSrcId funSrcId f' posArgs namedArgs reqTy = undefined
988986
-- when (not $ null unrecognizedNames) do
989987
-- throw sid $ UnrecognizedOptionalArgs (map pprint unrecognizedNames) (map pprint acceptedNames)
990988

991-
matchPrimApp :: PrimName -> [CAtom o] -> InfererM i o (CExpr o)
992-
matchPrimApp = \case
993-
-- UNat -> \case ~[] -> return $ toAtom $ NewtypeTyCon Nat
994-
-- UFin -> \case ~[n] -> return $ toAtom $ NewtypeTyCon (Fin n)
995-
-- UBaseType b -> \case ~[] -> return $ toAtomR $ BaseType b
996-
-- UNatCon -> \case ~[x] -> return $ toAtom $ NewtypeCon NatCon x
997-
-- UPrimTC tc -> case tc of
998-
-- S.ProdType -> \ts -> return $ toAtom $ ProdType $ map (fromJust . toMaybeType) ts
999-
-- S.SumType -> \ts -> return $ toAtom $ SumType $ map (fromJust . toMaybeType) ts
1000-
-- S.RefType -> \case ~[h, a] -> undefined -- return $ toAtom $ RefType h (fromJust $ toMaybeType a)
1001-
-- S.TypeKind -> \case ~[] -> return $ toAtom $ Kind $ TypeKind
1002-
-- UCon con -> case con of
1003-
-- S.ProdCon -> \xs -> return $ toAtom $ ProdCon xs
1004-
-- S.SumCon _ -> error "not supported"
1005-
MiscOp op -> \xs -> matchMiscOp op xs
1006-
-- MemOp op -> \xs -> CPrimOp <$> MemOp <$> matchGenericOp op xs
1007-
UnOp op () -> \case ~[x] -> return $ CPrimOp (getCType x) $ UnOp op x
1008-
BinOp op () () -> \case ~[x, y] -> return $ CPrimOp (getCType x) $ BinOp op x y
1009-
-- UMGet -> \case ~[r] -> emitRefOp r MGet
1010-
-- UMPut -> \case ~[r, x] -> emitRefOp r $ MPut x
1011-
-- UIndexRef -> \case ~[r, i] -> indexRef r i
1012-
-- UApplyMethod i -> \case ~(d:args) -> emit =<< mkApplyMethod (fromJust $ toMaybeDict d) i args
1013-
-- ULinearize -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Linearize f' x
1014-
-- UTranspose -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Transpose f' x
1015-
-- p -> \case xs -> throwInternal $ "Bad primitive application: " ++ show (p, xs)
1016-
1017-
cUnitTy :: CType n
1018-
cUnitTy = CTyCon $ CProdType []
1019-
1020-
matchMiscOp :: MiscOp () -> [CAtom o] -> InfererM i o (CExpr o)
1021-
matchMiscOp = \case
1022-
DebugPrintInt () -> \case ~[x] -> return $ CPrimOp cUnitTy $ MiscOp $ DebugPrintInt x
1023-
-- where
1024-
-- -- lam2 :: Fallible m => CAtom n -> m (LamExpr n)
1025-
-- -- lam2 x = do
1026-
-- -- ExplicitCoreLam (BinaryNest b1 b2) body <- return x
1027-
-- -- return $ BinaryLamExpr b1 b2 body
1028-
1029-
-- -- lam1 :: Fallible m => CAtom n -> m (LamExpr n)
1030-
-- -- lam1 x = do
1031-
-- -- ExplicitCoreLam (UnaryNest b) body <- return x
1032-
-- -- return $ UnaryLamExpr b body
1033-
1034-
-- matchGenericOp :: Functor op => op () -> [CAtom n] -> InfererM i n (op (CAtom n))
1035-
-- matchGenericOp op xs = undefined
1036-
-- do
1037-
-- (tyArgs, dataArgs) <- partitionEithers <$> forM xs \x -> do
1038-
-- case getType x of
1039-
-- TyCon (Kind TypeKind) -> do
1040-
-- Just x' <- return $ toMaybeType x
1041-
-- return $ Left x'
1042-
-- _ -> return $ Right x
1043-
-- let tyArgs' = case tyArgs of
1044-
-- [] -> Nothing
1045-
-- [t] -> Just t
1046-
-- _ -> error "Expected at most one type arg"
1047-
-- return $ fromJust $ toOp $ GenericOpRep op tyArgs' dataArgs
1048-
1049-
-- pattern ExplicitCoreLam :: Nest CBinder n l -> CExpr l -> CAtom n
1050-
-- pattern ExplicitCoreLam bs body <- Con (Lam (CoreLamExpr _ (LamExpr bs body)))
989+
matchPrimApp :: PrimOp -> [CAtom o] -> InfererM i o (CExpr o)
990+
matchPrimApp = undefined
1051991

1052992
-- -- === n-ary applications ===
1053993

src/lib/QueryTypePure.hs

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,6 @@ litType v = case v of
2828
Float32Lit _ -> Scalar Float32Type
2929
PtrLit ty _ -> PtrType ty
3030

31-
typeBinOp :: BinOp -> BaseType -> BaseType
32-
typeBinOp binop xTy = case binop of
33-
IAdd -> xTy; ISub -> xTy
34-
IMul -> xTy; IDiv -> xTy
35-
IRem -> xTy;
36-
ICmp _ -> Scalar Word8Type
37-
FAdd -> xTy; FSub -> xTy
38-
FMul -> xTy; FDiv -> xTy;
39-
FPow -> xTy
40-
FCmp _ -> Scalar Word8Type
41-
BAnd -> xTy; BOr -> xTy
42-
BXor -> xTy
43-
BShL -> xTy; BShR -> xTy
44-
45-
typeUnOp :: UnOp -> BaseType -> BaseType
46-
typeUnOp = const id -- All unary ops preserve the type of the input
47-
4831
getKind :: Type n -> Kind
4932
getKind = undefined
5033
-- getKind = \case
@@ -105,7 +88,7 @@ instance HasType Expr where
10588
getType = \case
10689
Block ty _ -> ty
10790
TopApp ty _ _ -> ty
108-
PrimOp ty _ -> ty
91+
PrimOp ty _ _ -> ty
10992
Case ty _ _ -> ty
11093
For _ _ -> undefined
11194
While _ -> undefined
@@ -130,7 +113,7 @@ instance HasCType CExpr where
130113
CBlock ty _ -> ty
131114
CVar _ ty -> ty
132115
CLit l -> CTyCon $ CBaseType $ litType l
133-
CPrimOp ty _ -> ty
116+
CPrimOp ty _ _ -> ty
134117
-- CTyCon (CTyCon n)
135118
-- Lam (CoreLamExpr n)
136119
-- NewtypeCon (NewtypeCon n) (CExpr n)

src/lib/Simplify.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,10 @@ simplifyLam (CoreLamExpr _ ab) = go ab
373373
simplifyExpr :: Emits o => CExpr i -> SimplifyM i o (Atom o)
374374
simplifyExpr = \case
375375
CLit val -> return $ Lit val
376-
CPrimOp ty op -> do
377-
op' <- mapM simplifyExpr op
376+
CPrimOp ty op xs -> do
377+
xs' <- mapM simplifyExpr xs
378378
ty <- simplifyType ty
379-
emit $ PrimOp ty op'
379+
emit $ PrimOp ty op xs'
380380
e -> error $ show e
381381

382382
simplifyType :: CType i -> SimplifyM i o (Type o)

src/lib/ToLLVM.hs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import Types.Simple
1818
import Types.Primitives
1919
import PPrint
2020

21+
import Err
2122
import Debug.Trace
2223
import QueryTypePure
2324
import Util
@@ -52,13 +53,13 @@ data TranslateState i = TranslateState
5253
type TranslateSubst i = Subst (LiftE L.Operand) i VoidS
5354

5455
newtype TranslateM (i::S) (a:: *) =
55-
TranslateM { inner :: State (TranslateState i) a }
56-
deriving (Functor, Applicative, Monad)
56+
TranslateM { inner :: StateT (TranslateState i) Except a }
57+
deriving (Functor, Applicative, Monad, MonadFail)
5758

5859
runTranslateM :: Monad m => TranslateM VoidS a -> m (TranslateState VoidS)
5960
runTranslateM cont = do
6061
let initState = TranslateState [] [] (L.Name "__entry__") 0 voidSubst
61-
return $ execState cont.inner initState
62+
return $ ignoreExcept $ execStateT cont.inner initState
6263

6364
emitInstr :: L.Type -> L.Instruction -> TranslateM i L.Operand
6465
emitInstr resultTy instr = do
@@ -76,7 +77,7 @@ extendEnv :: NameBinder i i' -> L.Operand -> TranslateM i' a -> TranslateM i a
7677
extendEnv b x cont = TranslateM do
7778
prevState <- get
7879
let subst' = prevState.subst <>> (b @> LiftE x)
79-
let (ans, newState) = runState (cont.inner) $ updateSubst prevState subst'
80+
let (ans, newState) = ignoreExcept $ runStateT (cont.inner) $ updateSubst prevState subst'
8081
put $ updateSubst newState prevState.subst
8182
return ans
8283

@@ -117,10 +118,10 @@ toLLVMEntryFun' (TopLamExpr (Abs Empty body)) = do
117118
trExpr :: Expr i -> TranslateM i L.Operand
118119
trExpr = \case
119120
Block resultTy block -> trBlock block
120-
PrimOp resultTy op -> do
121+
PrimOp resultTy op xs -> do
121122
resultTy' <- trType resultTy
122-
op' <- forM op trAtom
123-
trPrimOp resultTy' op'
123+
xs' <- forM xs trAtom
124+
trPrimOp resultTy' op xs'
124125

125126
trType :: Type i -> TranslateM i L.Type
126127
trType = \case
@@ -142,11 +143,12 @@ trBlock (Abs decls result) = case decls of
142143
val <- trExpr expr
143144
extendEnv b val $ trBlock $ Abs rest result
144145

145-
trPrimOp :: L.Type -> PrimOp L.Operand -> TranslateM i L.Operand
146-
trPrimOp resultTy op = case op of
147-
BinOp b x y -> case b of
148-
FAdd -> emitInstr resultTy $ L.FAdd x y
149-
MiscOp op' -> case op' of
150-
DebugPrintInt x -> do
151-
emitStatement $ L.Call floatTy "printfloat" [x]
152-
return unitOperand
146+
trPrimOp :: L.Type -> PrimOp -> [L.Operand] -> TranslateM i L.Operand
147+
trPrimOp resultTy op xs = case op of
148+
FAdd -> do
149+
[x, y] <- return xs
150+
emitInstr resultTy $ L.FAdd x y
151+
DebugPrintInt -> do
152+
[x] <- return xs
153+
emitStatement $ L.Call floatTy "printfloat" [x]
154+
return unitOperand

src/lib/Types/Complicated.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ data CExpr (n::S) =
2828
CBlock (CType n) (CBlock n)
2929
| CVar (Name n) (CType n)
3030
| CLit LitVal
31-
| CPrimOp (CType n) (PrimOp (CExpr n))
31+
| CPrimOp (CType n) PrimOp [CExpr n]
3232
| CTyCon (CTyCon n)
3333
| Lam (CoreLamExpr n)
3434
| NewtypeCon (NewtypeCon n) (CExpr n)
@@ -179,7 +179,7 @@ instance Pretty (CExpr n) where
179179
CBlock _ b -> pr b
180180
CVar v _ -> pr v
181181
CLit l -> pr l
182-
CPrimOp _ op -> pr op
182+
CPrimOp _ op args -> app (pr op) (map pr args)
183183
CTyCon _ -> undefined
184184
Lam _ -> undefined
185185
NewtypeCon _ _ -> undefined

0 commit comments

Comments
 (0)