From 0d8cc5d7467b79184fc126f7b8607944c88bf18c Mon Sep 17 00:00:00 2001 From: Roly Perera Date: Wed, 27 Sep 2023 14:50:02 +0100 Subject: [PATCH 1/2] =?UTF-8?q?=E2=9D=97=20[incomplete]:=20Now=20for=20pri?= =?UTF-8?q?mitives.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/BoolAlg.purs | 2 +- src/Eval2.purs | 163 ++++++++++++++++++++++++++++++++++++++ src/EvalBwd.purs | 7 +- src/GaloisConnection.purs | 6 +- 4 files changed, 172 insertions(+), 6 deletions(-) create mode 100644 src/Eval2.purs diff --git a/src/BoolAlg.purs b/src/BoolAlg.purs index e9c53e568..13b73fd88 100644 --- a/src/BoolAlg.purs +++ b/src/BoolAlg.purs @@ -38,7 +38,7 @@ powerset xs = , neg: (xs `S.difference` _) } -slices :: forall f a. Apply f => BoolAlg a -> f a -> BoolAlg (f a) +slices :: forall f a b. Apply f => BoolAlg a -> f b -> BoolAlg (f a) slices 𝒶 x = { top: x <#> const 𝒶.top , bot: x <#> const 𝒶.bot diff --git a/src/Eval2.purs b/src/Eval2.purs new file mode 100644 index 000000000..089408a37 --- /dev/null +++ b/src/Eval2.purs @@ -0,0 +1,163 @@ +module Eval2 where + +import Prelude hiding (absurd, apply) +import Bindings (varAnon) +import BoolAlg (BoolAlg) +import Control.Monad.Error.Class (class MonadError) +import Data.Array (fromFoldable) as A +import Data.Bifunctor (bimap) +import Data.Either (Either(..)) +import Data.Exists (mkExists, runExists) +import Data.List (List(..), (:), length, range, singleton, unzip, zip) +import Data.Maybe (Maybe(..)) +import Data.Profunctor.Strong (first) +import Data.Set (fromFoldable, toUnfoldable, singleton) as S +import Data.Set (union, subset) +import Data.Traversable (sequence, traverse) +import Data.Tuple (fst, snd) +import DataType (Ctr, arity, consistentWith, dataTypeFor, showCtr) +import Dict (disjointUnion, get, empty, lookup, keys) +import Dict (fromFoldable, singleton, unzip) as D +import Effect.Exception (Error) +import Expr (Cont(..), Elim(..), Expr(..), Module(..), RecDefs, VarDef(..), asExpr, fv) +import Lattice ((∧), erase, top) +import Pretty (prettyP) +import Primitive (intPair, string) +import Trace (AppTrace(..), Trace(..), VarDef(..)) as T +import Trace (AppTrace, ForeignTrace, ForeignTrace'(..), Match(..), Trace) +import Util (type (×), absurd, both, check, error, orElse, successful, throw, with, (×)) +import Util.Pair (unzip) as P +import Val (Fun(..), Val(..)) as V +import Val (class Ann, class Highlightable, DictRep(..), Env, ForeignOp'(..), MatrixRep(..), Val, for, lookup', restrict, (<+>)) + +patternMismatch :: String -> String -> String +patternMismatch s s' = "Pattern mismatch: found " <> s <> ", expected " <> s' + +match :: forall a m. MonadError Error m => Highlightable a => BoolAlg a -> Val a -> Elim a -> m (Env a × Cont a × a × Match) +match 𝒶 v (ElimVar x κ) + | x == varAnon = pure (empty × κ × 𝒶.top × MatchVarAnon (erase v)) + | otherwise = pure (D.singleton x v × κ × 𝒶.top × MatchVar x (erase v)) +match 𝒶 (V.Constr α c vs) (ElimConstr m) = do + with "Pattern mismatch" $ S.singleton c `consistentWith` keys m + κ <- lookup c m # orElse ("Incomplete patterns: no branch for " <> showCtr c) + γ × κ' × α' × ws <- matchMany 𝒶 vs κ + pure (γ × κ' × (α `𝒶.meet` α') × MatchConstr c ws) +match _ v (ElimConstr m) = do + d <- dataTypeFor $ keys m + throw $ patternMismatch (prettyP v) (show d) +match 𝒶 (V.Record α xvs) (ElimRecord xs κ) = do + check (subset xs (S.fromFoldable $ keys xvs)) $ patternMismatch (show (keys xvs)) (show xs) + let xs' = xs # S.toUnfoldable + γ × κ' × α' × ws <- matchMany 𝒶 (xs' <#> flip get xvs) κ + pure (γ × κ' × (α `𝒶.meet` α') × MatchRecord (D.fromFoldable (zip xs' ws))) +match _ v (ElimRecord xs _) = throw $ patternMismatch (prettyP v) (show xs) + +matchMany :: forall a m. MonadError Error m => Highlightable a => BoolAlg a -> List (Val a) -> Cont a -> m (Env a × Cont a × a × List Match) +matchMany 𝒶 Nil κ = pure (empty × κ × 𝒶.top × Nil) +matchMany 𝒶 (v : vs) (ContElim σ) = do + γ × κ' × α × w <- match 𝒶 v σ + γ' × κ'' × β × ws <- matchMany 𝒶 vs κ' + pure $ γ `disjointUnion` γ' × κ'' × (α `𝒶.meet` β) × (w : ws) +matchMany _ (_ : vs) (ContExpr _) = throw $ + show (length vs + 1) <> " extra argument(s) to constructor/record; did you forget parentheses in lambda pattern?" +matchMany _ _ _ = error absurd + +closeDefs :: forall a. Env a -> RecDefs a -> a -> Env a +closeDefs γ ρ α = ρ <#> \σ -> + let ρ' = ρ `for` σ in V.Fun α $ V.Closure (γ `restrict` (fv ρ' `union` fv σ)) ρ' σ + +checkArity :: forall m. MonadError Error m => Ctr -> Int -> m Unit +checkArity c n = do + n' <- arity c + check (n' >= n) (showCtr c <> " got " <> show n <> " argument(s), expects at most " <> show n') + +apply :: forall a m. MonadError Error m => Highlightable a => BoolAlg a -> Val a × Val a -> m (AppTrace × Val a) +apply 𝒶 (V.Fun β (V.Closure γ1 ρ σ) × v) = do + let γ2 = closeDefs γ1 ρ β + γ3 × e'' × β' × w <- match 𝒶 v σ + t'' × v'' <- eval 𝒶 (γ1 <+> γ2 <+> γ3) (asExpr e'') (β `𝒶.meet` β') + pure $ T.AppClosure (S.fromFoldable (keys ρ)) w t'' × v'' +apply 𝒶 (V.Fun α (V.Foreign φ vs) × v) = do + t × v'' <- runExists apply' φ + pure $ T.AppForeign (length vs + 1) t × v'' + where + vs' = vs <> singleton v + + apply' :: forall t. ForeignOp' t -> m (ForeignTrace × Val _) + apply' (ForeignOp' φ') = do + t × v'' <- do + if φ'.arity > length vs' then pure $ Nothing × V.Fun α (V.Foreign φ vs') + else first Just <$> φ'.op vs' + pure $ mkExists (ForeignTrace' (ForeignOp' φ') t) × v'' +apply _ (V.Fun α (V.PartialConstr c vs) × v) = do + check (length vs < n) ("Too many arguments to " <> showCtr c) + pure $ T.AppConstr c × v' + where + n = successful (arity c) + v' = + if length vs < n - 1 then V.Fun α $ V.PartialConstr c (vs <> singleton v) + else V.Constr α c (vs <> singleton v) +apply _ (_ × v) = throw $ "Found " <> prettyP v <> ", expected function" + +apply2 :: forall a m. MonadError Error m => Ann a => Val a × Val a × Val a -> m ((AppTrace × AppTrace) × Val a) +apply2 (u1 × v1 × v2) = do + t1 × u2 <- apply (u1 × v1) + t2 × v <- apply (u2 × v2) + pure $ (t1 × t2) × v + +eval :: forall a m. MonadError Error m => Highlightable a => BoolAlg a -> Env a -> Expr a -> a -> m (Trace × Val a) +eval _ γ (Var x) _ = (T.Var x × _) <$> lookup' x γ +eval _ γ (Op op) _ = (T.Op op × _) <$> lookup' op γ +eval 𝒶 _ (Int α n) α' = pure (T.Const × V.Int (α `𝒶.meet` α') n) +eval 𝒶 _ (Float α n) α' = pure (T.Const × V.Float (α `𝒶.meet` α') n) +eval 𝒶 _ (Str α str) α' = pure (T.Const × V.Str (α `𝒶.meet` α') str) +eval 𝒶 γ (Record α xes) α' = do + xts × xvs <- traverse (flip (eval 𝒶 γ) α') xes <#> D.unzip + pure $ T.Record xts × V.Record (α `𝒶.meet` α') xvs +eval 𝒶 γ (Dictionary α ees) α' = do + (ts × vs) × (ts' × us) <- traverse (traverse (flip (eval 𝒶 γ) α')) ees <#> (P.unzip >>> (unzip # both)) + let + ss × αs = (vs <#> \u -> string.match u) # unzip + d = D.fromFoldable $ zip ss (zip αs us) + pure $ T.Dictionary (zip ss (zip ts ts')) (d <#> snd >>> erase) × V.Dictionary (α `𝒶.meet` α') (DictRep d) +eval 𝒶 γ (Constr α c es) α' = do + checkArity c (length es) + ts × vs <- traverse (flip (eval 𝒶 γ) α') es <#> unzip + pure (T.Constr c ts × V.Constr (α `𝒶.meet` α') c vs) +eval 𝒶 γ (Matrix α e (x × y) e') α' = do + t × v <- eval 𝒶 γ e' α' + let (i' × β) × (j' × β') = fst (intPair.match v) + check (i' × j' >= 1 × 1) ("array must be at least (" <> show (1 × 1) <> "); got (" <> show (i' × j') <> ")") + tss × vss <- unzipToArray <$> ((<$>) unzipToArray) <$> + ( sequence $ do + i <- range 1 i' + singleton $ sequence $ do + j <- range 1 j' + let γ' = D.singleton x (V.Int β i) `disjointUnion` (D.singleton y (V.Int β' j)) + singleton (eval 𝒶 (γ <+> γ') e α') + ) + pure $ T.Matrix tss (x × y) (i' × j') t × V.Matrix (α `𝒶.meet` α') (MatrixRep (vss × (i' × β) × (j' × β'))) + where + unzipToArray :: forall b c. List (b × c) -> Array b × Array c + unzipToArray = unzip >>> bimap A.fromFoldable A.fromFoldable +eval _ γ (Lambda σ) α = + pure $ T.Const × V.Fun α (V.Closure (γ `restrict` fv σ) empty σ) +eval 𝒶 γ (Project e x) α = do + t × v <- eval 𝒶 γ e α + case v of + V.Record _ xvs -> (T.Project t x × _) <$> lookup' x xvs + _ -> throw $ "Found " <> prettyP v <> ", expected record" +eval 𝒶 γ (App e e') α = do + t × v <- eval 𝒶 γ e α + t' × v' <- eval 𝒶 γ e' α + t'' × v'' <- apply 𝒶 (v × v') + pure $ T.App t t' t'' × v'' +eval 𝒶 γ (Let (VarDef σ e) e') α = do + t × v <- eval 𝒶 γ e α + γ' × _ × α' × w <- match 𝒶 v σ -- terminal meta-type of eliminator is meta-unit + t' × v' <- eval 𝒶 (γ <+> γ') e' α' -- (α ∧ α') for consistency with functions? (similarly for module defs) + pure $ T.Let (T.VarDef w t) t' × v' +eval 𝒶 γ (LetRec ρ e) α = do + let γ' = closeDefs γ ρ α + t × v <- eval 𝒶 (γ <+> γ') e α + pure $ T.LetRec (erase <$> ρ) t × v diff --git a/src/EvalBwd.purs b/src/EvalBwd.purs index 0ad4101c9..29798f332 100644 --- a/src/EvalBwd.purs +++ b/src/EvalBwd.purs @@ -202,7 +202,10 @@ evalBwd' v (T.LetRec ρ t) = evalBwd' _ _ = error absurd type EvalGaloisConnection a = GaloisConnection (EvalBwdResult a) (Val a) - ( v :: Raw Val + ( γ :: Raw Env + , e :: Raw Expr + , t :: Trace + , v :: Raw Val ) traceGC :: forall a m. MonadError Error m => Ann a => Raw Env -> Raw Expr -> m (EvalGaloisConnection a) @@ -211,4 +214,4 @@ traceGC γ e = do let bwd v' = evalBwd γ e v' t fwd { γ: γ', e: e', α } = snd $ fromRight $ runExcept $ eval γ' e' α - pure { v, fwd, bwd } + pure { γ, e, t, v, fwd, bwd } diff --git a/src/GaloisConnection.purs b/src/GaloisConnection.purs index 37e89c157..411f1badc 100644 --- a/src/GaloisConnection.purs +++ b/src/GaloisConnection.purs @@ -5,10 +5,10 @@ import Prelude import BoolAlg (BoolAlg) import Util (Endo) --- a and b are posets, but we don't enforce that here. Use record rather than type class so we can extend with --- explicit value-level representation of index (e.g. graph or trace) for families of GCs. +-- Galois connections are more general, this is specialised to Boolean algebras. type GaloisConnection a b r = - { fwd :: a -> b + { + fwd :: a -> b , bwd :: b -> a | r } From d680d44518bca80f45088c2eb9e6f45f20df55d8 Mon Sep 17 00:00:00 2001 From: Roly Perera Date: Wed, 27 Sep 2023 15:06:03 +0100 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=A7=A9=20[add-unused]:=20Version=20of?= =?UTF-8?q?=20eval=20and=20primitives=20based=20on=20BoolAlg.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Ann.purs | 9 + src/Eval2.purs | 37 ++- src/GaloisConnection.purs | 3 +- src/Pretty2.purs | 461 ++++++++++++++++++++++++++++++++++++++ src/Primitive2.purs | 311 +++++++++++++++++++++++++ src/Trace2.purs | 51 +++++ src/Val2.purs | 255 +++++++++++++++++++++ 7 files changed, 1113 insertions(+), 14 deletions(-) create mode 100644 src/Ann.purs create mode 100644 src/Pretty2.purs create mode 100644 src/Primitive2.purs create mode 100644 src/Trace2.purs create mode 100644 src/Val2.purs diff --git a/src/Ann.purs b/src/Ann.purs new file mode 100644 index 000000000..664a26c69 --- /dev/null +++ b/src/Ann.purs @@ -0,0 +1,9 @@ +module Ann where + +import Prelude + +erase :: forall t a. Functor t => t a -> Raw t +erase = (<$>) (const unit) + +type 𝔹 = Boolean +type Raw (c :: Type -> Type) = c Unit diff --git a/src/Eval2.purs b/src/Eval2.purs index 089408a37..cfcaea23f 100644 --- a/src/Eval2.purs +++ b/src/Eval2.purs @@ -1,6 +1,8 @@ module Eval2 where import Prelude hiding (absurd, apply) + +import Ann (erase) import Bindings (varAnon) import BoolAlg (BoolAlg) import Control.Monad.Error.Class (class MonadError) @@ -20,15 +22,14 @@ import Dict (disjointUnion, get, empty, lookup, keys) import Dict (fromFoldable, singleton, unzip) as D import Effect.Exception (Error) import Expr (Cont(..), Elim(..), Expr(..), Module(..), RecDefs, VarDef(..), asExpr, fv) -import Lattice ((∧), erase, top) -import Pretty (prettyP) -import Primitive (intPair, string) -import Trace (AppTrace(..), Trace(..), VarDef(..)) as T -import Trace (AppTrace, ForeignTrace, ForeignTrace'(..), Match(..), Trace) +import Pretty2 (prettyP) +import Primitive2 (intPair, string) +import Trace2 (AppTrace(..), Trace(..), VarDef(..)) as T +import Trace2 (AppTrace, ForeignTrace, ForeignTrace'(..), Match(..), Trace) import Util (type (×), absurd, both, check, error, orElse, successful, throw, with, (×)) import Util.Pair (unzip) as P -import Val (Fun(..), Val(..)) as V -import Val (class Ann, class Highlightable, DictRep(..), Env, ForeignOp'(..), MatrixRep(..), Val, for, lookup', restrict, (<+>)) +import Val2 (Fun(..), Val(..)) as V +import Val2 (class Highlightable, DictRep(..), Env, ForeignOp'(..), MatrixRep(..), Val, for, lookup', restrict, (<+>)) patternMismatch :: String -> String -> String patternMismatch s s' = "Pattern mismatch: found " <> s <> ", expected " <> s' @@ -87,7 +88,7 @@ apply 𝒶 (V.Fun α (V.Foreign φ vs) × v) = do apply' (ForeignOp' φ') = do t × v'' <- do if φ'.arity > length vs' then pure $ Nothing × V.Fun α (V.Foreign φ vs') - else first Just <$> φ'.op vs' + else first Just <$> φ'.op 𝒶 vs' pure $ mkExists (ForeignTrace' (ForeignOp' φ') t) × v'' apply _ (V.Fun α (V.PartialConstr c vs) × v) = do check (length vs < n) ("Too many arguments to " <> showCtr c) @@ -99,10 +100,10 @@ apply _ (V.Fun α (V.PartialConstr c vs) × v) = do else V.Constr α c (vs <> singleton v) apply _ (_ × v) = throw $ "Found " <> prettyP v <> ", expected function" -apply2 :: forall a m. MonadError Error m => Ann a => Val a × Val a × Val a -> m ((AppTrace × AppTrace) × Val a) -apply2 (u1 × v1 × v2) = do - t1 × u2 <- apply (u1 × v1) - t2 × v <- apply (u2 × v2) +apply2 :: forall a m. MonadError Error m => Highlightable a => BoolAlg a -> Val a × Val a × Val a -> m ((AppTrace × AppTrace) × Val a) +apply2 𝒶 (u1 × v1 × v2) = do + t1 × u2 <- apply 𝒶 (u1 × v1) + t2 × v <- apply 𝒶 (u2 × v2) pure $ (t1 × t2) × v eval :: forall a m. MonadError Error m => Highlightable a => BoolAlg a -> Env a -> Expr a -> a -> m (Trace × Val a) @@ -161,3 +162,15 @@ eval 𝒶 γ (LetRec ρ e) α = do let γ' = closeDefs γ ρ α t × v <- eval 𝒶 (γ <+> γ') e α pure $ T.LetRec (erase <$> ρ) t × v + +eval_module :: forall a m. MonadError Error m => Highlightable a => BoolAlg a -> Env a -> Module a -> a -> m (Env a) +eval_module 𝒶 γ = go empty + where + go :: Env a -> Module a -> a -> m (Env a) + go γ' (Module Nil) _ = pure γ' + go y' (Module (Left (VarDef σ e) : ds)) α = do + _ × v <- eval 𝒶 (γ <+> y') e α + γ'' × _ × α' × _ <- match 𝒶 v σ + go (y' <+> γ'') (Module ds) α' + go γ' (Module (Right ρ : ds)) α = + go (γ' <+> closeDefs (γ <+> γ') ρ α) (Module ds) α diff --git a/src/GaloisConnection.purs b/src/GaloisConnection.purs index 411f1badc..97ca08d7d 100644 --- a/src/GaloisConnection.purs +++ b/src/GaloisConnection.purs @@ -7,8 +7,7 @@ import Util (Endo) -- Galois connections are more general, this is specialised to Boolean algebras. type GaloisConnection a b r = - { - fwd :: a -> b + { fwd :: a -> b , bwd :: b -> a | r } diff --git a/src/Pretty2.purs b/src/Pretty2.purs new file mode 100644 index 000000000..aa06770b2 --- /dev/null +++ b/src/Pretty2.purs @@ -0,0 +1,461 @@ +module Pretty2 (class Pretty, pretty, prettyP) where + +import Prelude hiding (absurd, between) + +import Bindings (Bind, key, val, Var, (↦)) +import Data.Array (foldl) +import Data.Exists (runExists) +import Data.Foldable (class Foldable) +import Data.List (List(..), fromFoldable, null, uncons, (:)) +import Data.List.NonEmpty (NonEmptyList, groupBy, singleton, toList) +import Data.Map (lookup) +import Data.Maybe (Maybe(..)) +import Data.Profunctor.Choice ((|||)) +import Data.Profunctor.Strong (first) +import Data.Set (Set, toUnfoldable) as S +import Data.String (Pattern(..), Replacement(..), contains) as DS +import Data.String (drop, replaceAll) +import DataType (Ctr, cCons, cNil, cPair, showCtr) +import Dict (Dict) +import Dict (toUnfoldable) as D +import Expr (Cont(..), Elim(..)) +import Expr (Expr(..), VarDef(..)) as E +import Graph (Vertex(..)) +import Graph.GraphImpl (GraphImpl(..)) +import Parse (str) +import Primitive.Parse (opDefs) +import SExpr (Branch, Clause(..), Clauses(..), Expr(..), ListRest(..), ListRestPattern(..), Pattern(..), Qualifier(..), RecDefs, VarDef(..), VarDefs) +import Util (type (+), type (×), Endo, absurd, assert, error, intersperse, (×)) +import Util.Pair (Pair(..), toTuple) +import Util.Pretty (Doc(..), atop, beside, empty, hcat, render, text) +import Val2 (Fun(..), Val(..)) as V +import Val2 (class Ann, class Highlightable, DictRep(..), ForeignOp', Fun, MatrixRep(..), Val, highlightIf) + +replacement :: Array (String × String) +replacement = + [ "( " × "(" + , " )" × ")" + , "[ " × "[" + , " ]" × "]" + , "{ " × "{" + , " }" × "}" + , ". " × "." + , " ." × "." + , ". " × "." + , " ," × "," + , " ;" × ";" + , "| " × "|" + , " |" × "|" + , "⸨ " × "⸨" + , " ⸩" × "⸩" + ] + +pattRepPairs :: Array (DS.Pattern × DS.Replacement) +pattRepPairs = map (\(x × y) -> (DS.Pattern x × DS.Replacement y)) replacement + +newtype FirstGroup a = First (RecDefs a) +data ExprType = Simple | Expression +type Sep = Doc -> Doc -> Doc + +exprType :: forall a. Expr a -> ExprType +exprType (Var _) = Simple +exprType (Op _) = Simple +exprType (Int _ _) = Simple +exprType (Float _ _) = Simple +exprType (Str _ _) = Simple +exprType (Constr _ _ Nil) = Simple +exprType (Constr _ _ _) = Expression +exprType (Record _ _) = Simple +exprType (Dictionary _ _) = Simple +exprType (Matrix _ _ _ _) = Simple +exprType (Lambda _) = Simple +exprType (Project _ _) = Simple +exprType (App _ _) = Expression +exprType (BinaryApp _ _ _) = Expression +exprType (MatchAs _ _) = Simple +exprType (IfElse _ _ _) = Simple +exprType (ListEmpty _) = Simple -- try +exprType (ListNonEmpty _ _ _) = Simple +exprType (ListEnum _ _) = Simple +exprType (ListComp _ _ _) = Simple +exprType (Let _ _) = Expression +exprType (LetRec _ _) = Expression + +prettySimple :: forall a. Ann a => Expr a -> Doc +prettySimple s = case exprType s of + Simple -> pretty s + Expression -> parentheses (pretty s) + +prettyAppChain :: forall a. Ann a => Expr a -> Doc +prettyAppChain (App s s') = prettyAppChain s .<>. prettySimple s' +prettyAppChain s = prettySimple s + +prettyBinApp :: forall a. Ann a => Int -> Expr a -> Doc +prettyBinApp n (BinaryApp s op s') = + let + prec' = getPrec op + in + case getPrec op of + -1 -> prettyBinApp prec' s .<>. (text ("`" <> op <> "`")) .<>. prettyBinApp prec' s' + _ -> + if prec' <= n then + parentheses (prettyBinApp prec' s .<>. text op .<>. prettyBinApp prec' s') + else + prettyBinApp prec' s .<>. text op .<>. prettyBinApp prec' s' +prettyBinApp _ s = prettyAppChain s + +getPrec :: String -> Int +getPrec x = case lookup x opDefs of + Just y -> y.prec + Nothing -> -1 + +infixl 5 beside as .<>. +-- infixl 5 space as .<>. +infixl 5 atop as .-. + +class Pretty p where + pretty :: p -> Doc + +removeLineWS :: String -> String +removeLineWS str = foldl (\curr (x × y) -> replaceAll x y curr) str pattRepPairs + +removeDocWS :: Doc -> Doc +removeDocWS (Doc d) = Doc + { width: d.width + , height: d.height + , lines: map (\x -> removeLineWS (drop 1 x)) d.lines + } + +instance Ann a => Pretty (Expr a) where + pretty (Var x) = text x + pretty (Op op) = parentheses (text op) + pretty (Int ann n) = (highlightIf ann $ text (show n)) + pretty (Float ann n) = (highlightIf ann $ text (show n)) + pretty (Str ann str) = (highlightIf ann $ (text ("\"" <> str <> "\""))) + pretty (Constr ann c x) = (prettyConstr ann c x) + pretty (Record ann xss) = (highlightIf ann $ curlyBraces (prettyOperator (.-.) xss)) + pretty (Dictionary ann sss) = (highlightIf ann $ dictBrackets (pretty sss)) + pretty (Matrix ann e (x × y) e') = + highlightIf ann $ arrayBrackets + ( pretty e .<>. text str.bar .<>. parentheses (text x .<>. text str.comma .<>. text y) + .<>. text str.in_ + .<>. pretty e' + ) + pretty (Lambda cs) = parentheses (text str.fun .<>. pretty cs) + pretty (Project s x) = pretty s .<>. text str.dot .<>. text x + pretty (App s s') = prettyAppChain (App s s') + pretty (BinaryApp s op s') = prettyBinApp 0 (BinaryApp s op s') + pretty (MatchAs s cs) = ((text str.match .<>. pretty s .<>. text str.as)) .-. curlyBraces (pretty cs) + pretty (IfElse s1 s2 s3) = text str.if_ .<>. pretty s1 .<>. text str.then_ .<>. pretty s2 .<>. text str.else_ .<>. pretty s3 + pretty (ListEmpty ann) = (highlightIf ann $ brackets empty) + pretty (ListNonEmpty ann (Record _ xss) l) = + (((highlightIf ann $ text str.lBracket)) .<>. ((highlightIf ann $ curlyBraces (prettyOperator (.<>.) xss)))) .-. pretty l + pretty (ListNonEmpty ann e l) = ((highlightIf ann $ text str.lBracket)) .<>. pretty e .<>. pretty l + pretty (ListEnum s s') = brackets (pretty s .<>. text str.ellipsis .<>. pretty s') + pretty (ListComp ann s qs) = (highlightIf ann $ brackets (pretty s .<>. text str.bar .<>. pretty qs)) + pretty (Let ds s) = (text str.let_ .<>. pretty ds .<>. text str.in_) .-. pretty s + pretty (LetRec h s) = (text str.let_ .<>. pretty (First h) .<>. text str.in_) .-. pretty s + +prettyOperator :: forall a. Ann a => (Doc -> Doc -> Doc) -> List (Bind (Expr a)) -> Doc +prettyOperator _ (Cons s Nil) = text (key s) .<>. text str.colon .<>. pretty (val s) +prettyOperator sep (Cons s xss) = sep (prettyOperator sep (toList (singleton s)) .<>. text str.comma) (prettyOperator sep xss) +prettyOperator _ Nil = empty + +instance Ann a => Pretty (ListRest a) where + pretty (Next ann (Record _ xss) l) = ((highlightIf ann $ text str.comma)) .<>. ((highlightIf ann $ curlyBraces (prettyOperator (.<>.) xss))) .-. pretty l + pretty (Next ann s l) = ((highlightIf ann $ text str.comma)) .<>. pretty s .<>. pretty l + pretty (End ann) = (highlightIf ann $ text str.rBracket) + +instance Ann a => Pretty (List (Pair (Expr a))) where + pretty (Cons (Pair e e') Nil) = prettyPairs (Pair e e') + pretty (Cons (Pair e e') sss) = prettyPairs (Pair e e') .<>. text str.comma .<>. pretty sss + pretty Nil = empty + +prettyPairs :: forall a. Ann a => (Pair (Expr a)) -> Doc +prettyPairs (Pair e e') = pretty e .<>. text str.colonEq .<>. pretty e' + +instance Pretty Pattern where + pretty (PVar x) = text x + pretty (PRecord xps) = curlyBraces (pretty xps) + pretty (PConstr c ps) = case uncons ps of + Just { head: p, tail: Nil } -> pretty c .<>. pretty p + _ -> + if c == cPair then (parentheses (prettyPattConstr (text str.comma) ps)) + else if c == cCons then (parentheses (prettyPattConstr (text str.colon) ps)) + else + parentheses (text c .<>. prettyPattConstr empty ps) + + --pretty (PConstr c ps) = if c == cPair then parentheses (prettyPattConstr (text str.comma) ps) + -- if case uncons ps of + -- Nothing -> pretty ps + -- -- Just {head : p, tail: Nil} -> text c .<>. pretty p + -- -- _ -> parentheses (text c .<>. prettyPattConstr empty ps) + pretty (PListEmpty) = brackets empty + pretty (PListNonEmpty p l) = text str.lBracket .<>. pretty p .<>. pretty l + +instance Pretty (List (Bind Pattern)) where + pretty (Cons xp Nil) = text (key xp) .<>. text str.colon .<>. pretty (val xp) + pretty (Cons xp xps) = text (key xp) .<>. text str.colon .<>. pretty (val xp) .<>. text str.comma .-. pretty xps + pretty Nil = empty + +prettyPattConstr :: Doc -> List Pattern -> Doc +prettyPattConstr _ Nil = empty +prettyPattConstr _ (Cons p Nil) = pretty p +prettyPattConstr sep (Cons p ps) = pretty p .<>. sep .<>. prettyPattConstr sep ps + +instance Pretty ListRestPattern where + pretty (PNext p l) = text str.comma .<>. pretty p .<>. pretty l + pretty PEnd = text str.rBracket + +prettyClause :: forall a. Ann a => Doc -> Clause a -> Doc +prettyClause sep (Clause (ps × e)) = prettyPattConstr empty (toList ps) .<>. sep .<>. pretty e + +instance Ann a => Pretty (Clauses a) where + pretty (Clauses cs) = intersperse' (toList (map (prettyClause (text str.equals)) (cs))) (text str.semiColon) + +instance Ann a => Pretty (Branch a) where + pretty (x × Clause (ps × e)) = text x .<>. prettyClause (text str.equals) (Clause (ps × e)) + +instance Ann a => Pretty (NonEmptyList (Branch a)) where + pretty h = intersperse' (toList (map pretty h)) (text str.semiColon) + +instance Ann a => Pretty (NonEmptyList (NonEmptyList (Branch a))) where + pretty hs = intersperse' (toList (map pretty hs)) (text str.semiColon) + +instance Ann a => Pretty (FirstGroup a) where + pretty (First h) = pretty (groupBy (\p q -> key p == key q) h) + +instance Ann a => Pretty (NonEmptyList (Pattern × Expr a)) where + pretty pss = intersperse' (map (prettyClause (text str.rArrow)) (map Clause (toList (helperMatch pss)))) (text str.semiColon) + +instance Ann a => Pretty (VarDef a) where + pretty (VarDef p s) = pretty p .<>. text str.equals .<>. pretty s + +instance Ann a => Pretty (VarDefs a) where + pretty ds = intersperse' (toList (map pretty ds)) (text str.semiColon) + +instance Ann a => Pretty (List (Expr a)) where + pretty (Cons s Nil) = pretty s + pretty (Cons s ss) = pretty s .<>. pretty ss + pretty Nil = empty + +instance Ann a => Pretty (List (Qualifier a)) where + pretty (Cons (Guard s) Nil) = pretty s + pretty (Cons (Declaration d) Nil) = text str.let_ .<>. pretty d + pretty (Cons (Generator p s) Nil) = pretty p .<>. text str.lArrow .<>. pretty s + pretty (Cons q qs) = pretty (toList (singleton q)) .<>. text str.comma .<>. pretty qs + pretty Nil = empty + +intersperse' :: List Doc -> Doc -> Doc +intersperse' (Cons dc Nil) _ = dc +intersperse' (Cons dc dcs) dc' = dc .<>. dc' .-. intersperse' dcs dc' +intersperse' Nil _ = empty + +helperMatch :: forall a. NonEmptyList (Pattern × Expr a) -> NonEmptyList (NonEmptyList Pattern × Expr a) +helperMatch pss = map (\(x × y) -> singleton x × y) pss + +prettyP :: forall d. Pretty d => d -> String +prettyP x = render (removeDocWS (pretty x)) + +between :: Doc -> Doc -> Endo Doc +between l r doc = l .<>. doc .<>. r + +brackets :: Endo Doc +brackets = between (text str.lBracket) (text str.rBracket) + +dictBrackets :: Endo Doc +dictBrackets = between (text str.dictLBracket) (text str.dictRBracket) + +parentheses :: Endo Doc +parentheses = between (text str.lparenth) (text str.rparenth) + +-- slashes :: Endo Doc +-- slashes = between (text str.slash) (text str.slash) + +-- backTicks :: Endo Doc +-- backTicks = between (text str.backtick) (text str.backtick) + +curlyBraces :: Endo Doc +curlyBraces = between (text str.curlylBrace) (text str.curlyrBrace) + +arrayBrackets :: Endo Doc +arrayBrackets = between (text str.arrayLBracket) (text str.arrayRBracket) + +comma :: Doc +comma = text str.comma + +semi :: Doc +semi = text str.semiColon + +hcomma :: forall f. Foldable f => f Doc -> Doc +hcomma = fromFoldable >>> intersperse comma >>> hcat + +parens :: Endo Doc +parens = between (text "(") (text ")") + +class ToList a where + toList2 :: a -> List a + +class ToPair a where + toPair :: a -> a × a + +instance ToPair (E.Expr a) where + toPair (E.Constr _ c (e : e' : Nil)) | c == cPair = e × e' + toPair _ = error absurd + +instance ToPair (Val a) where + toPair (V.Constr _ c (v : v' : Nil)) | c == cPair = v × v' + toPair _ = error absurd + +instance Pretty String where + pretty = text + +vert :: forall f. Foldable f => Doc -> f Doc -> Doc +vert delim = fromFoldable >>> vert' + where + vert' :: List Doc -> Doc + vert' Nil = empty + vert' (x : Nil) = x + vert' (x : y : xs) = atop (x .<>. delim) (vert' (y : xs)) + +prettyCtr :: Ctr -> Doc +prettyCtr = showCtr >>> text + +-- Cheap hack; revisit. +prettyParensOpt :: forall a. Pretty a => a -> Doc +prettyParensOpt x = + if DS.contains (DS.Pattern " ") (render doc) then parens doc + else doc + where + doc = pretty x + +nil :: Doc +nil = text (str.lBracket <> str.rBracket) + +-- (highlightIf α $ parens (hcomma [ pretty x, pretty y ])) +-- (highlightIf ann $ text str.lBracket) +prettyConstr :: forall d a. Pretty d => Highlightable a => a -> Ctr -> List d -> Doc +prettyConstr α c (x : y : ys) + | c == cPair = assert (null ys) $ (highlightIf α $ parens (hcomma [ pretty x, pretty y ])) +prettyConstr α c ys + | c == cNil = assert (null ys) $ (highlightIf α nil) +prettyConstr α c (x : y : ys) + | c == cCons = assert (null ys) $ parens (hcat [ pretty x, (highlightIf α $ text str.colon), pretty y ]) +prettyConstr α c (x : Nil) = (highlightIf α (prettyCtr c .<>. pretty x)) +prettyConstr α c xs = hcat ((highlightIf α (prettyCtr c)) : (prettyParensOpt <$> xs)) + +prettyRecordOrDict + :: forall d b a + . Pretty d + => Highlightable a + => Doc + -> Endo Doc + -> (b -> Doc) + -> a + -> List (b × d) + -> Doc +prettyRecordOrDict sep bracify prettyKey α xvs = + xvs <#> first prettyKey <#> (\(x × v) -> hcat [ x .<>. sep, pretty v ]) + # hcomma >>> bracify >>> highlightIf α + +prettyDict :: forall d b a. Pretty d => Highlightable a => (b -> Doc) -> a -> List (b × d) -> Doc +prettyDict = between (text str.dictLBracket) (text str.dictRBracket) # prettyRecordOrDict (text str.colonEq) + +prettyRecord :: forall d b a. Pretty d => Highlightable a => (b -> Doc) -> a -> List (b × d) -> Doc +prettyRecord = curlyBraces # prettyRecordOrDict (text str.colon) + +prettyMatrix :: forall a. Highlightable a => E.Expr a -> Var -> Var -> E.Expr a -> Doc +prettyMatrix e1 i j e2 = arrayBrackets (pretty e1 .<>. text str.lArrow .<>. text (i <> "×" <> j) .<>. text str.in_ .<>. pretty e2) + +instance Highlightable a => Pretty (E.Expr a) where + pretty (E.Var x) = text x + pretty (E.Int α n) = (highlightIf α (text (show n))) + pretty (E.Float _ n) = text (show n) + pretty (E.Str _ str) = text (show str) + pretty (E.Record α xes) = prettyRecord text α (xes # D.toUnfoldable) + pretty (E.Dictionary α ees) = prettyDict pretty α (ees <#> toTuple) + pretty (E.Constr α c es) = prettyConstr α c es + pretty (E.Matrix α e1 (i × j) e2) = (highlightIf α (prettyMatrix e1 i j e2)) + pretty (E.Lambda σ) = hcat [ text str.fun, pretty σ ] + pretty (E.Op op) = parens (text op) + pretty (E.Let (E.VarDef σ e) e') = atop (hcat [ text str.let_, pretty σ, text str.equals, pretty e, text str.in_ ]) + (pretty e') + pretty (E.LetRec δ e) = atop (hcat [ text str.let_, pretty δ, text str.in_ ]) (pretty e) + pretty (E.Project e x) = pretty e .<>. text str.dot .<>. pretty x + pretty (E.App e e') = hcat [ pretty e, pretty e' ] + +instance Highlightable a => Pretty (Dict (Elim a)) where + pretty x = go (D.toUnfoldable x) + where + go :: List (Var × Elim a) -> Doc + go Nil = error absurd -- non-empty + go (xσ : Nil) = pretty xσ + go (xσ : δ) = atop (go δ .<>. semi) (pretty xσ) + +instance Highlightable a => Pretty (Dict (Val a)) where + pretty γ = brackets $ go (D.toUnfoldable γ) + where + go :: List (Var × Val a) -> Doc + go Nil = empty + go ((x × v) : rest) = parens (text x .<>. text str.comma .<>. pretty v) .<>. text str.comma .<>. go rest + +instance Pretty (Dict (S.Set Vertex)) where + pretty d = brackets $ go (D.toUnfoldable d) + where + go :: List (String × S.Set Vertex) -> Doc + go Nil = empty + go ((α × βs) : rest) = text α .<>. text " ↦ " .<>. pretty (βs :: S.Set Vertex) .<>. text str.comma .<>. go rest + +instance Highlightable a => Pretty (Bind (Elim a)) where + pretty (x ↦ σ) = hcat [ text x, text str.equals, pretty σ ] + +instance Highlightable a => Pretty (Cont a) where + pretty ContNone = empty + pretty (ContExpr e) = pretty e + pretty (ContElim σ) = pretty σ + +instance Highlightable a => Pretty (Ctr × Cont a) where + pretty (c × κ) = hcat [ text (showCtr c), text str.rArrow, pretty κ ] + +instance Highlightable a => Pretty (Elim a) where + pretty (ElimVar x κ) = hcat [ text x, text str.rArrow, pretty κ ] + pretty (ElimConstr κs) = hcomma (pretty <$> κs) -- looks dodgy + pretty (ElimRecord xs κ) = hcat [ curlyBraces $ hcomma (text <$> (S.toUnfoldable xs :: List String)), text str.rArrow, curlyBraces (pretty κ) ] + +instance Highlightable a => Pretty (Val a) where + pretty (V.Int α n) = (highlightIf α (text (show n))) + pretty (V.Float α n) = (highlightIf α (text (show n))) + pretty (V.Str α str) = (highlightIf α (text (show str))) + pretty (V.Record α xvs) = prettyRecord text α (xvs # D.toUnfoldable) + pretty (V.Dictionary α (DictRep svs)) = prettyDict + (\(s × β) -> highlightIf β (text (show s))) + α + (svs # D.toUnfoldable <#> \(s × (β × v)) -> (s × β) × v) + pretty (V.Constr α c vs) = prettyConstr α c vs + pretty (V.Matrix _ (MatrixRep (vss × _ × _))) = vert comma (((<$>) pretty >>> hcomma) <$> vss) + pretty (V.Fun α φ) = pretty (α × φ) + +instance Highlightable a => Pretty (a × Fun a) where + pretty (α × V.Closure _ _ _) = (highlightIf α $ text "") + pretty (_ × V.Foreign φ _) = runExists pretty φ + pretty (α × V.PartialConstr c vs) = prettyConstr α c vs + +instance Pretty (ForeignOp' t) where + pretty _ = text "" -- TODO + +instance (Pretty a, Pretty b) => Pretty (a + b) where + pretty = pretty ||| pretty + +instance Pretty GraphImpl where + pretty (GraphImpl g) = + text "GraphImpl \n " .<>. + atop + ( text "{\n" .<>. + atop (text "OUT: " .<>. pretty g.out) (text "IN: " .<>. pretty g.in) + ) + (text "}") + +instance Pretty (S.Set Vertex) where + pretty αs = curlyBraces (hcomma (text <<< unwrap <$> (S.toUnfoldable αs :: List Vertex))) + where + unwrap (Vertex α) = α diff --git a/src/Primitive2.purs b/src/Primitive2.purs new file mode 100644 index 000000000..ae0e7c219 --- /dev/null +++ b/src/Primitive2.purs @@ -0,0 +1,311 @@ +module Primitive2 where + +import Prelude hiding (absurd, apply, div, top) + +import Ann (Raw, erase) +import Data.Either (Either(..)) +import Data.Exists (mkExists) +import Data.Int (toNumber) +import Data.List (List(..), (:)) +import Data.Profunctor.Choice ((|||)) +import Data.Set (singleton, insert) +import DataType (cFalse, cPair, cTrue) +import Dict (Dict) +import Graph.GraphWriter (new) +import Lattice (class BoundedJoinSemilattice, bot) +import Partial.Unsafe (unsafePartial) +import Pretty2 (prettyP) +import Util (type (+), type (×), (×), error) +import Val2 (class Ann, ForeignOp'(..), Fun(..), MatrixRep, OpBwd, OpFwd, OpGraph, Val(..)) + +-- Mediate between values of annotation type a and (potential) underlying datatype d, analogous to +-- pattern-matching and construction for data types. Wasn't able to make a typeclass version of this +-- work with the required higher-rank polymorphism. +type ToFrom d a = + { constr :: d × a -> Val a + , constr_bwd :: Val a -> d × a -- equivalent to match (except at Val) + , match :: Val a -> d × a + } + +int :: forall a. ToFrom Int a +int = + { constr: \(n × α) -> Int α n + , constr_bwd: match' + , match: match' + } + where + match' :: _ + match' (Int α n) = n × α + match' v = error ("Int expected; got " <> prettyP (erase v)) + +number :: forall a. ToFrom Number a +number = + { constr: \(n × α) -> Float α n + , constr_bwd: match' + , match: match' + } + where + match' :: _ + match' (Float α n) = n × α + match' v = error ("Float expected; got " <> prettyP (erase v)) + +string :: forall a. ToFrom String a +string = + { constr: \(str × α) -> Str α str + , constr_bwd: match' + , match: match' + } + where + match' :: _ + match' (Str α str) = str × α + match' v = error ("Str expected; got " <> prettyP (erase v)) + +intOrNumber :: forall a. ToFrom (Int + Number) a +intOrNumber = + { constr: case _ of + Left n × α -> Int α n + Right n × α -> Float α n + , constr_bwd: match' + , match: match' + } + where + match' :: Val a -> (Int + Number) × a + match' (Int α n) = Left n × α + match' (Float α n) = Right n × α + match' v = error ("Int or Float expected; got " <> prettyP (erase v)) + +intOrNumberOrString :: forall a. ToFrom (Int + Number + String) a +intOrNumberOrString = + { constr: case _ of + Left n × α -> Int α n + Right (Left n) × α -> Float α n + Right (Right str) × α -> Str α str + , constr_bwd: match' + , match: match' + } + where + match' :: Val a -> (Int + Number + String) × a + match' (Int α n) = Left n × α + match' (Float α n) = Right (Left n) × α + match' (Str α str) = Right (Right str) × α + match' v = error ("Int, Float or Str expected; got " <> prettyP (erase v)) + +intPair :: forall a. ToFrom ((Int × a) × (Int × a)) a +intPair = + { constr: \((nβ × mβ') × α) -> Constr α cPair (int.constr nβ : int.constr mβ' : Nil) + , constr_bwd: match' + , match: match' + } + where + match' :: Val a -> ((Int × a) × (Int × a)) × a + match' (Constr α c (v : v' : Nil)) | c == cPair = (int.match v × int.match v') × α + match' v = error ("Pair expected; got " <> prettyP (erase v)) + +matrixRep :: forall a. Ann a => ToFrom (MatrixRep a) a +matrixRep = + { constr: \(m × α) -> Matrix α m + , constr_bwd: match' + , match: match' + } + where + match' :: Ann a => Val a -> MatrixRep a × a + match' (Matrix α m) = m × α + match' v = error ("Matrix expected; got " <> prettyP v) + +record :: forall a. Ann a => ToFrom (Dict (Val a)) a +record = + { constr: \(xvs × α) -> Record α xvs + , constr_bwd: match' + , match: match' + } + where + match' :: Ann a => _ + match' (Record α xvs) = xvs × α + match' v = error ("Record expected; got " <> prettyP v) + +boolean :: forall a. ToFrom Boolean a +boolean = + { constr: case _ of + true × α -> Constr α cTrue Nil + false × α -> Constr α cFalse Nil + , constr_bwd: match' + , match: match' + } + where + match' :: Val a -> Boolean × a + match' (Constr α c Nil) + | c == cTrue = true × α + | c == cFalse = false × α + match' v = error ("Boolean expected; got " <> prettyP (erase v)) + +class IsZero a where + isZero :: a -> Boolean + +instance IsZero Int where + isZero = ((==) 0) + +instance IsZero Number where + isZero = ((==) 0.0) + +instance (IsZero a, IsZero b) => IsZero (a + b) where + isZero = isZero ||| isZero + +-- Need to be careful about type variables escaping higher-rank quantification. +type Unary i o a = + { i :: ToFrom i a + , o :: ToFrom o a + , fwd :: i -> o + } + +type Binary i1 i2 o a = + { i1 :: ToFrom i1 a + , i2 :: ToFrom i2 a + , o :: ToFrom o a + , fwd :: i1 -> i2 -> o + } + +type BinaryZero i o a = + { i :: ToFrom i a + , o :: ToFrom o a + , fwd :: i -> i -> o + } + +unary :: forall i o a'. BoundedJoinSemilattice a' => (forall a. Unary i o a) -> Val a' +unary op = + Fun bot $ flip Foreign Nil + $ mkExists + $ ForeignOp' { arity: 1, op': unsafePartial op', op: unsafePartial fwd, op_bwd: unsafePartial bwd } + where + op' :: Partial => OpGraph + op' (v : Nil) = + op.o.constr <$> ((op.fwd x × _) <$> new (singleton α)) + where + x × α = op.i.match v + + fwd :: Partial => OpFwd (Raw Val) + fwd _ (v : Nil) = pure $ erase v × op.o.constr (op.fwd x × α) + where + x × α = op.i.match v + + bwd :: Partial => OpBwd (Raw Val) + bwd _ (u × v) = op.i.constr (x × α) : Nil + where + _ × α = op.o.constr_bwd v + (x × _) = op.i.match u + +binary :: forall i1 i2 o a'. BoundedJoinSemilattice a' => (forall a. Binary i1 i2 o a) -> Val a' +binary op = + Fun bot $ flip Foreign Nil + $ mkExists + $ ForeignOp' { arity: 2, op': unsafePartial op', op: unsafePartial fwd, op_bwd: unsafePartial bwd } + where + op' :: Partial => OpGraph + op' (v1 : v2 : Nil) = + op.o.constr <$> ((op.fwd x y × _) <$> new (singleton α # insert β)) + where + (x × α) × (y × β) = op.i1.match v1 × op.i2.match v2 + + fwd :: Partial => OpFwd (Raw Val × Raw Val) + fwd 𝒶 (v1 : v2 : Nil) = pure $ (erase v1 × erase v2) × op.o.constr (op.fwd x y × (α `𝒶.meet` β)) + where + (x × α) × (y × β) = op.i1.match v1 × op.i2.match v2 + + bwd :: Partial => OpBwd (Raw Val × Raw Val) + bwd _ ((u1 × u2) × v) = op.i1.constr (x × α) : op.i2.constr (y × α) : Nil + where + _ × α = op.o.constr_bwd v + (x × _) × (y × _) = op.i1.match u1 × op.i2.match u2 + +-- If both are zero, depend only on the first. +binaryZero :: forall i o a'. BoundedJoinSemilattice a' => IsZero i => (forall a. BinaryZero i o a) -> Val a' +binaryZero op = + Fun bot $ flip Foreign Nil + $ mkExists + $ ForeignOp' { arity: 2, op': unsafePartial op', op: unsafePartial fwd, op_bwd: unsafePartial bwd } + where + op' :: Partial => OpGraph + op' (v1 : v2 : Nil) = + let + αs = + if isZero x then singleton α + else if isZero y then singleton β + else singleton α # insert β + in + op.o.constr <$> ((op.fwd x y × _) <$> new αs) + where + (x × α) × (y × β) = op.i.match v1 × op.i.match v2 + + fwd :: Partial => OpFwd (Raw Val × Raw Val) + fwd 𝒶 (v1 : v2 : Nil) = + pure $ (erase v1 × erase v2) × + op.o.constr (op.fwd x y × if isZero x then α else if isZero y then β else α `𝒶.meet` β) + where + (x × α) × (y × β) = op.i.match v1 × op.i.match v2 + + bwd :: Partial => OpBwd (Raw Val × Raw Val) + bwd 𝒶 ((u1 × u2) × v) = op.i.constr (x × β1) : op.i.constr (y × β2) : Nil + where + _ × α = op.o.constr_bwd v + (x × _) × (y × _) = op.i.match u1 × op.i.match u2 + β1 × β2 = + if isZero x then α × 𝒶.bot + else if isZero y then 𝒶.bot × α + else α × α + +class As a b where + as :: a -> b + +union1 :: forall a1 b. (a1 -> b) -> (Number -> b) -> a1 + Number -> b +union1 f _ (Left x) = f x +union1 _ g (Right x) = g x + +-- Biased towards g: if arguments are of mixed types, we try to coerce to an application of g. +union + :: forall a1 b1 c1 a2 b2 c2 c + . As c1 c + => As c2 c + => As a1 a2 + => As b1 b2 + => (a1 -> b1 -> c1) + -> (a2 -> b2 -> c2) + -> a1 + a2 + -> b1 + b2 + -> c +union f _ (Left x) (Left y) = as (f x y) +union _ g (Left x) (Right y) = as (g (as x) y) +union _ g (Right x) (Right y) = as (g x y) +union _ g (Right x) (Left y) = as (g x (as y)) + +-- Helper to avoid some explicit type annotations when defining primitives. +unionStr + :: forall a b + . As a a + => As b String + => (b -> b -> a) + -> (String -> String -> a) + -> b + String + -> b + String + -> a +unionStr = union + +instance asIntIntOrNumber :: As Int (Int + a) where + as = Left + +instance asNumberIntOrNumber :: As Number (a + Number) where + as = Right + +instance asIntNumber :: As Int Number where + as = toNumber + +instance asBooleanBoolean :: As Boolean Boolean where + as = identity + +instance asNumberString :: As Number String where + as _ = error "Non-uniform argument types" + +instance asIntNumberOrString :: As Int (Number + a) where + as = toNumber >>> Left + +instance asIntorNumberNumber :: As (Int + Number) Number where + as (Left n) = as n + as (Right n) = n \ No newline at end of file diff --git a/src/Trace2.purs b/src/Trace2.purs new file mode 100644 index 000000000..11a9a5378 --- /dev/null +++ b/src/Trace2.purs @@ -0,0 +1,51 @@ +module Trace2 where + +import Prelude + +import Bindings (Var) +import Data.Exists (Exists) +import Data.List (List) +import Data.Maybe (Maybe) +import Data.Set (Set, empty, singleton, unions) +import DataType (Ctr) +import Dict (Dict) +import Expr (class BV, RecDefs, bv) +import Lattice (Raw) +import Util (type (×)) +import Val2 (Array2, ForeignOp', Val) + +data Trace + = Var Var + | Op Var + | Const + | Record (Dict Trace) + | Dictionary (List (String × Trace × Trace)) (Dict (Raw Val)) + | Constr Ctr (List Trace) + | Matrix (Array2 Trace) (Var × Var) (Int × Int) Trace + | Project Trace Var + | App Trace Trace AppTrace + | Let VarDef Trace + | LetRec (Raw RecDefs) Trace + +data AppTrace + = AppClosure (Set Var) Match Trace + -- these two forms represent partial (unsaturated) applications + | AppForeign Int ForeignTrace -- record number of arguments + | AppConstr Ctr + +data ForeignTrace' t = ForeignTrace' (ForeignOp' t) (Maybe t) +type ForeignTrace = Exists ForeignTrace' + +data VarDef = VarDef Match Trace + +data Match + = MatchVar Var (Raw Val) + | MatchVarAnon (Raw Val) + | MatchConstr Ctr (List Match) + | MatchRecord (Dict Match) + +instance BV Match where + bv (MatchVar x _) = singleton x + bv (MatchVarAnon _) = empty + bv (MatchConstr _ ws) = unions (bv <$> ws) + bv (MatchRecord xws) = unions (bv <$> xws) diff --git a/src/Val2.purs b/src/Val2.purs new file mode 100644 index 000000000..ae196645e --- /dev/null +++ b/src/Val2.purs @@ -0,0 +1,255 @@ +module Val2 where + +import Prelude hiding (absurd, append) + +import Bindings (Var) +import BoolAlg (BoolAlg) +import Control.Apply (lift2) +import Control.Monad.Error.Class (class MonadError, class MonadThrow) +import Data.Array ((!!)) +import Data.Array (zipWith) as A +import Data.Bitraversable (bitraverse) +import Data.Exists (Exists) +import Data.Foldable (class Foldable, foldl, foldrDefault, foldMapDefaultL) +import Data.List (List(..), (:), zipWith) +import Data.Set (Set, empty, fromFoldable, intersection, member, singleton, toUnfoldable, union) +import Data.Traversable (class Traversable, sequenceDefault, traverse) +import DataType (Ctr) +import Dict (Dict, get) +import Dict (apply2, intersectionWith) as D +import Effect.Exception (Error) +import Expr (Elim, RecDefs, fv) +import Foreign.Object (filterKeys, lookup, unionWith) +import Foreign.Object (keys) as O +import Graph (Vertex(..)) +import Graph.GraphWriter (class MonadGraphAlloc) +import Lattice (class BoundedJoinSemilattice, class BoundedLattice, class Expandable, class JoinSemilattice, class Neg, Raw, definedJoin, expand, maybeJoin, neg, (∨)) +import Util (type (×), Endo, error, orElse, throw, unsafeUpdateAt, (!), (×), (≜), (≞)) +import Util.Pretty (Doc, beside, text) + +data Val a + = Int a Int + | Float a Number + | Str a String + | Constr a Ctr (List (Val a)) -- always saturated + | Record a (Dict (Val a)) -- always saturated + | Dictionary a (DictRep a) + | Matrix a (MatrixRep a) + | Fun a (Fun a) + +data Fun a + = Closure (Env a) (RecDefs a) (Elim a) + | Foreign ForeignOp (List (Val a)) -- never saturated + | PartialConstr Ctr (List (Val a)) -- never saturated + +class (Highlightable a, BoundedLattice a) <= Ann a + +instance Ann Boolean +instance Ann Unit + +instance Highlightable a => Highlightable (a × b) where + highlightIf (a × _) doc = highlightIf a doc + +instance (Ann a, BoundedLattice b) => Ann (a × b) + +-- similar to an isomorphism lens with complement t +type OpFwd t = forall a m. Highlightable a => MonadError Error m => BoolAlg a -> List (Val a) -> m (t × Val a) +type OpBwd t = forall a. Highlightable a => BoolAlg a -> t × Val a -> List (Val a) +type OpGraph = forall m. MonadGraphAlloc m => MonadError Error m => List (Val Vertex) -> m (Val Vertex) + +data ForeignOp' t = ForeignOp' + { arity :: Int + , op :: OpFwd t + , op' :: OpGraph + , op_bwd :: OpBwd t + } + +type ForeignOp = Exists ForeignOp' + +-- Environments. +type Env a = Dict (Val a) + +lookup' :: forall a m. MonadThrow Error m => Var -> Dict a -> m a +lookup' x γ = lookup x γ # orElse ("variable " <> x <> " not found") + +-- Want a monoid instance but needs a newtype +append :: forall a. Env a -> Endo (Env a) +append = unionWith (const identity) + +infixl 5 append as <+> + +append_inv :: forall a. Set Var -> Env a -> Env a × Env a +append_inv xs γ = filterKeys (_ `not <<< member` xs) γ × restrict γ xs + +restrict :: forall a. Dict a -> Set Var -> Dict a +restrict γ xs = filterKeys (_ `member` xs) γ + +reaches :: forall a. RecDefs a -> Endo (Set Var) +reaches ρ xs = go (toUnfoldable xs) empty + where + dom_ρ = fromFoldable $ O.keys ρ + + go :: List Var -> Endo (Set Var) + go Nil acc = acc + go (x : xs') acc | x `member` acc = go xs' acc + go (x : xs') acc | otherwise = + go (toUnfoldable (fv σ `intersection` dom_ρ) <> xs') + (singleton x `union` acc) + where + σ = get x ρ + +for :: forall a. RecDefs a -> Elim a -> RecDefs a +for ρ σ = ρ `restrict` reaches ρ (fv σ `intersection` (fromFoldable $ O.keys ρ)) + +-- Wrap internal representations to provide foldable/traversable instances. +newtype DictRep a = DictRep (Dict (a × Val a)) +newtype MatrixRep a = MatrixRep (Array2 (Val a) × (Int × a) × (Int × a)) + +type Array2 a = Array (Array a) + +matrixGet :: forall a m. MonadThrow Error m => Int -> Int -> MatrixRep a -> m (Val a) +matrixGet i j (MatrixRep (vss × _ × _)) = + orElse "Index out of bounds" $ do + us <- vss !! (i - 1) + us !! (j - 1) + +matrixUpdate :: forall a. Int -> Int -> Endo (Val a) -> Endo (MatrixRep a) +matrixUpdate i j δv (MatrixRep (vss × h × w)) = + MatrixRep (vss' × h × w) + where + vs_i = vss ! (i - 1) + v_j = vs_i ! (j - 1) + vss' = unsafeUpdateAt (i - 1) (unsafeUpdateAt (j - 1) (δv v_j) vs_i) vss + +class Highlightable a where + highlightIf :: a -> Endo Doc + +instance Highlightable Unit where + highlightIf _ = identity + +instance Highlightable Boolean where + highlightIf false = identity + highlightIf true = \doc -> text "⸨" `beside` doc `beside` text "⸩" + +instance Highlightable Vertex where + highlightIf (Vertex α) = \doc -> doc `beside` text "_" `beside` text ("⟨" <> α <> "⟩") + +-- ====================== +-- boilerplate +-- ====================== +derive instance Functor DictRep +derive instance Functor MatrixRep +derive instance Functor Val +derive instance Foldable Val +derive instance Traversable Val +derive instance Functor Fun +derive instance Foldable Fun +derive instance Traversable Fun + +instance Apply Val where + apply (Int fα n) (Int α _) = Int (fα α) n + apply (Float fα n) (Float α _) = Float (fα α) n + apply (Str fα s) (Str α _) = Str (fα α) s + apply (Constr fα c fes) (Constr α _ es) = Constr (fα α) c (zipWith (<*>) fes es) + apply (Record fα fxvs) (Record α xvs) = Record (fα α) (D.apply2 fxvs xvs) + apply (Dictionary fα fxvs) (Dictionary α xvs) = Dictionary (fα α) (fxvs <*> xvs) + apply (Matrix fα fm) (Matrix α m) = Matrix (fα α) (fm <*> m) + apply (Fun fα ff) (Fun α f) = Fun (fα α) (ff <*> f) + apply _ _ = error "Apply Expr: shape mismatch" + +instance Apply Fun where + apply (Closure fγ fρ fσ) (Closure γ ρ σ) = Closure (D.apply2 fγ γ) (D.apply2 fρ ρ) (fσ <*> σ) + apply (Foreign op fvs) (Foreign _ vs) = Foreign op (zipWith (<*>) fvs vs) + apply (PartialConstr c fvs) (PartialConstr _ vs) = PartialConstr c (zipWith (<*>) fvs vs) + apply _ _ = error "Apply Fun: shape mismatch" + +instance Apply DictRep where + apply (DictRep fxvs) (DictRep xvs) = DictRep $ D.intersectionWith (\(fα' × fv') (α' × v') -> (fα' α') × (fv' <*> v')) fxvs xvs + +instance Apply MatrixRep where + apply (MatrixRep (fvss × (n × fnα) × (m × fmα))) (MatrixRep (vss × (_ × nα) × (_ × mα))) = MatrixRep $ (A.zipWith (A.zipWith (<*>)) fvss vss) × (n × fnα nα) × (m × fmα mα) + +instance Foldable DictRep where + foldl f acc (DictRep d) = foldl (\acc' (a × v) -> foldl f (acc' `f` a) v) acc d + foldr f = foldrDefault f + foldMap f = foldMapDefaultL f + +instance Traversable DictRep where + traverse f (DictRep d) = DictRep <$> traverse (bitraverse f (traverse f)) d + sequence = sequenceDefault + +instance Foldable MatrixRep where + foldl f acc (MatrixRep (vss × (_ × βi) × (_ × βj))) = foldl (foldl (foldl f)) (acc `f` βi `f` βj) vss + foldr f = foldrDefault f + foldMap f = foldMapDefaultL f + +instance Traversable MatrixRep where + traverse f (MatrixRep m) = + MatrixRep <$> bitraverse (traverse (traverse (traverse f))) + (bitraverse (traverse f) (traverse f)) + m + sequence = sequenceDefault + +instance JoinSemilattice a => JoinSemilattice (DictRep a) where + maybeJoin (DictRep svs) (DictRep svs') = DictRep <$> maybeJoin svs svs' + join v = definedJoin v + +instance JoinSemilattice a => JoinSemilattice (MatrixRep a) where + maybeJoin (MatrixRep (vss × (i × βi) × (j × βj))) (MatrixRep (vss' × (i' × βi') × (j' × βj'))) = + MatrixRep <$> + ( maybeJoin vss vss' + `lift2 (×)` (((_ × (βi ∨ βi')) <$> (i ≞ i')) `lift2 (×)` ((_ × (βj ∨ βj')) <$> (j ≞ j'))) + ) + join v = definedJoin v + +instance JoinSemilattice a => JoinSemilattice (Val a) where + maybeJoin (Int α n) (Int α' n') = Int (α ∨ α') <$> (n ≞ n') + maybeJoin (Float α n) (Float α' n') = Float (α ∨ α') <$> (n ≞ n') + maybeJoin (Str α s) (Str α' s') = Str (α ∨ α') <$> (s ≞ s') + maybeJoin (Record α xvs) (Record α' xvs') = Record (α ∨ α') <$> maybeJoin xvs xvs' + maybeJoin (Dictionary α d) (Dictionary α' d') = Dictionary (α ∨ α') <$> maybeJoin d d' + maybeJoin (Constr α c vs) (Constr α' c' us) = Constr (α ∨ α') <$> (c ≞ c') <*> maybeJoin vs us + maybeJoin (Matrix α m) (Matrix α' m') = Matrix (α ∨ α') <$> maybeJoin m m' + maybeJoin (Fun α φ) (Fun α' φ') = Fun (α ∨ α') <$> maybeJoin φ φ' + maybeJoin _ _ = throw "Incompatible values" + + join v = definedJoin v + +instance JoinSemilattice a => JoinSemilattice (Fun a) where + maybeJoin (Closure γ ρ σ) (Closure γ' ρ' σ') = + Closure <$> maybeJoin γ γ' <*> maybeJoin ρ ρ' <*> maybeJoin σ σ' + maybeJoin (Foreign φ vs) (Foreign _ vs') = + Foreign φ <$> maybeJoin vs vs' -- TODO: require φ == φ' + maybeJoin (PartialConstr c vs) (PartialConstr c' us) = + PartialConstr <$> (c ≞ c') <*> maybeJoin vs us + maybeJoin _ _ = throw "Incompatible functions" + + join v = definedJoin v + +instance BoundedJoinSemilattice a => Expandable (DictRep a) (Raw DictRep) where + expand (DictRep svs) (DictRep svs') = DictRep (expand svs svs') + +instance BoundedJoinSemilattice a => Expandable (MatrixRep a) (Raw MatrixRep) where + expand (MatrixRep (vss × (i × βi) × (j × βj))) (MatrixRep (vss' × (i' × _) × (j' × _))) = + MatrixRep (expand vss vss' × ((i ≜ i') × βi) × ((j ≜ j') × βj)) + +instance BoundedJoinSemilattice a => Expandable (Val a) (Raw Val) where + expand (Int α n) (Int _ n') = Int α (n ≜ n') + expand (Float α n) (Float _ n') = Float α (n ≜ n') + expand (Str α s) (Str _ s') = Str α (s ≜ s') + expand (Record α xvs) (Record _ xvs') = Record α (expand xvs xvs') + expand (Dictionary α d) (Dictionary _ d') = Dictionary α (expand d d') + expand (Constr α c vs) (Constr _ c' us) = Constr α (c ≜ c') (expand vs us) + expand (Matrix α m) (Matrix _ m') = Matrix α (expand m m') + expand (Fun α φ) (Fun _ φ') = Fun α (expand φ φ') + expand _ _ = error "Incompatible values" + +instance BoundedJoinSemilattice a => Expandable (Fun a) (Raw Fun) where + expand (Closure γ ρ σ) (Closure γ' ρ' σ') = + Closure (expand γ γ') (expand ρ ρ') (expand σ σ') + expand (Foreign φ vs) (Foreign _ vs') = Foreign φ (expand vs vs') -- TODO: require φ == φ' + expand (PartialConstr c vs) (PartialConstr c' us) = PartialConstr (c ≜ c') (expand vs us) + expand _ _ = error "Incompatible values" + +instance Neg a => Neg (Val a) where + neg = (<$>) neg