@@ -242,6 +242,10 @@ public export
242
242
liftList : Foldable f => f TTImp -> TTImp
243
243
liftList = foldr (\ l, r => `(~ l :: ~ r)) `([])
244
244
245
+ public export
246
+ liftList' : Foldable f => f TTImp -> TTImp
247
+ liftList' = foldr (\ l, r => `(Prelude . (:: ) ~ l ~ r)) `(Prelude . Nil )
248
+
245
249
export
246
250
liftWeight1 : TTImp
247
251
liftWeight1 = `(Data . Nat1 . one)
@@ -795,27 +799,25 @@ getConsRecs = do
795
799
consRecs <- for niit. types $ \ targetType => logBounds {level= DetailedTrace } " consRec" [targetType] $ do
796
800
crsForTy <- for targetType. cons $ \ con => do
797
801
tuneImpl <- search $ ProbabilityTuning con. name
798
- w : Either Nat1 (TTImp -> TTImp, Maybe $ SortedSet $ Fin con.args.length) <- case isRecursive {containingType=Just targetType} con of
799
- -- ^^^^^^^^^^^^^^ ^^^^^ ^^^^^^^^^^^^^^^ <- set of directly recursive constructor arguments
800
- -- | \-- `Just` in this `Maybe` means that this constructor only contains direct recursion (not mutual one)
802
+ w : Either Nat1 (TTImp -> TTImp, SortedSet $ Fin con.args.length) <- case isRecursive {containingType=Just targetType} con of
803
+ -- ^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^ <- set of directly recursive constructor arguments
801
804
-- \------ Modifier of the standard weight expression
802
805
False => pure $ Left $ maybe one (\ impl => tuneWeight @{impl} one) tuneImpl
803
806
True => Right <$> do
804
807
fuelWeightExpr <- case tuneImpl of
805
808
Nothing => pure id
806
809
Just impl => quote (tuneWeight @{impl}) <&> \ wm, expr => workaroundFromNat $ wm `applySyn` expr
807
- let directlyRec = filter (not . null ) $ map (fromList . mapMaybe id ) $ for con. args. withIdx $ \ (idx, arg) => do
808
- case (== targetType. name) <$> getAppVar arg. type of
809
- Just True => Just $ Just idx
810
- _ => if hasNameInsideDeep targetType. name arg. type then Nothing else Just Nothing
811
- whenJust directlyRec $ \ ars =>
812
- logPoint {level= DetailedTrace } " consRec" [targetType, con] " - directly recursive, rec args: \{show $ finToNat <$> ars.asList}"
813
- pure (fuelWeightExpr, directlyRec)
810
+ let directlyRecArgs : List $ Fin con.args.length := flip mapMaybe con.args.withIdx $ \idxarg => do
811
+ argTy <- getAppVar (snd idxarg). type
812
+ whenT .| argTy == targetType. name .| fst idxarg
813
+ when (not $ null directlyRecArgs) $
814
+ logPoint {level= DetailedTrace } " consRec" [targetType, con] " - directly recursive args: \{show $ finToNat <$> directlyRecArgs}"
815
+ pure (fuelWeightExpr, fromList directlyRecArgs)
814
816
pure (con ** w)
815
817
-- determine if this type is a nat-or-list-like data, i.e. one which we can measure for the probability
816
- let weightable = flip all crsForTy $ \ case (_ ** Right (_ , Nothing )) => False ; _ => True
818
+ let weightable = flip any crsForTy $ \ case (_ ** Right (_ , dra )) => not $ null dra ; _ => False
817
819
pure (toMaybe weightable targetType, crsForTy)
818
- let 0 _ : SortedMap Name (Maybe TypeInfo, List (con : Con ** Either Nat1 (TTImp -> TTImp, Maybe $ SortedSet $ Fin con.args.length))) := consRecs
820
+ let 0 _ : SortedMap Name (Maybe TypeInfo, List (con : Con ** Either Nat1 (TTImp -> TTImp, SortedSet $ Fin con.args.length))) := consRecs
819
821
820
822
let weightableTyArgs : (ars : List Arg) -> SortedMap Nat (TypeInfo, Name) -- <- a map from Fin ars.length to a weightable type and its argument name
821
823
weightableTyArgs ars = fromList $ flip List . mapMaybe ars. withIdx $ \ (idx, ar) =>
@@ -853,7 +855,7 @@ getConsRecs = do
853
855
let funSig = export' weightFunName $ piAll `(Data . Nat1 . Nat1 ) $ map {piInfo : = ImplicitArg} ty.args ++ [inTyArg]
854
856
855
857
let wClauses = cons <&> \ (con ** e) => do
856
- let wArgs = either (const empty) (fromMaybe empty . snd ) e
858
+ let wArgs = either (const empty) snd e
857
859
let lhsArgs : List (_ , _ ) = mapI con.args $ \idx, arg => appArg arg <$> if contains idx wArgs
858
860
then let bindName = " arg^\ {show idx}" in (Just bindName, bindVar bindName)
859
861
else (Nothing , implicitTrue)
@@ -862,7 +864,7 @@ getConsRecs = do
862
864
patClause (var weightFunName .$ (reAppAny .| var con. name .| snd <$> lhsArgs)) $ case mapMaybe (map (UN . Basic ) . fst ) lhsArgs of
863
865
[] => liftWeight1
864
866
[x] => `(succ ~ (callSelfOn x))
865
- xs => `(succ $ foldMap @{% search} @{ Maximal } ~ (liftList $ xs <&> callSelfOn))
867
+ xs => `(succ $ Prelude . concat @{Maximum } ~ (liftList' $ xs <&> callSelfOn))
866
868
867
869
pure (funSig, def weightFunName wClauses)
868
870
0 commit comments