Skip to content

Commit 7cb8cf3

Browse files
authored
Merge pull request #5778 from unisonweb/fix/opt-captures
Fix some variable naming problems with new optimizations
2 parents c42cbb2 + d686343 commit 7cb8cf3

File tree

6 files changed

+99
-37
lines changed

6 files changed

+99
-37
lines changed

unison-core/src/Unison/ABT/Normalized.hs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,16 @@ freshenBinder fvs rn0@(RN cf rn) u = (rn', u')
230230
}
231231
| otherwise = rn0
232232

233+
-- Simultaneously freshens some binders. This ensures not just that
234+
-- they're fresh with respect to the given set of variables, but
235+
-- mutually distinct.
233236
freshenBinders ::
234237
(Var v) => Set v -> Renaming v -> [v] -> (Renaming v, [v])
235-
freshenBinders fvs = mapAccumL (freshenBinder fvs)
238+
freshenBinders fvs rn0 = first snd . mapAccumL f (Set.empty, rn0)
239+
where
240+
f (avoid, rn) u
241+
| (rn, v) <- freshenBinder (Set.union avoid fvs) rn u =
242+
((Set.insert v avoid, rn), v)
236243

237244
-- Simultaneous variable renaming and freshening implementation.
238245
--
@@ -254,10 +261,10 @@ renamesAndFreshen0 ::
254261
Term f v ->
255262
Term f v
256263
renamesAndFreshen0 rn0 tm = case tm of
257-
TAbs u body
258-
| (rn, u') <- freshenBinder (freeVars body) rn u,
259-
u /= u' || not (isEmptyRenaming rn) ->
260-
TAbs u' (renamesAndFreshen0 rn body)
264+
TAbs u (TAbss us body)
265+
| (rn, vs) <- freshenBinders (freeVars body) rn (u : us),
266+
u : us /= vs || not (isEmptyRenaming rn) ->
267+
TAbss vs (renamesAndFreshen0 rn body)
261268
TTm body
262269
| not $ isEmptyRenaming rn ->
263270
TTm $ bimap (renameVar rn) (renamesAndFreshen0 rn) body

unison-runtime/package.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ flags:
1717
manual: true
1818
default: false
1919

20+
# Run code through the serializer to test code loading as if it were
21+
# remote.
22+
codeserialchecks:
23+
manual: true
24+
default: false
25+
2026
# Dumps core for debugging to unison-runtime/.stack-work/dist/<arch>/ghc-x.y.z/build/
2127
dumpcore:
2228
manual: true
@@ -32,6 +38,8 @@ when:
3238
cpp-options: -DOPT_CHECK
3339
dependencies:
3440
- inspection-testing
41+
- condition: flag(codeserialchecks)
42+
ghc-options: -DCODE_SERIAL_CHECK
3543
- condition: flag(dumpcore)
3644
ghc-options: -ddump-simpl -ddump-stg-final -ddump-to-file -dsuppress-coercions -dsuppress-idinfo -dsuppress-module-prefixes -ddump-str-signatures -ddump-simpl-stats # -dsuppress-type-applications -dsuppress-type-signatures
3745

unison-runtime/src/Unison/Runtime/ANF.hs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,10 +2538,10 @@ prettyBranches ind bs = case bs of
25382538
MatchText bs df ->
25392539
maybe id (\e -> prettyCase ind (showString "_") e id) df
25402540
. foldr (uncurry $ prettyCase ind . shows) id (Map.toList bs)
2541-
MatchData _ bs df ->
2541+
MatchData r bs df ->
25422542
maybe id (\e -> prettyCase ind (showString "_") e id) df
25432543
. foldr
2544-
(uncurry $ prettyCase ind . shows)
2544+
(uncurry $ prettyCase ind . prettyTag r)
25452545
id
25462546
(mapToList $ snd <$> bs)
25472547
MatchRequest bs df ->
@@ -2572,6 +2572,13 @@ prettyBranches ind bs = case bs of
25722572
. shows c
25732573
. showString ")"
25742574

2575+
prettyTag r c =
2576+
showString "CON("
2577+
. showsShort r
2578+
. showString ","
2579+
. shows c
2580+
. showString ")"
2581+
25752582
prettyCase :: (Var v) => Int -> ShowS -> ANormal v -> ShowS -> ShowS
25762583
prettyCase ind sc (ABTN.TAbss vs e) r =
25772584
showString "\n"

unison-runtime/src/Unison/Runtime/ANF/Optimize.hs

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -165,46 +165,52 @@ whenChanged f act = do
165165

166166
descend ::
167167
(Memo m, Var v) =>
168-
(Bool -> ANormal v -> m (ANormal v)) ->
168+
(Bool -> Set v -> ANormal v -> m (ANormal v)) ->
169169
Bool ->
170+
Set v ->
170171
ANormal v ->
171172
m (ANormal v)
172-
descend rec tail tm = memo tm $ case tm of
173+
descend rec tail bound tm = memo tm $ case tm of
173174
TLets d vs ccs bn bd ->
174-
TLets d vs ccs <$> rec False bn <*> rec tail bd
175+
TLets d vs ccs <$> rec False bound bn <*> rec tail bnd bd
176+
where
177+
bnd = Set.union (Set.fromList vs) bound
175178
TName v f vs bd ->
176-
TName v f vs <$> rec tail bd
179+
TName v f vs <$> rec tail (Set.insert v bound) bd
177180
TMatch v bs ->
178-
TMatch v <$> traverse (rec tail) bs
181+
TMatch v <$> traverse (rec tail bound) bs
179182
TShift r v bd ->
180-
TShift r v <$> rec tail bd
183+
TShift r v <$> rec tail (Set.insert v bound) bd
181184
THnd rs hn ha bd ->
182-
THnd rs hn ha <$> rec tail bd
185+
THnd rs hn ha <$> rec tail bound bd
183186
TLocal v bd ->
184-
TLocal v <$> rec tail bd
187+
TLocal v <$> rec tail bound bd
185188
ABTN.TAbs v (ABTN.TAbss vs bd) ->
186-
ABTN.TAbss (v : vs) <$> rec tail bd
189+
ABTN.TAbss (v : vs) <$> rec tail bnd bd
190+
where
191+
bnd = Set.union (Set.fromList $ v : vs) bound
187192
_ -> pure tm
188193

189194
-- Rewrites a term from the top down, first applying the step
190195
-- transform given, then descending to children.
191196
rewriteDown ::
192197
(Memo m, Var v) =>
193-
(Bool -> ANormal v -> m (ANormal v)) ->
198+
(Bool -> Set v -> ANormal v -> m (ANormal v)) ->
194199
ANormal v ->
195200
m (ANormal v)
196-
rewriteDown step = go True
201+
rewriteDown step = go True Set.empty
197202
where
198-
go tail tm = step tail tm >>= descend go tail
203+
go tail bound tm = step tail bound tm >>= descend go tail bound
199204

200205
rewriteUp ::
201206
(Memo m, Var v) =>
202-
(Bool -> ANormal v -> m (ANormal v)) ->
207+
(Bool -> Set v -> ANormal v -> m (ANormal v)) ->
203208
ANormal v ->
204209
m (ANormal v)
205-
rewriteUp step = go True
210+
rewriteUp step = go True Set.empty
206211
where
207-
go tail tm = memo tm (descend go tail tm) >>= step tail
212+
go tail bound tm =
213+
memo tm (descend go tail bound tm) >>= step tail bound
208214

209215
-- Performs inlining on a `SuperGroup` using the inlining information
210216
-- in the map. The map can be created from typical `SuperGroup` data
@@ -225,15 +231,15 @@ inline avoid (arities, inls) n0 = memo n0 $ go (30 :: Int) n0
225231
| n <= 0 = pure tm
226232
| otherwise = rewriteUp (step n) tm
227233

228-
step n tail (TApp (FComb r) args)
229-
| Just new <- findInline tail r args =
234+
step n tail bound (TApp (FComb r) args)
235+
| Just new <- findInline tail bound r args =
230236
dirty *> go (n - 1) new
231-
step _ _tail tm = pure tm
237+
step _ _tail _bound tm = pure tm
232238

233-
findInline tail r args = do
239+
findInline tail bound r args = do
234240
info <- Map.lookup r inls
235241
arity <- Map.lookup r arities
236-
tweak tail args arity info
242+
tweak tail bound args arity info
237243

238244
don'tInline Don'tInl _ = True
239245
don'tInline TailInl isTail = not isTail
@@ -245,17 +251,17 @@ inline avoid (arities, inls) n0 = memo n0 $ go (30 :: Int) n0
245251
-- multiple inlining steps, so we freshen anything else we inline
246252
-- to not be capable of capturing the variables from the entry
247253
-- code.
248-
tweak isTail args arity (InlInfo clazz (ABTN.TAbss vs body))
254+
tweak isTail bound args arity (InlInfo clazz (ABTN.TAbss vs body))
249255
| don'tInline clazz isTail = Nothing
250256
-- exactly saturated
251257
| length args == arity,
252258
rn <- Map.fromList (zip vs args) =
253-
Just $ ABTN.renamesAvoiding avoid rn body
259+
Just $ ABTN.renamesAvoiding (avoid `Set.union` bound) rn body
254260
-- oversaturated, only makes sense if body is a call
255261
| length args > arity,
256262
(pre, post) <- splitAt arity args,
257263
rn <- Map.fromList (zip vs pre),
258-
TApp f pre <- ABTN.renamesAvoiding avoid rn body =
264+
TApp f pre <- ABTN.renamesAvoiding (avoid `Set.union` bound) rn body =
259265
Just $ TApp f (pre ++ post)
260266
| otherwise = Nothing
261267

@@ -276,7 +282,7 @@ peephole arities affine n0 = memo n0 $ go (30 :: Int) n0
276282
where
277283
go 0 = pure
278284
go n =
279-
whenChanged (go $ n - 1) . rewriteDown \tail -> \case
285+
whenChanged (go $ n - 1) . rewriteDown \tail _bound -> \case
280286
-- eliminate `v = u` bindings in affine contexts
281287
TLet _ v _ (TVar u) bd
282288
| affine -> ABTN.rename v u bd <$ dirty
@@ -338,9 +344,11 @@ optSuper ::
338344
Bool ->
339345
SuperNormal v ->
340346
m (SuperNormal v)
341-
optSuper opts avoid affine sn@(Lambda ccs (ABTN.TAbss vs bd)) =
347+
optSuper opts avoid0 affine sn@(Lambda ccs (ABTN.TAbss vs bd)) =
342348
memo sn $
343349
Lambda ccs . ABTN.TAbss vs <$> optNormal opts avoid affine bd
350+
where
351+
avoid = Set.union (Set.fromList vs) avoid0
344352

345353
-- Optimizes a single group
346354
optGroup ::
@@ -716,6 +724,7 @@ translateHandlerMatch ::
716724
(Var v) => OptInfos v -> v -> v -> SuperNormal v -> Maybe (SuperNormal v)
717725
translateHandlerMatch opts self ah (Lambda ccs (ABTN.TAbss args body))
718726
| v : vs <- shiftArgs args,
727+
bound <- Set.fromList (self : args),
719728
TMatch u branches <- body,
720729
u == v,
721730
MatchRequest cs df <- branches,
@@ -725,7 +734,7 @@ translateHandlerMatch opts self ah (Lambda ccs (ABTN.TAbss args body))
725734
. ABTN.TAbss args
726735
. TMatch u
727736
. flip MatchRequest df
728-
<$> traverse3 (affineHandlerCase opts self vs ah) cs
737+
<$> traverse3 (affineHandlerCase opts self bound vs ah) cs
729738
| otherwise = Nothing
730739
where
731740
ar = freshAff 2
@@ -753,12 +762,12 @@ augmentHandlerEntry thunk0 mv0 ah body
753762
-- Recognizes an affine handler case, yielding a translated efficient
754763
-- version if it is one.
755764
affineHandlerCase ::
756-
(Var v) => OptInfos v -> v -> [v] -> v -> ANormal v -> Maybe (ANormal v)
757-
affineHandlerCase opts self vs rec br
765+
(Var v) => OptInfos v -> v -> Set v -> [v] -> v -> ANormal v -> Maybe (ANormal v)
766+
affineHandlerCase opts self bound vs rec br
758767
| ABTN.TAbss us body <- br,
759768
TShift _ kf0 body <- body,
760769
TName kf (Left (Builtin "jumpCont")) [kf1] body <- body,
761-
bound <- Set.fromList (kf0 : kf : us),
770+
bound <- Set.union bound (Set.fromList (kf0 : kf : us)),
762771
kf0 == kf1 =
763772
ABTN.TAbss us
764773
<$> affinePreBranch opts self bound vs rec ar kf body

unison-runtime/src/Unison/Runtime/Machine.hs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ import Unison.Runtime.ANF as ANF
6565
)
6666
import Unison.Runtime.ANF qualified as ANF
6767
import Unison.Runtime.ANF.Optimize qualified as ANF
68+
#ifdef CODE_SERIAL_CHECK
69+
import Unison.Runtime.ANF.Serialize (serializeCode, deserializeCode)
70+
#endif
6871
import Unison.Runtime.Array as PA
6972
import Unison.Runtime.Builtin hiding (unitValue)
7073
import Unison.Runtime.Exception hiding (die)
@@ -1233,13 +1236,35 @@ addRefs vfrsh vfrom vto rs = do
12331236
evaluateSTM :: a -> STM a
12341237
evaluateSTM x = unsafeIOToSTM (evaluate x)
12351238

1239+
-- If this flag is set, all code is run through serialization before
1240+
-- loading. This renames variables, and it's possible a problem would
1241+
-- only be visible with the renamed variables. This allows testing
1242+
-- these cases just by rebuilding ucm, rather than actually concocting
1243+
-- a test that involves remote code loading.
1244+
#if defined(CODE_SERIAL_CHECK)
1245+
1246+
normalizeCode :: Code -> Code
1247+
normalizeCode co = case deserializeCode (serializeCode False co) of
1248+
Left _ -> error "normalizeCode: impossible"
1249+
Right co -> co
1250+
1251+
normalizeCodes :: [(Reference, Code)] -> [(Reference, Code)]
1252+
normalizeCodes = fmap $ second normalizeCode
1253+
1254+
#else
1255+
1256+
normalizeCodes :: [(Reference, Code)] -> [(Reference, Code)]
1257+
normalizeCodes = id
1258+
1259+
#endif
1260+
12361261
cacheAdd0 ::
12371262
S.Set Reference ->
12381263
[(Reference, Code)] ->
12391264
[(Reference, Set Reference)] ->
12401265
CCache ->
12411266
IO ()
1242-
cacheAdd0 ntys0 termSuperGroups sands cc = do
1267+
cacheAdd0 ntys0 (normalizeCodes -> termSuperGroups) sands cc = do
12431268
let toAdd = M.fromList (termSuperGroups <&> second codeGroup)
12441269
(unresolvedCacheableCombs, unresolvedNonCacheableCombs) <- atomically $ do
12451270
have <- readTVar (intermed cc)

unison-runtime/unison-runtime.cabal

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ flag optchecks
2929
manual: True
3030
default: False
3131

32+
flag codeserialchecks
33+
manual: True
34+
default: False
35+
3236
flag stackchecks
3337
manual: True
3438
default: False
@@ -161,6 +165,8 @@ library
161165
cpp-options: -DOPT_CHECK
162166
build-depends:
163167
inspection-testing
168+
if flag(codeserialchecks)
169+
cpp-options: -DCODE_SERIAL_CHECK
164170
if flag(dumpcore)
165171
ghc-options: -ddump-simpl -ddump-stg-final -ddump-to-file -dsuppress-coercions -dsuppress-idinfo -dsuppress-module-prefixes -ddump-str-signatures -ddump-simpl-stats
166172

0 commit comments

Comments
 (0)