Skip to content

Commit

Permalink
🧩 [consolidate]: Mostly done with old WithGraphAllocT.
Browse files Browse the repository at this point in the history
  • Loading branch information
rolyp committed Sep 22, 2023
1 parent 0b9c360 commit 66c0b28
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 77 deletions.
6 changes: 3 additions & 3 deletions src/Eval.purs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import Pretty (prettyP)
import Primitive (intPair, string)
import Trace (AppTrace(..), Trace(..), VarDef(..)) as T
import Trace (AppTrace, ForeignTrace, ForeignTrace'(..), Match(..), Trace)
import Util (type (×), MayFailT, absurd, both, check, error, successful, orElse, with, (×))
import Util (type (×), absurd, both, check, error, successful, orElse, with, (×))
import Util.Pair (unzip) as P
import Val (Fun(..), Val(..)) as V
import Val (class Ann, DictRep(..), Env, ForeignOp'(..), MatrixRep(..), (<+>), Val, for, lookup', restrict)
Expand Down Expand Up @@ -161,10 +161,10 @@ eval γ (LetRec ρ e) α = do
t × v <- eval (γ <+> γ') e α
pure $ T.LetRec (erase <$> ρ) t × v

eval_module :: forall a m. Monad m => Ann a => Env a -> Module a -> a -> MayFailT m (Env a)
eval_module :: forall a m. MonadError String m => Ann a => Env a -> Module a -> a -> m (Env a)
eval_module γ = go empty
where
go :: Env a -> Module a -> a -> MayFailT m (Env a)
go :: Env a -> Module a -> a -> m (Env a)
go γ' (Module Nil) _ = pure γ'
go y' (Module (Left (VarDef σ e) : ds)) α = do
_ × v <- eval (γ <+> y') e α
Expand Down
86 changes: 20 additions & 66 deletions src/EvalGraph.purs
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,35 @@ module EvalGraph
, eval
, evalWithConfig
, eval_module
, graphGC
, match
, patternMismatch
) where

import Prelude hiding (apply, add)

import Bindings (varAnon)
import Control.Monad.Error.Class (class MonadError, class MonadThrow, throwError)
import Control.Monad.Error.Class (class MonadError, throwError)
import Data.Array (range, singleton) as A
import Data.Either (Either(..))
import Data.Exists (runExists)
--import Data.Identity (Identity(..))
import Data.List (List(..), (:), length, snoc, unzip, zip)
import Data.Set (Set, empty, insert, {-intersection, -}singleton, union)
import Data.Set (Set, empty, insert, intersection, singleton, union)
import Data.Set as S
import Data.Traversable (sequence, traverse)
import Data.Tuple (fst)
import DataType (checkArity, arity, consistentWith, dataTypeFor, showCtr)
import Dict (disjointUnion, fromFoldable, empty, get, keys, lookup, singleton) as D
import Expr (Cont(..), Elim(..), Expr(..), VarDef(..), RecDefs, Module(..), fv, asExpr)
--import GaloisConnection (GaloisConnection)
import GaloisConnection (GaloisConnection)
import Graph (class Graph, Vertex)
--import Graph (vertices) as G
import Graph.GraphWriter (class MonadGraphAlloc, {-WithGraphAllocT, -}alloc, new, {-runWithGraphAlloc2T, -}runWithGraphAllocT)
--import Graph.Slice (bwdSlice, fwdSlice, vertices)
--import Lattice (Raw)
import Graph (vertices) as G
import Graph.GraphWriter (class MonadGraphAlloc, alloc, new, runWithGraphAlloc2T)
import Graph.Slice (bwdSlice, fwdSlice, vertices)
import Lattice (Raw)
import Pretty (prettyP)
import Primitive (string, intPair)
import Util (type (+), type (×), (×), check, error, successful, with, orElse)
import Util (type (×), (×), check, error, successful, with, orElse)
import Util.Pair (unzip) as P
import Val (DictRep(..), Env, ForeignOp'(..), MatrixRep(..), Val, for, lookup', restrict, (<+>))
import Val (Val(..), Fun(..)) as V
Expand All @@ -47,13 +47,7 @@ type GraphConfig g =
patternMismatch :: String -> String -> String
patternMismatch s s' = "Pattern mismatch: found " <> s <> ", expected " <> s'

match
:: forall m
. MonadGraphAlloc m
=> MonadError String m
=> Val Vertex
-> Elim Vertex
-> m (Env Vertex × Cont Vertex × Set Vertex)
match :: forall m. MonadGraphAlloc m => Val Vertex -> Elim Vertex -> m (Env Vertex × Cont Vertex × Set Vertex)
match v (ElimVar x κ)
| x == varAnon = pure (D.empty × κ × empty)
| otherwise = pure (D.singleton x v × κ × empty)
Expand All @@ -73,13 +67,7 @@ match (V.Record α xvs) (ElimRecord xs κ) = do
pure $ γ × κ' × (insert α αs)
match v (ElimRecord xs _) = throwError (patternMismatch (prettyP v) (show xs))

matchMany
:: forall m
. MonadGraphAlloc m
=> MonadError String m
=> List (Val Vertex)
-> Cont Vertex
-> m (Env Vertex × Cont Vertex × Set Vertex)
matchMany :: forall m. MonadGraphAlloc m => List (Val Vertex) -> Cont Vertex -> m (Env Vertex × Cont Vertex × Set Vertex)
matchMany Nil κ = pure (D.empty × κ × empty)
matchMany (v : vs) (ContElim σ) = do
γ × κ × αs <- match v σ
Expand All @@ -89,29 +77,16 @@ matchMany (_ : vs) (ContExpr _) = throwError $
show (length vs + 1) <> " extra argument(s) to constructor/record; did you forget parentheses in lambda pattern?"
matchMany _ _ = error "absurd"

closeDefs
:: forall m
. MonadGraphAlloc m
=> MonadThrow String m
=> Env Vertex
-> RecDefs Vertex
-> Set Vertex
-> m (Env Vertex)
closeDefs :: forall m. MonadGraphAlloc m => Env Vertex -> RecDefs Vertex -> Set Vertex -> m (Env Vertex)
closeDefs γ ρ αs =
flip traverse ρ \σ ->
let
ρ' = ρ `for` σ
in
V.Fun <$> new αs <@> V.Closure`restrict` (fv ρ' `S.union` fv σ)) ρ' σ
V.Fun <$> new αs <@> V.Closure`restrict` (fv ρ' `union` fv σ)) ρ' σ

{-# Evaluation #-}
apply
:: forall m
. MonadGraphAlloc m
=> MonadError String m
=> Val Vertex
-> Val Vertex
-> m (Val Vertex)
apply :: forall m. MonadGraphAlloc m => Val Vertex -> Val Vertex -> m (Val Vertex)
apply (V.Fun α (V.Closure γ1 ρ σ)) v = do
γ2 <- closeDefs γ1 ρ (singleton α)
γ3 × κ × αs <- match v σ
Expand All @@ -134,14 +109,7 @@ apply (V.Fun α (V.PartialConstr c vs)) v = do
n = successful (arity c)
apply _ v = throwError $ "Found " <> prettyP v <> ", expected function"

eval
:: forall m
. MonadGraphAlloc m
=> MonadError String m
=> Env Vertex
-> Expr Vertex
-> Set Vertex
-> m (Val Vertex)
eval :: forall m. MonadGraphAlloc m => Env Vertex -> Expr Vertex -> Set Vertex -> m (Val Vertex)
eval γ (Var x) _ = lookup' x γ
eval γ (Op op) _ = lookup' op γ
eval _ (Int α n) αs = V.Int <$> new (insert α αs) <@> n
Expand Down Expand Up @@ -193,14 +161,7 @@ eval γ (LetRec ρ e) αs = do
γ' <- closeDefs γ ρ αs
eval (γ <+> γ') e αs

eval_module
:: forall m
. MonadGraphAlloc m
=> MonadError String m
=> Env Vertex
-> Module Vertex
-> Set Vertex
-> m (Env Vertex)
eval_module :: forall m. MonadGraphAlloc m => Env Vertex -> Module Vertex -> Set Vertex -> m (Env Vertex)
eval_module γ = go D.empty
where
go :: Env Vertex -> Module Vertex -> Set Vertex -> m (Env Vertex)
Expand All @@ -214,22 +175,16 @@ eval_module γ = go D.empty
go (γ' <+> γ'') (Module ds) αs

-- TODO: Inline into graphGC
evalWithConfig
:: forall g m a
. Monad m
=> Graph g
=> GraphConfig g
-> Expr a
-> m (String + ((g × Int) × Expr Vertex × Val Vertex))
evalWithConfig :: forall g m a. MonadError String m => Graph g => GraphConfig g -> Expr a -> m ((g × Int) × Expr Vertex × Val Vertex)
evalWithConfig { g, n, γα } e =
runWithGraphAllocT (g × n) $ do
runWithGraphAlloc2T (g × n) $ do
<- alloc e
<- eval γα eα S.empty
pure (eα × vα)
{-

graphGC
:: forall g m
. Monad m
. MonadError String m
=> Graph g
=> GraphConfig g
-> Raw Expr
Expand All @@ -245,4 +200,3 @@ graphGC { g: g0, n, γα } e = do
<- alloc e
<- eval γα eα S.empty
pure (vα × eα)
-}
12 changes: 11 additions & 1 deletion src/Graph/GraphWriter.purs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Graph.GraphWriter
( AdjMapEntries
, WithAllocT
, WithGraphAllocT
, WithGraphAlloc2T
, WithGraph
, WithGraphT
, class MonadAlloc
Expand Down Expand Up @@ -63,7 +64,13 @@ instance Monad m => MonadAlloc (WithAllocT m) where
instance Monad m => MonadAlloc (WithGraphAllocT m) where
fresh = lift fresh

instance (Monad m, MonadAlloc (WithGraphAllocT m), MonadGraph (WithGraphAllocT m)) => MonadGraphAlloc (WithGraphAllocT m) where
instance MonadError String m => MonadGraphAlloc (WithGraphAlloc2T m) where
new αs = do
α <- fresh
extend α αs
pure α

instance Monad m => MonadGraphAlloc (WithGraphAllocT m) where
new αs = do
α <- fresh
extend α αs
Expand All @@ -73,6 +80,9 @@ instance Monad m => MonadGraph (WithGraphT m) where
extend α αs =
void $ modify_ $ (:) (α × αs)

instance Monad m => MonadGraph (WithGraphAlloc2T m) where
extend α = lift <<< extend α

instance Monad m => MonadGraph (WithGraphAllocT m) where
extend α = lift <<< lift <<< extend α

Expand Down
1 change: 0 additions & 1 deletion src/Primitive/Defs.purs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ dict_intersectionWith = mkExists $ ForeignOp' { arity: 3, op': op, op: fwd, op_b
d'' <-
sequence $
D.intersectionWith (\(β × u) (β' × u') -> (β ∧ β' × _) <$> apply2 (v × u × u')) d d'
-- :: MayFailT m (Dict (_ × (AppTrace × AppTrace) × Val _))
pure $ (erase v × (d'' <#> snd >>> fst)) × Dictionary (α ∧ α') (DictRep (d'' <#> second snd))
fwd _ = throwError "Function and two dictionaries expected"

Expand Down
5 changes: 1 addition & 4 deletions src/Util.purs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Prelude hiding (absurd)
import Control.Apply (lift2)
import Control.Comonad (extract)
import Control.Monad.Error.Class (class MonadError, class MonadThrow, catchError, throwError)
import Control.Monad.Except (Except, ExceptT(..), runExceptT, except)
import Control.Monad.Except (Except, ExceptT(..), runExceptT)
import Control.MonadPlus (class MonadPlus, empty)
import Data.Array ((!!), updateAt)
import Data.Either (Either(..))
Expand Down Expand Up @@ -69,9 +69,6 @@ ignoreMessage = runExceptT >>> extract >>> case _ of
(Left _) -> Nothing
(Right x) -> Just x

report :: String -> forall a m. Applicative m => MayFailT m a
report s = except $ Left s

fromRight :: forall a. Either String a -> a
fromRight (Right x) = x
fromRight (Left msg) = error msg
Expand Down
4 changes: 2 additions & 2 deletions test/Util.purs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import App.Fig (LinkFigSpec)
import App.Util (Selector)
import Benchmark.Util (BenchRow(..), GraphRow, TraceRow, preciseTime, tdiff)
import Control.Monad.Error.Class (class MonadThrow)
import Control.Monad.Except (except, runExceptT)
import Control.Monad.Except (runExceptT)
import Control.Monad.Trans.Class (lift)
import Data.Either (Either(..))
import Data.Foldable (foldl)
Expand Down Expand Up @@ -106,7 +106,7 @@ testGraph s gconf { δv, bwd_expect, fwd_expect } = do
-- | Eval
e <- desug s
tEval1 <- preciseTime
(g × _) × eα × vα <- evalWithConfig gconf e >>= except
(g × _) × eα × vα <- evalWithConfig gconf e
tEval2 <- preciseTime

-- | Backward
Expand Down

0 comments on commit 66c0b28

Please sign in to comment.