Skip to content

Commit

Permalink
Merge pull request #814 from explorable-viz/matrix-update-bug
Browse files Browse the repository at this point in the history
Fix `evalBwd` bug in matrix comprehensions
  • Loading branch information
rolyp authored Oct 26, 2023
2 parents cfc7ad1 + af765cd commit e2a70e2
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 18 deletions.
2 changes: 1 addition & 1 deletion fluid/example/slicing/dtw/compute-dtw.expect.fld
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
let d x y = (x - y) * (x - y) in
let seq1 = ⸨[⸩ ⸨3⸩ ⸨,⸩ ⸨1⸩ ⸨,⸩ 2 ⸨,⸩ 2 ⸨,⸩ 1 ⸨]⸩;
seq2 = [⸨2⸩, ⸨0⸩, ⸨0⸩, 3, 3, 1, 0];
seq2 = ⸨[⸩ ⸨2⸩ ⸨,⸩ ⸨0⸩ ⸨,⸩ ⸨0⸩ ⸨,⸩ 3 ⸨,⸩ 3 ⸨,⸩ 1 ⸨,⸩ 0 ⸨]⸩;
window = ⸨2⸩;
(costs, matched) = computeDTW seq1 seq2 d window in
matched
2 changes: 1 addition & 1 deletion src/EvalBwd.purs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ evalBwd' (V.Matrix α (MatrixRep (vss × (_ × βi) × (_ × βj)))) (T.Matrix t
in
unsafePartial $
let
V.Int β _ × V.Int β' _ = get x γ0 × get x γ0
V.Int β _ × V.Int β' _ = get x γ0 × get y γ0
in
γ × e × α' × β × β'
γ × e × α' × β × β' = foldl1
Expand Down
12 changes: 6 additions & 6 deletions src/Primitive/Defs.purs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import Eval (apply, apply2)
import EvalBwd (apply2Bwd, applyBwd)
import EvalGraph (apply) as G
import Graph.GraphWriter (new)
import Lattice (class BoundedJoinSemilattice, Raw, bot, botOf, erase, top, (∧), (∨))
import Lattice (class BoundedJoinSemilattice, Raw, bot, botOf, erase, (∧), (∨))
import Partial.Unsafe (unsafePartial)
import Prelude (div, mod) as P
import Primitive (binary, binaryZero, boolean, int, intOrNumber, intOrNumberOrString, number, string, unary, union, union1, unionStr)
Expand Down Expand Up @@ -149,15 +149,15 @@ matrixUpdate =
op _ = throw "Matrix, pair of integers and value expected"

fwd :: OpFwd ((Int × Int) × Raw Val)
fwd (Matrix _ r : Constr _ c (Int _ i : Int _ j : Nil) : v : Nil) | c == cPair =
pure $ ((i × j) × erase (matrixGet i j r)) × Matrix top (matrixPut i j (const v) r)
fwd (Matrix α r : Constr _ c (Int _ i : Int _ j : Nil) : v : Nil) | c == cPair =
pure $ ((i × j) × erase (matrixGet i j r)) × Matrix α (matrixPut i j (const v) r)
fwd _ = throw "Matrix, pair of integers and value expected"

bwd :: Partial => OpBwd ((Int × Int) × Raw Val)
bwd ((((i × j) × v) × Matrix _ r')) =
Matrix bot (matrixPut i j (const (botOf v)) r')
bwd ((((i × j) × v) × Matrix α r)) =
Matrix α (matrixPut i j (const (botOf v)) r)
: Constr bot cPair (Int bot i : Int bot j : Nil)
: matrixGet i j r'
: matrixGet i j r
: Nil

dict_difference :: ForeignOp
Expand Down
10 changes: 5 additions & 5 deletions test/Specs.purs
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ bwd_cases =
\3, 11, 15, 2, 9"
, δv: matrixElement 2 2 neg
}
-- , { file: "dtw/compute-dtw"
-- , bwd_expect_file: "dtw/compute-dtw.expect"
-- , fwd_expect: "((1, 1) : ((2, 2) : ((2, 3) : ((3, 4) : ((4, 5) : ((5, 6) : ((5, 7) : [])))))))"
-- , δv: listElement 1 neg
-- }
, { file: "dtw/compute-dtw"
, bwd_expect_file: "dtw/compute-dtw.expect"
, fwd_expect: "((1, 1) : (⸨(⸨2⸩, ⸨2⸩)⸩ : ((2, 3) : ((3, 4) : ((4, 5) : ((5, 6) : ((5, 7) : [])))))))"
, δv: listElement 1 neg
}
]

graphics_cases :: Array TestWithDatasetSpec
Expand Down
14 changes: 9 additions & 5 deletions test/Util.purs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type AffError m a = MonadAff m => MonadError Error m => m a
type EffectError m a = MonadEffect m => MonadError Error m => m a

logging :: Boolean
logging = true
logging = false

logAs :: forall m. MonadEffect m => String -> String -> m Unit
logAs tag s = log $ tag <> ": " <> s
Expand Down Expand Up @@ -96,8 +96,8 @@ testTrace s γα spec@{ δv } = do
γ = erase <$> γα
benchmark (method <> "-Eval") $ \_ -> traceGC γ e

let v𝔹 = δv (botOf v)
γ𝔹 × e𝔹 × _ <- do
let v𝔹 = δv (botOf v)
unless (isGraphical v𝔹) $
when logging (logAs "Selection for bwd" (prettyP v𝔹))
benchmark (method <> "-Bwd") $ \_ -> pure (eval.bwd v𝔹)
Expand All @@ -106,13 +106,17 @@ testTrace s γα spec@{ δv } = do
let s𝔹 = desug𝔹.bwd e𝔹
v𝔹' <- do
let e𝔹' = desug𝔹.fwd s𝔹
PrettyShow e𝔹' `shouldSatisfy "fwd ⚬ bwd round-trip (desugar)"` (unwrap >>> (_ >= e𝔹))
benchmark (method <> "-Fwd") $ \_ -> pure (eval.fwd (γ𝔹 × e𝔹' × top))
PrettyShow v𝔹' `shouldSatisfy "fwd ⚬ bwd round-trip (eval)"` (unwrap >>> (_ >= v𝔹))

let
v𝔹_top = topOf v
γ𝔹_top × e𝔹_top × _ = eval.bwd v𝔹_top
v𝔹_top' = eval.fwd (γ𝔹_top × e𝔹_top × top)
PrettyShow v𝔹_top' `shouldSatisfy "fwd ⚬ bwd round-tripping property"` (unwrap >>> (_ >= v𝔹_top))
s𝔹_top = desug𝔹.bwd e𝔹_top
e𝔹_top' = desug𝔹.fwd s𝔹_top
v𝔹_top' = eval.fwd (γ𝔹_top × e𝔹_top' × top)
PrettyShow v𝔹_top' `shouldSatisfy "fwd ⚬ bwd round-trip (eval ⚬ desugar)"` (unwrap >>> (_ >= v𝔹_top))

validate method spec s𝔹 v𝔹'

Expand All @@ -134,7 +138,7 @@ testGraph s gconfig spec@{ δv } benchmarking = do
let v𝔹' = select𝔹s vα αs_out'

validate method spec (desug𝔹.bwd e𝔹) v𝔹'
αs_out `shouldSatisfy "fwd ⚬ bwd round-tripping property"` (flip subset αs_out')
αs_out `shouldSatisfy "fwd ⚬ bwd round-trip"` (flip subset αs_out')
recordGraphSize g

when benchmarking do
Expand Down

0 comments on commit e2a70e2

Please sign in to comment.