Skip to content

Commit

Permalink
Merge pull request #791 from explorable-viz/primitives
Browse files Browse the repository at this point in the history
Unify `constr_bwd` and `match` in Primitive
  • Loading branch information
rolyp authored Oct 7, 2023
2 parents 65d9cd4 + 7957397 commit 0aabb64
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 96 deletions.
4 changes: 2 additions & 2 deletions src/App/BarChart.purs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ foreign import drawBarChart :: Renderer BarChart

instance Reflect (Dict (Val Boolean)) BarChartRecord where
from r = BarChartRecord
{ x: string.match (get f_x r)
{ x: string.unpack (get f_x r)
, y: get_intOrNumber f_y r
}

instance Reflect (Dict (Val Boolean)) BarChart where
from r = BarChart
{ caption: string.match (get f_caption r)
{ caption: string.unpack (get f_caption r)
, data: record from <$> from (get f_data r)
}

Expand Down
2 changes: 1 addition & 1 deletion src/App/Fig.purs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ view _ (Constr _ c (u1 : Nil)) | c == cLineChart =
view title u@(Constr _ c _) | c == cNil || c == cCons =
EnergyTableView (EnergyTable { title, table: unsafePartial $ record energyRecord <$> from u })
view title u@(Matrix _ _) =
MatrixFig (MatrixView { title, matrix: matrixRep $ fst (P.matrixRep.match u) })
MatrixFig (MatrixView { title, matrix: matrixRep $ fst (P.matrixRep.unpack u) })
view _ _ = error absurd

-- An example of the form (let <defs> in expr) can be decomposed as follows.
Expand Down
4 changes: 2 additions & 2 deletions src/App/LineChart.purs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ instance Reflect (Dict (Val Boolean)) Point where

instance Reflect (Dict (Val Boolean)) LinePlot where
from r = LinePlot
{ name: string.match (get f_name r)
{ name: string.unpack (get f_name r)
, data: record from <$> from (get f_data r)
}

instance Reflect (Dict (Val Boolean)) LineChart where
from r = LineChart
{ caption: string.match (get f_caption r)
{ caption: string.unpack (get f_caption r)
, plots: from <$> (from (get f_plots r) :: Array (Val 𝔹)) :: Array LinePlot
}

Expand Down
2 changes: 1 addition & 1 deletion src/App/MatrixView.purs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ foreign import drawMatrix :: Renderer MatrixView

matrixRep :: MatrixRep 𝔹 -> IntMatrix
matrixRep (MatrixRep (vss × (i × _) × (j × _))) =
((<$>) ((<$>) (\x -> int.match x))) vss × i × j
((int.unpack <$> _) <$> vss) × i × j

matrixViewHandler :: Handler
matrixViewHandler ev = flip (uncurry matrixElement) neg $ unsafePos $ target ev
Expand Down
6 changes: 3 additions & 3 deletions src/App/TableView.purs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ foreign import drawTable :: Renderer EnergyTable

energyRecord :: Dict (Val 𝔹) -> EnergyRecord
energyRecord r =
{ year: int.match (get "year" r)
, country: string.match (get "country" r)
, energyType: string.match (get "energyType" r)
{ year: int.unpack (get "year" r)
, country: string.unpack (get "country" r)
, energyType: string.unpack (get "energyType" r)
, output: get_intOrNumber "output" r
}

Expand Down
4 changes: 2 additions & 2 deletions src/App/Util.purs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ doNothing :: OnSel
doNothing = const $ pure unit

get_intOrNumber :: Var -> Dict (Val 𝔹) -> Number × 𝔹
get_intOrNumber x r = first as (intOrNumber.match (get x r))
get_intOrNumber x r = first as (intOrNumber.unpack (get x r))

-- Assumes fields are all of primitive type.
record :: forall a. (Dict (Val 𝔹) -> a) -> Val 𝔹 -> a
record toRecord u = toRecord (fst (P.record.match u))
record toRecord u = toRecord (fst (P.record.unpack u))

class Reflect a b where
from :: Partial => a -> b
Expand Down
4 changes: 2 additions & 2 deletions src/Eval.purs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ eval γ (Record α xes) α' = do
eval γ (Dictionary α ees) α' = do
(ts × vs) × (ts' × us) <- traverse (traverse (flip (eval γ) α')) ees <#> (P.unzip >>> (unzip # both))
let
ss × αs = (vs <#> \u -> string.match u) # unzip
ss × αs = (vs <#> \u -> string.unpack u) # unzip
d = D.fromFoldable $ zip ss (zip αs us)
pure $ T.Dictionary (zip ss (zip ts ts')) (d <#> snd >>> erase) × V.Dictionary (α ∧ α') (DictRep d)
eval γ (Constr α c es) α' = do
Expand All @@ -126,7 +126,7 @@ eval γ (Constr α c es) α' = do
pure (T.Constr c ts × V.Constr (α ∧ α') c vs)
eval γ (Matrix α e (x × y) e') α' = do
t × v <- eval γ e' α'
let (i' × β) × (j' × β') = fst (intPair.match v)
let (i' × β) × (j' × β') = fst (intPair.unpack v)
check (i' × j' >= 1 × 1) ("array must be at least (" <> show (1 × 1) <> "); got (" <> show (i' × j') <> ")")
tss × vss <- unzipToArray <$> ((<$>) unzipToArray) <$>
( sequence do
Expand Down
4 changes: 2 additions & 2 deletions src/EvalGraph.purs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ eval γ (Record α xes) αs = do
eval γ (Dictionary α ees) αs = do
vs × us <- traverse (traverse (flip (eval γ) αs)) ees <#> P.unzip
let
ss × βs = (vs <#> string.match) # unzip
ss × βs = (vs <#> string.unpack) # unzip
d = D.fromFoldable $ zip ss (zip βs us)
V.Dictionary <$> new (insert α αs) <@> DictRep d
eval γ (Constr α c es) αs = do
Expand All @@ -131,7 +131,7 @@ eval γ (Constr α c es) αs = do
V.Constr <$> new (insert α αs) <@> c <@> vs
eval γ (Matrix α e (x × y) e') αs = do
v <- eval γ e' αs
let (i' × β) × (j' × β') = fst (intPair.match v)
let (i' × β) × (j' × β') = fst (intPair.unpack v)
check
(i' × j' >= 1 × 1)
("array must be at least (" <> show (1 × 1) <> "); got (" <> show (i' × j') <> ")")
Expand Down
146 changes: 65 additions & 81 deletions src/Primitive.purs
Original file line number Diff line number Diff line change
Expand Up @@ -21,121 +21,105 @@ import Val (class Ann, ForeignOp'(..), Fun(..), MatrixRep, OpBwd, OpFwd, OpGraph
-- pattern-matching and construction for data types. Wasn't able to make a typeclass version of this
-- work with the required higher-rank polymorphism.
type ToFrom d a =
{ constr :: d × a -> Val a
, constr_bwd :: Val a -> d × a -- equivalent to match (except at Val)
, match :: Val a -> d × a
{ pack :: d × a -> Val a
, unpack :: Val a -> d × a
}

typeError :: forall a b. Val a -> String -> b
typeError v typeName = error (typeName <> " expected; got " <> prettyP (erase v))

int :: forall a. ToFrom Int a
int =
{ constr: \(n × α) -> Int α n
, constr_bwd: match'
, match: match'
{ pack: \(n × α) -> Int α n
, unpack
}
where
match' :: _
match' (Int α n) = n × α
match' v = error ("Int expected; got " <> prettyP (erase v))
unpack (Int α n) = n × α
unpack v = typeError v "Int"

number :: forall a. ToFrom Number a
number =
{ constr: \(n × α) -> Float α n
, constr_bwd: match'
, match: match'
{ pack: \(n × α) -> Float α n
, unpack
}
where
match' :: _
match' (Float α n) = n × α
match' v = error ("Float expected; got " <> prettyP (erase v))
unpack (Float α n) = n × α
unpack v = typeError v "Float"

string :: forall a. ToFrom String a
string =
{ constr: \(str × α) -> Str α str
, constr_bwd: match'
, match: match'
{ pack: \(str × α) -> Str α str
, unpack
}
where
match' :: _
match' (Str α str) = str × α
match' v = error ("Str expected; got " <> prettyP (erase v))
unpack (Str α str) = str × α
unpack v = typeError v "Str"

intOrNumber :: forall a. ToFrom (Int + Number) a
intOrNumber =
{ constr: case _ of
{ pack: case _ of
Left n × α -> Int α n
Right n × α -> Float α n
, constr_bwd: match'
, match: match'
, unpack
}
where
match' :: Val a -> (Int + Number) × a
match' (Int α n) = Left n × α
match' (Float α n) = Right n × α
match' v = error ("Int or Float expected; got " <> prettyP (erase v))
unpack (Int α n) = Left n × α
unpack (Float α n) = Right n × α
unpack v = typeError v "Int or Float"

intOrNumberOrString :: forall a. ToFrom (Int + Number + String) a
intOrNumberOrString =
{ constr: case _ of
{ pack: case _ of
Left n × α -> Int α n
Right (Left n) × α -> Float α n
Right (Right str) × α -> Str α str
, constr_bwd: match'
, match: match'
, unpack
}
where
match' :: Val a -> (Int + Number + String) × a
match' (Int α n) = Left n × α
match' (Float α n) = Right (Left n) × α
match' (Str α str) = Right (Right str) × α
match' v = error ("Int, Float or Str expected; got " <> prettyP (erase v))
unpack (Int α n) = Left n × α
unpack (Float α n) = Right (Left n) × α
unpack (Str α str) = Right (Right str) × α
unpack v = typeError v "Int, Float or Str"

intPair :: forall a. ToFrom ((Int × a) × (Int × a)) a
intPair =
{ constr: \((nβ × mβ') × α) -> Constr α cPair (int.constr nβ : int.constr mβ' : Nil)
, constr_bwd: match'
, match: match'
{ pack: \((nβ × mβ') × α) -> Constr α cPair (int.pack nβ : int.pack mβ' : Nil)
, unpack
}
where
match' :: Val a -> ((Int × a) × (Int × a)) × a
match' (Constr α c (v : v' : Nil)) | c == cPair = (int.match v × int.match v') × α
match' v = error ("Pair expected; got " <> prettyP (erase v))
unpack (Constr α c (v : v' : Nil)) | c == cPair = (int.unpack v × int.unpack v') × α
unpack v = typeError v "Pair"

matrixRep :: forall a. Ann a => ToFrom (MatrixRep a) a
matrixRep =
{ constr: \(m × α) -> Matrix α m
, constr_bwd: match'
, match: match'
{ pack: \(m × α) -> Matrix α m
, unpack
}
where
match' :: Ann a => Val a -> MatrixRep a × a
match' (Matrix α m) = m × α
match' v = error ("Matrix expected; got " <> prettyP v)
unpack (Matrix α m) = m × α
unpack v = typeError v "Matrix"

record :: forall a. Ann a => ToFrom (Dict (Val a)) a
record =
{ constr: \(xvs × α) -> Record α xvs
, constr_bwd: match'
, match: match'
{ pack: \(xvs × α) -> Record α xvs
, unpack
}
where
match' :: Ann a => _
match' (Record α xvs) = xvs × α
match' v = error ("Record expected; got " <> prettyP v)
unpack (Record α xvs) = xvs × α
unpack v = typeError v "Record"

boolean :: forall a. ToFrom Boolean a
boolean =
{ constr: case _ of
{ pack: case _ of
true × α -> Constr α cTrue Nil
false × α -> Constr α cFalse Nil
, constr_bwd: match'
, match: match'
, unpack
}
where
match' :: Val a -> Boolean × a
match' (Constr α c Nil)
unpack (Constr α c Nil)
| c == cTrue = true × α
| c == cFalse = false × α
match' v = error ("Boolean expected; got " <> prettyP (erase v))
unpack v = typeError v "Boolean"

class IsZero a where
isZero :: a -> Boolean
Expand Down Expand Up @@ -177,20 +161,20 @@ unary op =
where
op' :: Partial => OpGraph
op' (v : Nil) =
op.o.constr <$> ((op.fwd x × _) <$> new (singleton α))
op.o.pack <$> ((op.fwd x × _) <$> new (singleton α))
where
x × α = op.i.match v
x × α = op.i.unpack v

fwd :: Partial => OpFwd (Raw Val)
fwd (v : Nil) = pure $ erase v × op.o.constr (op.fwd x × α)
fwd (v : Nil) = pure $ erase v × op.o.pack (op.fwd x × α)
where
x × α = op.i.match v
x × α = op.i.unpack v

bwd :: Partial => OpBwd (Raw Val)
bwd (u × v) = op.i.constr (x × α) : Nil
bwd (u × v) = op.i.pack (x × α) : Nil
where
_ × α = op.o.constr_bwd v
(x × _) = op.i.match u
_ × α = op.o.unpack v
(x × _) = op.i.unpack u

binary :: forall i1 i2 o a'. BoundedJoinSemilattice a' => (forall a. Binary i1 i2 o a) -> Val a'
binary op =
Expand All @@ -200,20 +184,20 @@ binary op =
where
op' :: Partial => OpGraph
op' (v1 : v2 : Nil) =
op.o.constr <$> ((op.fwd x y × _) <$> new (singleton α # insert β))
op.o.pack <$> ((op.fwd x y × _) <$> new (singleton α # insert β))
where
(x × α) × (y × β) = op.i1.match v1 × op.i2.match v2
(x × α) × (y × β) = op.i1.unpack v1 × op.i2.unpack v2

fwd :: Partial => OpFwd (Raw Val × Raw Val)
fwd (v1 : v2 : Nil) = pure $ (erase v1 × erase v2) × op.o.constr (op.fwd x y × (α ∧ β))
fwd (v1 : v2 : Nil) = pure $ (erase v1 × erase v2) × op.o.pack (op.fwd x y × (α ∧ β))
where
(x × α) × (y × β) = op.i1.match v1 × op.i2.match v2
(x × α) × (y × β) = op.i1.unpack v1 × op.i2.unpack v2

bwd :: Partial => OpBwd (Raw Val × Raw Val)
bwd ((u1 × u2) × v) = op.i1.constr (x × α) : op.i2.constr (y × α) : Nil
bwd ((u1 × u2) × v) = op.i1.pack (x × α) : op.i2.pack (y × α) : Nil
where
_ × α = op.o.constr_bwd v
(x × _) × (y × _) = op.i1.match u1 × op.i2.match u2
_ × α = op.o.unpack v
(x × _) × (y × _) = op.i1.unpack u1 × op.i2.unpack u2

-- If both are zero, depend only on the first.
binaryZero :: forall i o a'. BoundedJoinSemilattice a' => IsZero i => (forall a. BinaryZero i o a) -> Val a'
Expand All @@ -230,22 +214,22 @@ binaryZero op =
else if isZero y then singleton β
else singleton α # insert β
in
op.o.constr <$> ((op.fwd x y × _) <$> new αs)
op.o.pack <$> ((op.fwd x y × _) <$> new αs)
where
(x × α) × (y × β) = op.i.match v1 × op.i.match v2
(x × α) × (y × β) = op.i.unpack v1 × op.i.unpack v2

fwd :: Partial => OpFwd (Raw Val × Raw Val)
fwd (v1 : v2 : Nil) =
pure $ (erase v1 × erase v2) ×
op.o.constr (op.fwd x y × if isZero x then α else if isZero y then β else α ∧ β)
op.o.pack (op.fwd x y × if isZero x then α else if isZero y then β else α ∧ β)
where
(x × α) × (y × β) = op.i.match v1 × op.i.match v2
(x × α) × (y × β) = op.i.unpack v1 × op.i.unpack v2

bwd :: Partial => OpBwd (Raw Val × Raw Val)
bwd ((u1 × u2) × v) = op.i.constr (x × β1) : op.i.constr (y × β2) : Nil
bwd ((u1 × u2) × v) = op.i.pack (x × β1) : op.i.pack (y × β2) : Nil
where
_ × α = op.o.constr_bwd v
(x × _) × (y × _) = op.i.match u1 × op.i.match u2
_ × α = op.o.unpack v
(x × _) × (y × _) = op.i.unpack u1 × op.i.unpack u2
β1 × β2 =
if isZero x then α × bot
else if isZero y then bot × α
Expand Down

0 comments on commit 0aabb64

Please sign in to comment.