diff --git a/src/Eval.purs b/src/Eval.purs index c7bef4ec2..6fcc99b24 100644 --- a/src/Eval.purs +++ b/src/Eval.purs @@ -24,11 +24,11 @@ import Lattice ((∧), erase, top) import Pretty (prettyP) import Primitive (intPair, string) import Trace (AppTrace(..), Trace(..), VarDef(..)) as T -import Trace (AppTrace, ForeignTrace, ForeignTrace'(..), Match(..), Trace) +import Trace (AppTrace, ForeignTrace(..), ForeignTrace'(..), Match(..), Trace) import Util (type (×), absurd, both, check, error, orElse, successful, throw, with, (×)) import Util.Pair (unzip) as P import Val (Fun(..), Val(..)) as V -import Val (class Ann, DictRep(..), Env, ForeignOp'(..), MatrixRep(..), (<+>), Val, for, lookup', restrict) +import Val (class Ann, DictRep(..), Env, ForeignOp(..), ForeignOp'(..), MatrixRep(..), Val, for, lookup', restrict, (<+>)) patternMismatch :: String -> String -> String patternMismatch s s' = "Pattern mismatch: found " <> s <> ", expected " <> s' @@ -77,7 +77,7 @@ apply (V.Fun β (V.Closure γ1 ρ σ) × v) = do γ3 × e'' × β' × w <- match v σ t'' × v'' <- eval (γ1 <+> γ2 <+> γ3) (asExpr e'') (β ∧ β') pure $ T.AppClosure (S.fromFoldable (keys ρ)) w t'' × v'' -apply (V.Fun α (V.Foreign φ vs) × v) = do +apply (V.Fun α (V.Foreign (ForeignOp (id × φ)) vs) × v) = do t × v'' <- runExists apply' φ pure $ T.AppForeign (length vs + 1) t × v'' where @@ -86,9 +86,9 @@ apply (V.Fun α (V.Foreign φ vs) × v) = do apply' :: forall t. ForeignOp' t -> m (ForeignTrace × Val _) apply' (ForeignOp' φ') = do t × v'' <- do - if φ'.arity > length vs' then pure $ Nothing × V.Fun α (V.Foreign φ vs') + if φ'.arity > length vs' then pure $ Nothing × V.Fun α (V.Foreign (ForeignOp (id × φ)) vs') else first Just <$> φ'.op vs' - pure $ mkExists (ForeignTrace' (ForeignOp' φ') t) × v'' + pure $ ForeignTrace (id × mkExists (ForeignTrace' (ForeignOp' φ') t)) × v'' apply (V.Fun α (V.PartialConstr c vs) × v) = do check (length vs < n) ("Too many arguments to " <> showCtr c) pure $ T.AppConstr c × v' diff --git a/src/EvalBwd.purs b/src/EvalBwd.purs index c97037bc6..daf7007a6 100644 --- a/src/EvalBwd.purs +++ b/src/EvalBwd.purs @@ -25,11 +25,11 @@ import GaloisConnection (GaloisConnection(..)) import Lattice (Raw, bot, botOf, expand, (∨)) import Partial.Unsafe (unsafePartial) import Trace (AppTrace(..), Trace(..), VarDef(..)) as T -import Trace (AppTrace, ForeignTrace'(..), Match(..), Trace) +import Trace (AppTrace, ForeignTrace(..), ForeignTrace'(..), Match(..), Trace) import Util (type (×), Endo, absurd, definitely', error, nonEmpty, successful, (!), (×)) import Util.Pair (zip) as P import Val (Fun(..), Val(..)) as V -import Val (class Ann, DictRep(..), Env, ForeignOp, ForeignOp'(..), MatrixRep(..), Val, append_inv, (<+>)) +import Val (class Ann, DictRep(..), Env, ForeignOp(..), ForeignOp'(..), MatrixRep(..), Val, append_inv, (<+>)) closeDefsBwd :: forall a. Ann a => Env a -> Env a × RecDefs a × a closeDefsBwd γ = @@ -79,14 +79,14 @@ applyBwd (T.AppClosure xs w t3 × v) = γ1 × γ2 = append_inv xs γ1γ2 γ1' × δ' × β' = closeDefsBwd γ2 v' × σ = matchBwd γ3 (ContExpr e) β w -applyBwd (T.AppForeign n t × v) = +applyBwd (T.AppForeign n (ForeignTrace (id × t)) × v) = V.Fun α (V.Foreign φ vs'') × v2' where φ × α × { init: vs'', last: v2' } = second (second (definitely' <<< unsnoc)) $ runExists applyBwd' t where applyBwd' :: forall t. ForeignTrace' t -> ForeignOp × a × List (Val _) applyBwd' (ForeignTrace' (ForeignOp' φ) t') = - mkExists (ForeignOp' φ) × + ForeignOp (id × mkExists (ForeignOp' φ)) × if φ.arity > n then unsafePartial $ let V.Fun α (V.Foreign _ vs'') = v in α × vs'' else bot × φ.op_bwd (definitely' t' × v) applyBwd (T.AppConstr c × v) = diff --git a/src/EvalGraph.purs b/src/EvalGraph.purs index 299b89dcd..f10afd1e9 100644 --- a/src/EvalGraph.purs +++ b/src/EvalGraph.purs @@ -34,7 +34,7 @@ import Pretty (prettyP) import Primitive (string, intPair) import Util (type (×), check, concatM, error, orElse, successful, throw, with, (×)) import Util.Pair (unzip) as P -import Val (DictRep(..), Env, ForeignOp'(..), MatrixRep(..), Val, for, lookup', restrict, (<+>)) +import Val (DictRep(..), Env, ForeignOp(..), ForeignOp'(..), MatrixRep(..), Val, for, lookup', restrict, (<+>)) import Val (Fun(..), Val(..)) as V type GraphConfig g = @@ -92,7 +92,7 @@ apply (V.Fun α (V.Closure γ1 ρ σ)) v = do γ2 <- closeDefs γ1 ρ (singleton α) γ3 × κ × αs <- match v σ eval (γ1 <+> γ2 <+> γ3) (asExpr κ) (insert α αs) -apply (V.Fun α (V.Foreign φ vs)) v = +apply (V.Fun α (V.Foreign (ForeignOp (id × φ)) vs)) v = runExists apply' φ where vs' = snoc vs v @@ -100,7 +100,7 @@ apply (V.Fun α (V.Foreign φ vs)) v = apply' :: forall t. ForeignOp' t -> m (Val Vertex) apply' (ForeignOp' φ') = if φ'.arity > length vs' then - V.Fun <$> new (singleton α) <@> V.Foreign φ vs' + V.Fun <$> new (singleton α) <@> V.Foreign (ForeignOp (id × φ)) vs' else φ'.op' vs' apply (V.Fun α (V.PartialConstr c vs)) v = do check (length vs < n) ("Too many arguments to " <> showCtr c) diff --git a/src/Expr.purs b/src/Expr.purs index a7fce7fee..618efe377 100644 --- a/src/Expr.purs +++ b/src/Expr.purs @@ -121,6 +121,80 @@ instance BV (Cont a) where bv (ContElim σ) = bv σ bv (ContExpr _) = empty +instance JoinSemilattice a => JoinSemilattice (Elim a) where + maybeJoin (ElimVar x κ) (ElimVar x' κ') = ElimVar <$> (x ≞ x') <*> maybeJoin κ κ' + maybeJoin (ElimConstr cκs) (ElimConstr cκs') = + ElimConstr <$> ((keys cκs `consistentWith` keys cκs') *> maybeJoin cκs cκs') + maybeJoin (ElimRecord xs κ) (ElimRecord ys κ') = ElimRecord <$> (xs ≞ ys) <*> maybeJoin κ κ' + maybeJoin _ _ = throw "Incompatible eliminators" + + join σ = definedJoin σ + +instance BoundedJoinSemilattice a => Expandable (Elim a) (Raw Elim) where + expand (ElimVar x κ) (ElimVar x' κ') = ElimVar (x ≜ x') (expand κ κ') + expand (ElimConstr cκs) (ElimConstr cκs') = ElimConstr (expand cκs cκs') + expand (ElimRecord xs κ) (ElimRecord ys κ') = ElimRecord (xs ≜ ys) (expand κ κ') + expand _ _ = error "Incompatible eliminators" + +instance JoinSemilattice a => JoinSemilattice (Cont a) where + maybeJoin ContNone ContNone = pure ContNone + maybeJoin (ContExpr e) (ContExpr e') = ContExpr <$> maybeJoin e e' + maybeJoin (ContElim σ) (ContElim σ') = ContElim <$> maybeJoin σ σ' + maybeJoin _ _ = throw "Incompatible continuations" + + join κ = definedJoin κ + +instance BoundedJoinSemilattice a => Expandable (Cont a) (Raw Cont) where + expand ContNone ContNone = ContNone + expand (ContExpr e) (ContExpr e') = ContExpr (expand e e') + expand (ContElim σ) (ContElim σ') = ContElim (expand σ σ') + expand _ _ = error "Incompatible continuations" + +instance JoinSemilattice a => JoinSemilattice (VarDef a) where + join def = definedJoin def + maybeJoin (VarDef σ e) (VarDef σ' e') = VarDef <$> maybeJoin σ σ' <*> maybeJoin e e' + +instance BoundedJoinSemilattice a => Expandable (VarDef a) (Raw VarDef) where + expand (VarDef σ e) (VarDef σ' e') = VarDef (expand σ σ') (expand e e') + +instance JoinSemilattice a => JoinSemilattice (Expr a) where + maybeJoin (Var x) (Var x') = Var <$> (x ≞ x') + maybeJoin (Op op) (Op op') = Op <$> (op ≞ op') + maybeJoin (Int α n) (Int α' n') = Int (α ∨ α') <$> (n ≞ n') + maybeJoin (Str α str) (Str α' str') = Str (α ∨ α') <$> (str ≞ str') + maybeJoin (Float α n) (Float α' n') = Float (α ∨ α') <$> (n ≞ n') + maybeJoin (Record α xes) (Record α' xes') = Record (α ∨ α') <$> maybeJoin xes xes' + maybeJoin (Dictionary α ees) (Dictionary α' ees') = Dictionary (α ∨ α') <$> maybeJoin ees ees' + maybeJoin (Constr α c es) (Constr α' c' es') = Constr (α ∨ α') <$> (c ≞ c') <*> maybeJoin es es' + maybeJoin (Matrix α e1 (x × y) e2) (Matrix α' e1' (x' × y') e2') = + Matrix (α ∨ α') <$> maybeJoin e1 e1' <*> ((x ≞ x') `lift2 (×)` (y ≞ y')) <*> maybeJoin e2 e2' + maybeJoin (Lambda α σ) (Lambda α' σ') = Lambda (α ∨ α') <$> maybeJoin σ σ' + maybeJoin (Project e x) (Project e' x') = Project <$> maybeJoin e e' <*> (x ≞ x') + maybeJoin (App e1 e2) (App e1' e2') = App <$> maybeJoin e1 e1' <*> maybeJoin e2 e2' + maybeJoin (Let def e) (Let def' e') = Let <$> maybeJoin def def' <*> maybeJoin e e' + maybeJoin (LetRec α ρ e) (LetRec α' ρ' e') = LetRec (α ∨ α') <$> maybeJoin ρ ρ' <*> maybeJoin e e' + maybeJoin _ _ = throw "Incompatible expressions" + + join e = definedJoin e + +instance BoundedJoinSemilattice a => Expandable (Expr a) (Raw Expr) where + expand (Var x) (Var x') = Var (x ≜ x') + expand (Op op) (Op op') = Op (op ≜ op') + expand (Int α n) (Int _ n') = Int α (n ≜ n') + expand (Str α str) (Str _ str') = Str α (str ≜ str') + expand (Float α n) (Float _ n') = Float α (n ≜ n') + expand (Record α xes) (Record _ xes') = Record α (expand xes xes') + expand (Dictionary α ees) (Dictionary _ ees') = Dictionary α (expand ees ees') + expand (Constr α c es) (Constr _ c' es') = Constr α (c ≜ c') (expand es es') + expand (Matrix α e1 (x × y) e2) (Matrix _ e1' (x' × y') e2') = + Matrix α (expand e1 e1') ((x ≜ x') × (y ≜ y')) (expand e2 e2') + expand (Lambda α σ) (Lambda _ σ') = Lambda α (expand σ σ') + expand (Project e x) (Project e' x') = Project (expand e e') (x ≜ x') + expand (App e1 e2) (App e1' e2') = App (expand e1 e1') (expand e2 e2') + expand (Let def e) (Let def' e') = Let (expand def def') (expand e e') + expand (LetRec α ρ e) (LetRec _ ρ' e') = LetRec α (expand ρ ρ') (expand e e') + expand _ _ = error "Incompatible expressions" + -- ====================== -- boilerplate -- ====================== @@ -221,76 +295,7 @@ instance Foldable ProgCxt where foldr f = foldrDefault f foldMap f = foldMapDefaultL f -instance JoinSemilattice a => JoinSemilattice (Elim a) where - maybeJoin (ElimVar x κ) (ElimVar x' κ') = ElimVar <$> (x ≞ x') <*> maybeJoin κ κ' - maybeJoin (ElimConstr cκs) (ElimConstr cκs') = - ElimConstr <$> ((keys cκs `consistentWith` keys cκs') *> maybeJoin cκs cκs') - maybeJoin (ElimRecord xs κ) (ElimRecord ys κ') = ElimRecord <$> (xs ≞ ys) <*> maybeJoin κ κ' - maybeJoin _ _ = throw "Incompatible eliminators" - - join σ = definedJoin σ - -instance BoundedJoinSemilattice a => Expandable (Elim a) (Raw Elim) where - expand (ElimVar x κ) (ElimVar x' κ') = ElimVar (x ≜ x') (expand κ κ') - expand (ElimConstr cκs) (ElimConstr cκs') = ElimConstr (expand cκs cκs') - expand (ElimRecord xs κ) (ElimRecord ys κ') = ElimRecord (xs ≜ ys) (expand κ κ') - expand _ _ = error "Incompatible eliminators" - -instance JoinSemilattice a => JoinSemilattice (Cont a) where - maybeJoin ContNone ContNone = pure ContNone - maybeJoin (ContExpr e) (ContExpr e') = ContExpr <$> maybeJoin e e' - maybeJoin (ContElim σ) (ContElim σ') = ContElim <$> maybeJoin σ σ' - maybeJoin _ _ = throw "Incompatible continuations" - - join κ = definedJoin κ - -instance BoundedJoinSemilattice a => Expandable (Cont a) (Raw Cont) where - expand ContNone ContNone = ContNone - expand (ContExpr e) (ContExpr e') = ContExpr (expand e e') - expand (ContElim σ) (ContElim σ') = ContElim (expand σ σ') - expand _ _ = error "Incompatible continuations" - -instance JoinSemilattice a => JoinSemilattice (VarDef a) where - join def = definedJoin def - maybeJoin (VarDef σ e) (VarDef σ' e') = VarDef <$> maybeJoin σ σ' <*> maybeJoin e e' - -instance BoundedJoinSemilattice a => Expandable (VarDef a) (Raw VarDef) where - expand (VarDef σ e) (VarDef σ' e') = VarDef (expand σ σ') (expand e e') - -instance JoinSemilattice a => JoinSemilattice (Expr a) where - maybeJoin (Var x) (Var x') = Var <$> (x ≞ x') - maybeJoin (Op op) (Op op') = Op <$> (op ≞ op') - maybeJoin (Int α n) (Int α' n') = Int (α ∨ α') <$> (n ≞ n') - maybeJoin (Str α str) (Str α' str') = Str (α ∨ α') <$> (str ≞ str') - maybeJoin (Float α n) (Float α' n') = Float (α ∨ α') <$> (n ≞ n') - maybeJoin (Record α xes) (Record α' xes') = Record (α ∨ α') <$> maybeJoin xes xes' - maybeJoin (Dictionary α ees) (Dictionary α' ees') = Dictionary (α ∨ α') <$> maybeJoin ees ees' - maybeJoin (Constr α c es) (Constr α' c' es') = Constr (α ∨ α') <$> (c ≞ c') <*> maybeJoin es es' - maybeJoin (Matrix α e1 (x × y) e2) (Matrix α' e1' (x' × y') e2') = - Matrix (α ∨ α') <$> maybeJoin e1 e1' <*> ((x ≞ x') `lift2 (×)` (y ≞ y')) <*> maybeJoin e2 e2' - maybeJoin (Lambda α σ) (Lambda α' σ') = Lambda (α ∨ α') <$> maybeJoin σ σ' - maybeJoin (Project e x) (Project e' x') = Project <$> maybeJoin e e' <*> (x ≞ x') - maybeJoin (App e1 e2) (App e1' e2') = App <$> maybeJoin e1 e1' <*> maybeJoin e2 e2' - maybeJoin (Let def e) (Let def' e') = Let <$> maybeJoin def def' <*> maybeJoin e e' - maybeJoin (LetRec α ρ e) (LetRec α' ρ' e') = LetRec (α ∨ α') <$> maybeJoin ρ ρ' <*> maybeJoin e e' - maybeJoin _ _ = throw "Incompatible expressions" - - join e = definedJoin e - -instance BoundedJoinSemilattice a => Expandable (Expr a) (Raw Expr) where - expand (Var x) (Var x') = Var (x ≜ x') - expand (Op op) (Op op') = Op (op ≜ op') - expand (Int α n) (Int _ n') = Int α (n ≜ n') - expand (Str α str) (Str _ str') = Str α (str ≜ str') - expand (Float α n) (Float _ n') = Float α (n ≜ n') - expand (Record α xes) (Record _ xes') = Record α (expand xes xes') - expand (Dictionary α ees) (Dictionary _ ees') = Dictionary α (expand ees ees') - expand (Constr α c es) (Constr _ c' es') = Constr α (c ≜ c') (expand es es') - expand (Matrix α e1 (x × y) e2) (Matrix _ e1' (x' × y') e2') = - Matrix α (expand e1 e1') ((x ≜ x') × (y ≜ y')) (expand e2 e2') - expand (Lambda α σ) (Lambda _ σ') = Lambda α (expand σ σ') - expand (Project e x) (Project e' x') = Project (expand e e') (x ≜ x') - expand (App e1 e2) (App e1' e2') = App (expand e1 e1') (expand e2 e2') - expand (Let def e) (Let def' e') = Let (expand def def') (expand e e') - expand (LetRec α ρ e) (LetRec _ ρ' e') = LetRec α (expand ρ ρ') (expand e e') - expand _ _ = error "Incompatible expressions" +derive instance Ord a => Ord (Expr a) +derive instance Ord a => Ord (Elim a) +derive instance Ord a => Ord (Cont a) +derive instance Ord a => Ord (VarDef a) diff --git a/src/Pretty.purs b/src/Pretty.purs index eee631f80..e596f3b27 100644 --- a/src/Pretty.purs +++ b/src/Pretty.purs @@ -4,7 +4,6 @@ import Prelude hiding (absurd, between) import Bindings (Bind, key, val, Var, (↦)) import Data.Array (foldl) -import Data.Exists (runExists) import Data.Foldable (class Foldable) import Data.List (List(..), fromFoldable, null, uncons, (:)) import Data.List.NonEmpty (NonEmptyList, groupBy, singleton, toList) @@ -29,7 +28,7 @@ import Util (type (+), type (×), Endo, absurd, assert, error, intersperse, (×) import Util.Pair (Pair(..), toTuple) import Util.Pretty (Doc(..), atop, beside, empty, hcat, render, text) import Val (Fun(..), Val(..)) as V -import Val (class Ann, class Highlightable, DictRep(..), ForeignOp', Fun, MatrixRep(..), Val, highlightIf) +import Val (class Ann, class Highlightable, DictRep(..), ForeignOp(..), Fun, MatrixRep(..), Val, highlightIf) class Pretty p where pretty :: p -> Doc @@ -434,11 +433,11 @@ instance Highlightable a => Pretty (Val a) where instance Highlightable a => Pretty (a × Fun a) where pretty (α × V.Closure _ _ _) = (highlightIf α $ text "") - pretty (_ × V.Foreign φ _) = runExists pretty φ + pretty (_ × V.Foreign φ _) = pretty φ pretty (α × V.PartialConstr c vs) = prettyConstr α c vs -instance Pretty (ForeignOp' t) where - pretty _ = text "" -- TODO +instance Pretty ForeignOp where + pretty (ForeignOp (s × _)) = text s instance (Pretty a, Pretty b) => Pretty (a + b) where pretty = pretty ||| pretty diff --git a/src/Primitive.purs b/src/Primitive.purs index 5a88d2045..cc16e0645 100644 --- a/src/Primitive.purs +++ b/src/Primitive.purs @@ -1,9 +1,9 @@ module Primitive where import Prelude hiding (absurd, apply, div, top) - +import Bindings (Bind) import Data.Either (Either(..)) -import Data.Exists (mkExists) +import Data.Exists (Exists, mkExists) import Data.Int (toNumber) import Data.List (List(..), (:)) import Data.Profunctor.Choice ((|||)) @@ -15,7 +15,7 @@ import Lattice (class BoundedJoinSemilattice, Raw, (∧), bot, erase) import Partial.Unsafe (unsafePartial) import Pretty (prettyP) import Util (type (+), type (×), (×), error) -import Val (class Ann, ForeignOp'(..), Fun(..), MatrixRep, OpBwd, OpFwd, OpGraph, Val(..)) +import Val (class Ann, ForeignOp(..), ForeignOp'(..), Fun(..), MatrixRep, OpBwd, OpFwd, OpGraph, Val(..)) -- Mediate between values of annotation type a and (potential) underlying datatype d, analogous to -- pattern-matching and construction for data types. Wasn't able to make a typeclass version of this @@ -153,59 +153,65 @@ type BinaryZero i o a = , fwd :: i -> i -> o } -unary :: forall i o a'. BoundedJoinSemilattice a' => (forall a. Unary i o a) -> Val a' -unary op = - Fun bot $ flip Foreign Nil - $ mkExists - $ ForeignOp' { arity: 1, op': unsafePartial op', op: unsafePartial fwd, op_bwd: unsafePartial bwd } +unary :: forall i o a'. BoundedJoinSemilattice a' => String -> (forall a. Unary i o a) -> Bind (Val a') +unary id f = + id × Fun bot (Foreign (ForeignOp (id × op)) Nil) where + op :: Exists ForeignOp' + op = mkExists $ + ForeignOp' { arity: 1, op': unsafePartial op', op: unsafePartial fwd, op_bwd: unsafePartial bwd } + op' :: Partial => OpGraph op' (v : Nil) = - op.o.pack <$> ((op.fwd x × _) <$> new (singleton α)) + f.o.pack <$> ((f.fwd x × _) <$> new (singleton α)) where - x × α = op.i.unpack v + x × α = f.i.unpack v fwd :: Partial => OpFwd (Raw Val) - fwd (v : Nil) = pure $ erase v × op.o.pack (op.fwd x × α) + fwd (v : Nil) = pure $ erase v × f.o.pack (f.fwd x × α) where - x × α = op.i.unpack v + x × α = f.i.unpack v bwd :: Partial => OpBwd (Raw Val) - bwd (u × v) = op.i.pack (x × α) : Nil + bwd (u × v) = f.i.pack (x × α) : Nil where - _ × α = op.o.unpack v - (x × _) = op.i.unpack u - -binary :: forall i1 i2 o a'. BoundedJoinSemilattice a' => (forall a. Binary i1 i2 o a) -> Val a' -binary op = - Fun bot $ flip Foreign Nil - $ mkExists - $ ForeignOp' { arity: 2, op': unsafePartial op', op: unsafePartial fwd, op_bwd: unsafePartial bwd } + _ × α = f.o.unpack v + (x × _) = f.i.unpack u + +binary :: forall i1 i2 o a'. BoundedJoinSemilattice a' => String -> (forall a. Binary i1 i2 o a) -> Bind (Val a') +binary id f = + id × Fun bot (Foreign (ForeignOp (id × op)) Nil) where + op :: Exists ForeignOp' + op = mkExists $ + ForeignOp' { arity: 2, op': unsafePartial op', op: unsafePartial fwd, op_bwd: unsafePartial bwd } + op' :: Partial => OpGraph op' (v1 : v2 : Nil) = - op.o.pack <$> ((op.fwd x y × _) <$> new (singleton α # insert β)) + f.o.pack <$> ((f.fwd x y × _) <$> new (singleton α # insert β)) where - (x × α) × (y × β) = op.i1.unpack v1 × op.i2.unpack v2 + (x × α) × (y × β) = f.i1.unpack v1 × f.i2.unpack v2 fwd :: Partial => OpFwd (Raw Val × Raw Val) - fwd (v1 : v2 : Nil) = pure $ (erase v1 × erase v2) × op.o.pack (op.fwd x y × (α ∧ β)) + fwd (v1 : v2 : Nil) = pure $ (erase v1 × erase v2) × f.o.pack (f.fwd x y × (α ∧ β)) where - (x × α) × (y × β) = op.i1.unpack v1 × op.i2.unpack v2 + (x × α) × (y × β) = f.i1.unpack v1 × f.i2.unpack v2 bwd :: Partial => OpBwd (Raw Val × Raw Val) - bwd ((u1 × u2) × v) = op.i1.pack (x × α) : op.i2.pack (y × α) : Nil + bwd ((u1 × u2) × v) = f.i1.pack (x × α) : f.i2.pack (y × α) : Nil where - _ × α = op.o.unpack v - (x × _) × (y × _) = op.i1.unpack u1 × op.i2.unpack u2 + _ × α = f.o.unpack v + (x × _) × (y × _) = f.i1.unpack u1 × f.i2.unpack u2 -- If both are zero, depend only on the first. -binaryZero :: forall i o a'. BoundedJoinSemilattice a' => IsZero i => (forall a. BinaryZero i o a) -> Val a' -binaryZero op = - Fun bot $ flip Foreign Nil - $ mkExists - $ ForeignOp' { arity: 2, op': unsafePartial op', op: unsafePartial fwd, op_bwd: unsafePartial bwd } +binaryZero :: forall i o a'. BoundedJoinSemilattice a' => IsZero i => String -> (forall a. BinaryZero i o a) -> Bind (Val a') +binaryZero id f = + id × Fun bot (Foreign (ForeignOp (id × op)) Nil) where + op :: Exists ForeignOp' + op = mkExists $ + ForeignOp' { arity: 2, op': unsafePartial op', op: unsafePartial fwd, op_bwd: unsafePartial bwd } + op' :: Partial => OpGraph op' (v1 : v2 : Nil) = let @@ -214,22 +220,22 @@ binaryZero op = else if isZero y then singleton β else singleton α # insert β in - op.o.pack <$> ((op.fwd x y × _) <$> new αs) + f.o.pack <$> ((f.fwd x y × _) <$> new αs) where - (x × α) × (y × β) = op.i.unpack v1 × op.i.unpack v2 + (x × α) × (y × β) = f.i.unpack v1 × f.i.unpack v2 fwd :: Partial => OpFwd (Raw Val × Raw Val) fwd (v1 : v2 : Nil) = pure $ (erase v1 × erase v2) × - op.o.pack (op.fwd x y × if isZero x then α else if isZero y then β else α ∧ β) + f.o.pack (f.fwd x y × if isZero x then α else if isZero y then β else α ∧ β) where - (x × α) × (y × β) = op.i.unpack v1 × op.i.unpack v2 + (x × α) × (y × β) = f.i.unpack v1 × f.i.unpack v2 bwd :: Partial => OpBwd (Raw Val × Raw Val) - bwd ((u1 × u2) × v) = op.i.pack (x × β1) : op.i.pack (y × β2) : Nil + bwd ((u1 × u2) × v) = f.i.pack (x × β1) : f.i.pack (y × β2) : Nil where - _ × α = op.o.unpack v - (x × _) × (y × _) = op.i.unpack u1 × op.i.unpack u2 + _ × α = f.o.unpack v + (x × _) × (y × _) = f.i.unpack u1 × f.i.unpack u2 β1 × β2 = if isZero x then α × bot else if isZero y then bot × α diff --git a/src/Primitive/Defs.purs b/src/Primitive/Defs.purs index 6a413ea04..52e0bcc3b 100644 --- a/src/Primitive/Defs.purs +++ b/src/Primitive/Defs.purs @@ -1,7 +1,7 @@ module Primitive.Defs where import Prelude hiding (absurd, apply, div, mod, top) - +import Bindings (Bind) import Data.Exists (mkExists) import Data.Foldable (foldl, foldM) import Data.FoldableWithIndex (foldWithIndexM) @@ -27,50 +27,51 @@ import Prelude (div, mod) as P import Primitive (binary, binaryZero, boolean, int, intOrNumber, intOrNumberOrString, number, string, unary, union, union1, unionStr) import Trace (AppTrace) import Util (type (+), type (×), Endo, error, orElse, throw, unimplemented, (×)) -import Val (Array2, DictRep(..), Env, ForeignOp, ForeignOp'(..), Fun(..), MatrixRep(..), OpBwd, OpFwd, OpGraph, Val(..), matrixGet, matrixPut) +import Val (Array2, DictRep(..), Env, ForeignOp(..), ForeignOp'(..), Fun(..), MatrixRep(..), OpBwd, OpFwd, OpGraph, Val(..), matrixGet, matrixPut) -extern :: forall a. BoundedJoinSemilattice a => ForeignOp -> Val a -extern = Fun bot <<< flip Foreign Nil +extern :: forall a. BoundedJoinSemilattice a => ForeignOp -> Bind (Val a) +extern (ForeignOp (id × φ)) = id × Fun bot ((Foreign (ForeignOp (id × φ))) Nil) primitives :: Raw Env primitives = D.fromFoldable [ ":" × Fun bot (PartialConstr cCons Nil) - , "ceiling" × unary { i: number, o: int, fwd: ceil } - , "debugLog" × extern debugLog - , "dims" × extern dims - , "error" × extern error_ - , "floor" × unary { i: number, o: int, fwd: floor } - , "log" × unary { i: intOrNumber, o: number, fwd: log } - , "numToStr" × unary { i: intOrNumber, o: string, fwd: numToStr } - , "+" × binary { i1: intOrNumber, i2: intOrNumber, o: intOrNumber, fwd: plus } - , "-" × binary { i1: intOrNumber, i2: intOrNumber, o: intOrNumber, fwd: minus } - , "*" × binaryZero { i: intOrNumber, o: intOrNumber, fwd: times } - , "**" × binaryZero { i: intOrNumber, o: intOrNumber, fwd: pow } - , "/" × binaryZero { i: intOrNumber, o: intOrNumber, fwd: divide } - , "==" × binary { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: equals } - , "/=" × binary { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: notEquals } - , "<" × binary { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: lessThan } - , ">" × binary { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: greaterThan } - , "<=" × binary { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: lessThanEquals } - , ">=" × binary { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: greaterThanEquals } - , "++" × binary { i1: string, i2: string, o: string, fwd: concat } - , "!" × extern matrixLookup - , "dict_difference" × extern dict_difference - , "dict_disjointUnion" × extern dict_disjointUnion - , "dict_foldl" × extern dict_foldl - , "dict_fromRecord" × extern dict_fromRecord - , "dict_get" × extern dict_get - , "dict_intersectionWith" × extern dict_intersectionWith - , "dict_map" × extern dict_map - , "div" × binaryZero { i: int, o: int, fwd: div } - , "matrixUpdate" × extern matrixUpdate - , "mod" × binaryZero { i: int, o: int, fwd: mod } - , "quot" × binaryZero { i: int, o: int, fwd: quot } - , "rem" × binaryZero { i: int, o: int, fwd: rem } + , unary "ceiling" { i: number, o: int, fwd: ceil } + , extern debugLog + , extern dims + , extern error_ + , unary "floor" { i: number, o: int, fwd: floor } + , unary "log" { i: intOrNumber, o: number, fwd: log } + , unary "numToStr" { i: intOrNumber, o: string, fwd: numToStr } + , binary "+" { i1: intOrNumber, i2: intOrNumber, o: intOrNumber, fwd: plus } + , binary "-" { i1: intOrNumber, i2: intOrNumber, o: intOrNumber, fwd: minus } + , binaryZero "*" { i: intOrNumber, o: intOrNumber, fwd: times } + , binaryZero "**" { i: intOrNumber, o: intOrNumber, fwd: pow } + , binaryZero "/" { i: intOrNumber, o: intOrNumber, fwd: divide } + , binary "==" { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: equals } + , binary "/=" { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: notEquals } + , binary "<" { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: lessThan } + , binary ">" { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: greaterThan } + , binary "<=" { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: lessThanEquals } + , binary ">=" { i1: intOrNumberOrString, i2: intOrNumberOrString, o: boolean, fwd: greaterThanEquals } + , binary "++" { i1: string, i2: string, o: string, fwd: concat } + , extern matrixLookup + , extern dict_difference + , extern dict_disjointUnion + , extern dict_foldl + , extern dict_fromRecord + , extern dict_get + , extern dict_intersectionWith + , extern dict_map + , binaryZero "div" { i: int, o: int, fwd: div } + , extern matrixUpdate + , binaryZero "mod" { i: int, o: int, fwd: mod } + , binaryZero "quot" { i: int, o: int, fwd: quot } + , binaryZero "rem" { i: int, o: int, fwd: rem } ] error_ :: ForeignOp -error_ = mkExists $ ForeignOp' { arity: 1, op': op', op: fwd, op_bwd: unsafePartial bwd } +error_ = + ForeignOp ("error" × mkExists (ForeignOp' { arity: 1, op': op', op: fwd, op_bwd: unsafePartial bwd })) where op' :: OpGraph op' (Str _ s : Nil) = pure $ error s @@ -84,7 +85,8 @@ error_ = mkExists $ ForeignOp' { arity: 1, op': op', op: fwd, op_bwd: unsafePart bwd _ = error unimplemented debugLog :: ForeignOp -debugLog = mkExists $ ForeignOp' { arity: 1, op': op', op: fwd, op_bwd: unsafePartial bwd } +debugLog = + ForeignOp ("debugLog" × mkExists (ForeignOp' { arity: 1, op': op', op: fwd, op_bwd: unsafePartial bwd })) where op' :: OpGraph op' (x : Nil) = pure $ trace x (const x) @@ -98,7 +100,8 @@ debugLog = mkExists $ ForeignOp' { arity: 1, op': op', op: fwd, op_bwd: unsafePa bwd _ = error unimplemented dims :: ForeignOp -dims = mkExists $ ForeignOp' { arity: 1, op': op, op: fwd, op_bwd: unsafePartial bwd } +dims = + ForeignOp ("dims" × mkExists (ForeignOp' { arity: 1, op': op, op: fwd, op_bwd: unsafePartial bwd })) where op :: OpGraph op (Matrix α (MatrixRep (_ × (i × β1) × (j × β2))) : Nil) = do @@ -117,7 +120,8 @@ dims = mkExists $ ForeignOp' { arity: 1, op': op, op: fwd, op_bwd: unsafePartial Matrix α (MatrixRep (((<$>) botOf <$> vss) × (i × β1) × (j × β2))) : Nil matrixLookup :: ForeignOp -matrixLookup = mkExists $ ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: bwd } +matrixLookup = + ForeignOp ("!" × mkExists (ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: bwd })) where op :: OpGraph op (Matrix _ r : Constr _ c (Int _ i : Int _ j : Nil) : Nil) | c == cPair = @@ -136,7 +140,8 @@ matrixLookup = mkExists $ ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: bwd } : Nil matrixUpdate :: ForeignOp -matrixUpdate = mkExists $ ForeignOp' { arity: 3, op': op, op: fwd, op_bwd: unsafePartial bwd } +matrixUpdate = + ForeignOp ("matrixUpdate" × mkExists (ForeignOp' { arity: 3, op': op, op: fwd, op_bwd: unsafePartial bwd })) where op :: OpGraph op (Matrix _ r : Constr _ c (Int _ i : Int _ j : Nil) : v : Nil) @@ -156,7 +161,8 @@ matrixUpdate = mkExists $ ForeignOp' { arity: 3, op': op, op: fwd, op_bwd: unsaf : Nil dict_difference :: ForeignOp -dict_difference = mkExists $ ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: unsafePartial bwd } +dict_difference = + ForeignOp ("dict_difference" × mkExists (ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: unsafePartial bwd })) where op :: OpGraph op (Dictionary α (DictRep d) : Dictionary β (DictRep d') : Nil) = @@ -173,7 +179,8 @@ dict_difference = mkExists $ ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: un Dictionary α d : Dictionary α (DictRep D.empty) : Nil dict_fromRecord :: ForeignOp -dict_fromRecord = mkExists $ ForeignOp' { arity: 1, op': op, op: fwd, op_bwd: unsafePartial bwd } +dict_fromRecord = + ForeignOp ("dict_fromRecord" × mkExists (ForeignOp' { arity: 1, op': op, op: fwd, op_bwd: unsafePartial bwd })) where op :: OpGraph op (Record α xvs : Nil) = do @@ -191,7 +198,8 @@ dict_fromRecord = mkExists $ ForeignOp' { arity: 1, op': op, op: fwd, op_bwd: un Record (foldl (∨) α (d <#> fst)) (d <#> snd) : Nil dict_disjointUnion :: ForeignOp -dict_disjointUnion = mkExists $ ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: unsafePartial bwd } +dict_disjointUnion = + ForeignOp ("dict_disjointUnion" × mkExists (ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: unsafePartial bwd })) where op :: OpGraph op (Dictionary α (DictRep d) : Dictionary β (DictRep d') : Nil) = do @@ -208,7 +216,8 @@ dict_disjointUnion = mkExists $ ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: Dictionary α (DictRep (d'' \\ d')) : Dictionary α (DictRep (d'' \\ d)) : Nil dict_foldl :: ForeignOp -dict_foldl = mkExists $ ForeignOp' { arity: 3, op': op, op: fwd, op_bwd: unsafePartial bwd } +dict_foldl = + ForeignOp ("dict_foldl" × mkExists (ForeignOp' { arity: 3, op': op, op: fwd, op_bwd: unsafePartial bwd })) where op :: OpGraph op (v : u : Dictionary _ (DictRep d) : Nil) = @@ -237,7 +246,8 @@ dict_foldl = mkExists $ ForeignOp' { arity: 3, op': op, op: fwd, op_bwd: unsafeP ts dict_get :: ForeignOp -dict_get = mkExists $ ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: unsafePartial bwd } +dict_get = + ForeignOp ("dict_get" × mkExists (ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: unsafePartial bwd })) where op :: OpGraph op (Str _ s : Dictionary _ (DictRep d) : Nil) = @@ -254,7 +264,8 @@ dict_get = mkExists $ ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: unsafePar Str bot s : Dictionary bot (DictRep $ D.singleton s (bot × v)) : Nil dict_intersectionWith :: ForeignOp -dict_intersectionWith = mkExists $ ForeignOp' { arity: 3, op': op, op: fwd, op_bwd: unsafePartial bwd } +dict_intersectionWith = + ForeignOp ("dict_intersectionWith" × mkExists (ForeignOp' { arity: 3, op': op, op: fwd, op_bwd: unsafePartial bwd })) where op :: OpGraph op (v : Dictionary α (DictRep d1) : Dictionary α' (DictRep d2) : Nil) = @@ -287,7 +298,8 @@ dict_intersectionWith = mkExists $ ForeignOp' { arity: 3, op': op, op: fwd, op_b :: Dict (_ × Val _ × Val _ × Val _) dict_map :: ForeignOp -dict_map = mkExists $ ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: unsafePartial bwd } +dict_map = + ForeignOp ("dict_map" × mkExists (ForeignOp' { arity: 2, op': op, op: fwd, op_bwd: unsafePartial bwd })) where op :: OpGraph op (v : Dictionary α (DictRep d) : Nil) = do diff --git a/src/Trace.purs b/src/Trace.purs index e762b210b..4b5aadf26 100644 --- a/src/Trace.purs +++ b/src/Trace.purs @@ -34,7 +34,7 @@ data AppTrace | AppConstr Ctr data ForeignTrace' t = ForeignTrace' (ForeignOp' t) (Maybe t) -type ForeignTrace = Exists ForeignTrace' +newtype ForeignTrace = ForeignTrace (String × Exists ForeignTrace') data VarDef = VarDef Match Trace diff --git a/src/Util/Pair.purs b/src/Util/Pair.purs index b42f08fa2..f026aefe3 100644 --- a/src/Util/Pair.purs +++ b/src/Util/Pair.purs @@ -11,8 +11,26 @@ import Util (type (×), (×)) -- a |-> a × a can't derive a functor instance, so use this data Pair a = Pair a a -instance Eq a => Eq (Pair a) where - eq (Pair x1 y1) (Pair x2 y2) = (eq x1 x2) && (eq y1 y2) +instance Show a => Show (Pair a) where + show (Pair x y) = "(Pair " <> show x <> " " <> show y <> ")" + +toTuple :: forall a. Pair a -> a × a +toTuple (Pair x y) = x × y + +fromTuple :: forall a. a × a -> Pair a +fromTuple (x × y) = Pair x y + +zip :: forall a. List a -> List a -> List (Pair a) +zip xs ys = L.zip xs ys <#> fromTuple + +unzip :: forall a. List (Pair a) -> List a × List a +unzip xys = xys <#> toTuple # L.unzip + +-- ====================== +-- boilerplate +-- ====================== +derive instance Eq a => Eq (Pair a) +derive instance Ord a => Ord (Pair a) instance Functor Pair where map f (Pair x y) = Pair (f x) (f y) @@ -31,18 +49,3 @@ instance Foldable Pair where instance Traversable Pair where traverse f (Pair x y) = Pair <$> f x <*> f y sequence = sequenceDefault - -instance Show a => Show (Pair a) where - show (Pair x y) = "(Pair " <> show x <> " " <> show y <> ")" - -toTuple :: forall a. Pair a -> a × a -toTuple (Pair x y) = x × y - -fromTuple :: forall a. a × a -> Pair a -fromTuple (x × y) = Pair x y - -zip :: forall a. List a -> List a -> List (Pair a) -zip xs ys = L.zip xs ys <#> fromTuple - -unzip :: forall a. List (Pair a) -> List a × List a -unzip xys = xys <#> toTuple # L.unzip diff --git a/src/Val.purs b/src/Val.purs index 29bc9fea2..1d7618f09 100644 --- a/src/Val.purs +++ b/src/Val.purs @@ -63,7 +63,13 @@ data ForeignOp' t = ForeignOp' , op_bwd :: OpBwd t } -type ForeignOp = Exists ForeignOp' +newtype ForeignOp = ForeignOp (String × Exists ForeignOp') -- string is unique identifier (for Eq) + +instance Eq ForeignOp where + eq (ForeignOp (s × _)) (ForeignOp (s' × _)) = s == s' + +instance Ord ForeignOp where + compare (ForeignOp (s × _)) (ForeignOp (s' × _)) = compare s s' -- Environments. type Env a = Dict (Val a) @@ -254,3 +260,13 @@ instance BoundedJoinSemilattice a => Expandable (Fun a) (Raw Fun) where instance Neg a => Neg (Val a) where neg = (<$>) neg + +derive instance Eq a => Eq (Val a) +derive instance Eq a => Eq (DictRep a) +derive instance Eq a => Eq (MatrixRep a) +derive instance Eq a => Eq (Fun a) + +derive instance Ord a => Ord (Val a) +derive instance Ord a => Ord (DictRep a) +derive instance Ord a => Ord (MatrixRep a) +derive instance Ord a => Ord (Fun a) diff --git a/test/Main.purs b/test/Main.purs index 77054e736..d4655dd3e 100644 --- a/test/Main.purs +++ b/test/Main.purs @@ -17,7 +17,6 @@ main :: Effect Unit main = run tests tests :: Array (String × Aff Unit) -{- tests = concat [ test_desugaring , test_misc @@ -25,9 +24,10 @@ tests = concat , test_graphics , test_linking ] --} +{- tests = concat [ test_scratchpad ] +-} test_scratchpad :: Array (String × Aff Unit) test_scratchpad = second void <$> bwdMany diff --git a/test/Util.purs b/test/Util.purs index 540e43281..3a4deb5fb 100644 --- a/test/Util.purs +++ b/test/Util.purs @@ -109,8 +109,9 @@ testTrace s γα spec@{ δv } = do validate method spec s𝔹 v𝔹' - let γ𝔹_top × e𝔹_top × _ = eval.bwd (topOf v) - v𝔹_top' = eval.fwd (γ𝔹_top × e𝔹_top × top) + let + γ𝔹_top × e𝔹_top × _ = eval.bwd (topOf v) + v𝔹_top' = eval.fwd (γ𝔹_top × e𝔹_top × top) PrettyShow v𝔹_top' `shouldSatisfy "fwd ⚬ bwd round-tripping property"` (const true) testGraph :: forall m. MonadWriter BenchRow m => Raw SE.Expr -> GraphConfig GraphImpl -> TestConfig -> Boolean -> AffError m Unit