Skip to content

Commit 42747bc

Browse files
committed
Use QuantifiedConstraints instead of Eq1,Ord1,Show1
1 parent 94339b9 commit 42747bc

22 files changed

+72
-126
lines changed

README.md

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,22 @@ Our second step is to instance `Language` for our `SymExpr`
4747
represented in e-graph and on which equality saturation can be run:
4848

4949
```hs
50-
class (Analysis l, Traversable l, Ord1 l) => Language l
50+
type Language l = (Traversable l, a. Ord a => Ord (l a))
5151
```
5252

53-
To declare a `Language` we must write the "base functor" of `SymExpr`
54-
(i.e. use a type parameter where the recursion points used to be in the original `SymExpr`),
55-
then instance `Traversable`, `Ord1`, and write an `Analysis` instance for it (see next section).
53+
To declare a `Language` we must write the "base functor" of `SymExpr` (i.e. use
54+
a type parameter where the recursion points used to be in the original
55+
`SymExpr`), then instance `Traversable l`, ` a. Ord a => Ord (l a)` (we can do
56+
it automatically through deriving), and write an `Analysis` instance for it (see
57+
next section).
5658

5759
```hs
5860
data SymExpr a = Const Double
5961
| Symbol String
6062
| a :+: a
6163
| a :*: a
6264
| a :/: a
63-
deriving (Functor, Foldable, Traversable)
65+
deriving (Eq, Ord, Show, Functor, Foldable, Traversable)
6466
infix 6 :+:
6567
infix 7 :*:, :/:
6668
```
@@ -76,14 +78,6 @@ fixed-point form
7678
e1 :: Fix SymExpr
7779
e1 = Fix (Fix (Fix (Symbol "x") :*: Fix (Const 2)) :/: (Fix (Const 2))) -- (x*2)/2
7880
```
79-
80-
We've already automagically derived `Functor`, `Foldable` and `Traversable`
81-
instances, and can use the following template haskell functions from `derive-compat` to derive `Ord1`.
82-
```hs
83-
deriveEq1 ''SymExpr
84-
deriveOrd1 ''SymExpr
85-
```
86-
8781
Then, we define an `Analysis` for our `SymExpr`.
8882

8983
### Analysis

hegg.cabal

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ test-suite hegg-test
114114
build-depends: base,
115115
hegg,
116116
containers,
117-
deriving-compat >= 0.6 && < 0.7,
118117
tasty >= 1.4 && < 1.5,
119118
tasty-hunit >= 0.10 && < 0.11,
120119
tasty-quickcheck >= 0.10 && < 0.11
@@ -127,7 +126,6 @@ benchmark hegg-bench
127126
type: exitcode-stdio-1.0
128127
build-depends: base, hegg,
129128
containers,
130-
deriving-compat,
131129
tasty,
132130
tasty-hunit,
133131
tasty-quickcheck,

src/Data/Equality/Analysis.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
{-# LANGUAGE FlexibleContexts #-}
77
{-# LANGUAGE TypeFamilies #-}
88
{-# LANGUAGE MultiParamTypeClasses #-}
9+
{-# LANGUAGE ImpredicativeTypes #-}
910
{-|
1011
1112
E-class analysis, which allows the concise expression of a program analysis over

src/Data/Equality/Extraction.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{-# LANGUAGE ScopedTypeVariables #-}
22
{-# LANGUAGE ViewPatterns #-}
3+
{-# LANGUAGE MonoLocalBinds #-}
34
{-|
45
Given an e-graph representing expressions of our language, we might want to
56
extract, out of all expressions represented by some equivalence class, /the best/

src/Data/Equality/Graph/Classes.hs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ module Data.Equality.Graph.Classes
1010

1111
import qualified Data.Set as S
1212

13-
import Data.Functor.Classes
14-
1513
import Data.Equality.Graph.Classes.Id
1614
import Data.Equality.Graph.Nodes
1715

@@ -30,6 +28,6 @@ data EClass analysis_domain language = EClass
3028
, eClassParents :: !(SList (ClassId, ENode language)) -- ^ E-nodes which are parents of this e-class and their corresponding e-class ids.
3129
}
3230

33-
instance (Show a, Show1 l) => Show (EClass a l) where
31+
instance (Show a, Show (l ClassId)) => Show (EClass a l) where
3432
show (EClass a b d (SList c _)) = "Id: " <> show a <> "\nNodes: " <> show b <> "\nParents: " <> show c <> "\nData: " <> show d
3533

src/Data/Equality/Graph/Internal.hs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
-}
77
module Data.Equality.Graph.Internal where
88

9-
import Data.Functor.Classes
10-
119
import Data.Equality.Graph.ReprUnionFind
1210
import Data.Equality.Graph.Classes
1311
import Data.Equality.Graph.Nodes
@@ -31,7 +29,7 @@ type Memo l = NodeMap l ClassId
3129
-- | Maintained worklist of e-class ids that need to be “upward merged”
3230
type Worklist l = [(ClassId, ENode l)]
3331

34-
instance (Show a, Show1 l) => Show (EGraph a l) where
32+
instance (Show a, Show (l ClassId)) => Show (EGraph a l) where
3533
show (EGraph a b c d e) =
3634
"UnionFind: " <> show a <>
3735
"\n\nE-Classes: " <> show b <>

src/Data/Equality/Graph/Lens.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import Data.Equality.Graph.ReprUnionFind
2424

2525
-- | A 'Lens'' as defined in other lenses libraries
2626
type Lens' s a = forall f. Functor f => (a -> f a) -> (s -> f s)
27-
type Traversal s t a b = forall f. Applicative f => (a -> f b) -> s -> f t
27+
type Traversal s t a b = forall f. Applicative f => (a -> f b) -> (s -> f t)
2828

2929
-- outdated comment for "getClass":
3030
--

src/Data/Equality/Graph/Monad.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{-# LANGUAGE TupleSections #-}
2+
{-# LANGUAGE MonoLocalBinds #-}
23
{-|
34
Monadic interface to e-graph stateful computations
45
-}

src/Data/Equality/Graph/Nodes.hs

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
{-# LANGUAGE TypeFamilies #-}
2-
{-# LANGUAGE TupleSections #-}
32
{-# LANGUAGE FlexibleInstances #-}
4-
{-# LANGUAGE DeriveGeneric #-}
5-
{-# LANGUAGE ViewPatterns #-}
3+
{-# LANGUAGE UnicodeSyntax #-}
64
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
75
{-# LANGUAGE DeriveTraversable #-}
6+
{-# LANGUAGE RankNTypes #-}
7+
{-# LANGUAGE QuantifiedConstraints #-}
8+
{-# LANGUAGE StandaloneDeriving #-}
9+
{-# LANGUAGE UndecidableInstances #-}
810
{-|
911
1012
Module defining e-nodes ('ENode'), the e-node function symbol ('Operator'), and
@@ -13,7 +15,6 @@ mappings from e-nodes ('NodeMap').
1315
-}
1416
module Data.Equality.Graph.Nodes where
1517

16-
import Data.Functor.Classes
1718
import Data.Foldable
1819
import Data.Bifunctor
1920

@@ -34,6 +35,10 @@ import Data.Equality.Graph.Classes.Id
3435
-- parametrized over 'ClassId', i.e. all recursive fields are rather e-class ids.
3536
newtype ENode l = Node { unNode :: l ClassId }
3637

38+
deriving instance Eq (l ClassId) => (Eq (ENode l))
39+
deriving instance Ord (l ClassId) => (Ord (ENode l))
40+
deriving instance Show (l ClassId) => (Show (ENode l))
41+
3742
-- | Get the children e-class ids of an e-node
3843
children :: Traversable l => ENode l -> [ClassId]
3944
children = toList . unNode
@@ -45,68 +50,54 @@ children = toList . unNode
4550
-- this means children e-classes are ignored.
4651
newtype Operator l = Operator { unOperator :: l () }
4752

53+
deriving instance Eq (l ()) => (Eq (Operator l))
54+
deriving instance Ord (l ()) => (Ord (Operator l))
55+
deriving instance Show (l ()) => (Show (Operator l))
56+
4857
-- | Get the operator (function symbol) of an e-node
4958
operator :: Traversable l => ENode l -> Operator l
5059
operator = Operator . void . unNode
5160
{-# INLINE operator #-}
5261

53-
instance Eq1 l => (Eq (ENode l)) where
54-
(==) (Node a) (Node b) = liftEq (==) a b
55-
{-# INLINE (==) #-}
56-
57-
instance Ord1 l => (Ord (ENode l)) where
58-
compare (Node a) (Node b) = liftCompare compare a b
59-
{-# INLINE compare #-}
60-
61-
instance Show1 l => (Show (ENode l)) where
62-
showsPrec p (Node l) = liftShowsPrec showsPrec showList p l
63-
64-
instance Eq1 l => (Eq (Operator l)) where
65-
(==) (Operator a) (Operator b) = liftEq (\_ _ -> True) a b
66-
{-# INLINE (==) #-}
67-
68-
instance Ord1 l => (Ord (Operator l)) where
69-
compare (Operator a) (Operator b) = liftCompare (\_ _ -> EQ) a b
70-
{-# INLINE compare #-}
71-
72-
instance Show1 l => (Show (Operator l)) where
73-
showsPrec p (Operator l) = liftShowsPrec (const . const $ showString "") (const $ showString "") p l
74-
7562
-- * Node Map
7663

7764
-- | A mapping from e-nodes of @l@ to @a@
7865
newtype NodeMap (l :: Type -> Type) a = NodeMap { unNodeMap :: M.Map (ENode l) a }
7966
-- TODO: Investigate whether it would be worth it requiring a trie-map for the
8067
-- e-node definition. Probably it isn't better since e-nodes aren't recursive.
81-
deriving (Show, Functor, Foldable, Traversable, Semigroup, Monoid)
68+
deriving (Functor, Foldable, Traversable)
69+
70+
deriving instance (Show a, Show (l ClassId)) => Show (NodeMap l a)
71+
deriving instance Ord (l ClassId) => Semigroup (NodeMap l a)
72+
deriving instance Ord (l ClassId) => Monoid (NodeMap l a)
8273

8374
-- | Insert a value given an e-node in a 'NodeMap'
84-
insertNM :: Ord1 l => ENode l -> a -> NodeMap l a -> NodeMap l a
75+
insertNM :: Ord (l ClassId) => ENode l -> a -> NodeMap l a -> NodeMap l a
8576
insertNM e v (NodeMap m) = NodeMap (M.insert e v m)
8677
{-# INLINE insertNM #-}
8778

8879
-- | Lookup an e-node in a 'NodeMap'
89-
lookupNM :: Ord1 l => ENode l -> NodeMap l a -> Maybe a
80+
lookupNM :: Ord (l ClassId) => ENode l -> NodeMap l a -> Maybe a
9081
lookupNM e = M.lookup e . unNodeMap
9182
{-# INLINE lookupNM #-}
9283

9384
-- | Delete an e-node in a 'NodeMap'
94-
deleteNM :: Ord1 l => ENode l -> NodeMap l a -> NodeMap l a
85+
deleteNM :: Ord (l ClassId) => ENode l -> NodeMap l a -> NodeMap l a
9586
deleteNM e (NodeMap m) = NodeMap (M.delete e m)
9687
{-# INLINE deleteNM #-}
9788

9889
-- | Insert a value and lookup by e-node in a 'NodeMap'
99-
insertLookupNM :: Ord1 l => ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
90+
insertLookupNM :: Ord (l ClassId) => ENode l -> a -> NodeMap l a -> (Maybe a, NodeMap l a)
10091
insertLookupNM e v (NodeMap m) = second NodeMap $ M.insertLookupWithKey (\_ a _ -> a) e v m
10192
{-# INLINE insertLookupNM #-}
10293

10394
-- | As 'Data.Map.foldlWithKeyNM'' but in a 'NodeMap'
104-
foldlWithKeyNM' :: Ord1 l => (b -> ENode l -> a -> b) -> b -> NodeMap l a -> b
95+
foldlWithKeyNM' :: Ord (l ClassId) => (b -> ENode l -> a -> b) -> b -> NodeMap l a -> b
10596
foldlWithKeyNM' f b = M.foldlWithKey' f b . unNodeMap
10697
{-# INLINE foldlWithKeyNM' #-}
10798

10899
-- | As 'Data.Map.foldrWithKeyNM'' but in a 'NodeMap'
109-
foldrWithKeyNM' :: Ord1 l => (ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
100+
foldrWithKeyNM' :: Ord (l ClassId) => (ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
110101
foldrWithKeyNM' f b = M.foldrWithKey' f b . unNodeMap
111102
{-# INLINE foldrWithKeyNM' #-}
112103

src/Data/Equality/Language.hs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
{-# LANGUAGE FlexibleContexts #-}
2+
{-# LANGUAGE UnicodeSyntax #-}
3+
{-# LANGUAGE RankNTypes #-}
4+
{-# LANGUAGE QuantifiedConstraints #-}
25
{-# LANGUAGE ConstraintKinds #-}
6+
{-# LANGUAGE StandaloneKindSignatures #-}
7+
{-# LANGUAGE FlexibleInstances #-}
8+
{-# LANGUAGE UndecidableInstances #-}
39
{-|
410
511
Defines 'Language', which is the required constraint on /expressions/ that are
@@ -29,7 +35,7 @@ instance Language Expr
2935
-}
3036
module Data.Equality.Language where
3137

32-
import Data.Functor.Classes
38+
import Data.Kind
3339

3440
-- | A 'Language' is the required constraint on /expressions/ that are to be
3541
-- represented in an e-graph.
@@ -39,5 +45,7 @@ import Data.Functor.Classes
3945
-- e-graphs), note that it must satisfy the other class constraints. In
4046
-- particular an 'Data.Equality.Analysis.Analysis' must be defined for the
4147
-- language.
42-
type Language l = (Traversable l, Ord1 l)
48+
type Language :: (Type -> Type) -> Constraint
49+
class ( a. Ord a => Ord (l a), Traversable l) => Language l
50+
instance ( a. Ord a => Ord (l a), Traversable l) => Language l
4351

0 commit comments

Comments
 (0)