Skip to content

Commit

Permalink
feat: eval elides lets as soon as possible
Browse files Browse the repository at this point in the history
Before doing this, we would sometimes build up large redundant
substitutions, and have to push these down to the leaves before eliding
them. This leads to larger terms and extra work.

This enables us to reenable unit_8, as it gives a major performance
increase. It also noticably speeds up unit_9.

Signed-off-by: Ben Price <[email protected]>
  • Loading branch information
brprice committed Aug 10, 2023
1 parent 04a69ff commit 936eb84
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 27 deletions.
36 changes: 34 additions & 2 deletions primer/src/Primer/Eval/Redex.hs
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,26 @@ viewRedex opts tydefs globals dir = \case
InlineGlobal{gvar, def, orig}
orig@(viewLets -> Just (bindings, expr))
| opts.pushMulti || null (NonEmpty.tail bindings)
, letBinders <- foldMap' (S.singleton . letBindingName . snd) bindings
, S.disjoint
(getBoundHereDn expr)
(foldMap' (S.singleton . letBindingName . snd) bindings <> setOf (folded % _2 % _freeVarsLetBinding) bindings) ->
(letBinders <> setOf (folded % _2 % _freeVarsLetBinding) bindings)
, -- prefer to elide if possible
allLetsUsed (fmap snd bindings) expr ->
pure $ PushLet{bindings, expr, orig}
where
-- Fold right-to-left calculating free var set and whether each
-- binder has been referenced
allLetsUsed ls b =
snd $
foldr
( \l (fvs, allUsed) ->
let n = letBindingName l
rhs = setOf _freeVarsLetBinding l
in (S.delete n fvs `S.union` rhs, allUsed && n `S.member` fvs)
)
(freeVars b, True)
ls
orig@(Let _ var rhs body)
| Var _ (LocalVarRef var') <- body
, var == var' ->
Expand Down Expand Up @@ -800,10 +816,26 @@ viewRedexType opts = \case
| Just (bindingsWithID, intoTy) <- viewLetsTy origTy
, opts.pushMulti || null (NonEmpty.tail bindingsWithID)
, (bindings, bindingIDs) <- NonEmpty.unzip bindingsWithID
, letBinders <- foldMap' (S.singleton . letTypeBindingName) bindings
, -- prefer to elide if possible
allLetsUsed bindings intoTy
, S.disjoint
(S.map unLocalName $ getBoundHereDnTy intoTy)
(foldMap' (S.singleton . letTypeBindingName) bindings <> setOf (folded % _freeVarsLetTypeBinding) bindings) ->
(letBinders <> setOf (folded % _freeVarsLetTypeBinding) bindings) ->
purer $ PushLetType{bindings, intoTy, origTy, bindingIDs}
where
-- Fold right-to-left calculating free var set and whether each
-- binder has been referenced
allLetsUsed ls b =
snd $
foldr
( \l (fvs, allUsed) ->
let n = letTypeBindingName l
rhs = setOf _freeVarsLetTypeBinding l
in (S.delete n fvs `S.union` rhs, allUsed && n `S.member` fvs)
)
(S.map unLocalName $ freeVarsTy b, True)
ls
orig@(TLet _ v s body)
| TVar _ var <- body
, v == var ->
Expand Down
51 changes: 26 additions & 25 deletions primer/test/Tests/EvalFull.hs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ import Primer.Primitives.DSL (pfun)
import Primer.Test.App (
runAppTestM,
)
import Primer.Test.Expected (
Expected (defMap, expectedResult, expr, maxID),
mapEven,
)
import Primer.Test.TestM (
evalTestM,
)
Expand Down Expand Up @@ -201,17 +205,16 @@ unit_7 =
s <- evalFullTest maxID mempty mempty 100 Syn e
s <~==> Right e

-- Temporarily disabled for performance reasons
-- unit_8 :: Assertion
-- unit_8 =
-- let n = 10
-- e = mapEven n
-- in do
-- evalFullTest (maxID e) builtinTypes (defMap e) 1000 Syn (expr e) >>= \case
-- Left (TimedOut _) -> pure ()
-- x -> assertFailure $ show x
-- s <- evalFullTest (maxID e) builtinTypes (defMap e) 2000 Syn (expr e)
-- s <~==> Right (expectedResult e)
unit_8 :: Assertion
unit_8 =
let n = 10
e = mapEven n
in do
evalFullTest (maxID e) builtinTypes (defMap e) 1000 Syn (expr e) >>= \case
Left (TimedOut _) -> pure ()
x -> assertFailure $ show x
s <- evalFullTest (maxID e) builtinTypes (defMap e) 2000 Syn (expr e)
s <~==> Right (expectedResult e)

-- A worker/wrapper'd map
unit_9 :: Assertion
Expand Down Expand Up @@ -722,25 +725,23 @@ unit_type_preservation_BETA_regression =
lam "x" $
let_ "c" (lvar "x" `ann` tcon tNat) (letType "a" (tvar "b" `tapp` tcon tBool) (emptyHole `ann` tvar "a"))
`ann` (tvar "b" `tapp` tcon tBool)
-- Push the lets
-- Λb. λx. (((let c = (x : Nat) in (lettype a = b Bool in ?)) : (lettype a = b Bool in a)) : (b Bool))
-- Elide a let
-- Λb. λx. ((lettype a = b Bool in (? : a)) : (b Bool))
expectA7 <-
lAM "b" $
lam "x" $
( let_ "c" (lvar "x" `ann` tcon tNat) (letType "a" (tvar "b" `tapp` tcon tBool) emptyHole)
`ann` tlet "a" (tvar "b" `tapp` tcon tBool) (tvar "a")
)
letType "a" (tvar "b" `tapp` tcon tBool) (emptyHole `ann` tvar "a")
`ann` (tvar "b" `tapp` tcon tBool)
-- Inline a let
-- Λb. λx. (((let c = (x : Nat) in (lettype a = b Bool in ?)) : (b Bool)) : (b Bool))
-- Push the let
-- Λb. λx. (((lettype a = b Bool in ?) : (lettype a = b Bool in a)) : (b Bool))
expectA8 <-
lAM "b" $
lam "x" $
( let_ "c" (lvar "x" `ann` tcon tNat) (letType "a" (tvar "b" `tapp` tcon tBool) emptyHole)
`ann` (tvar "b" `tapp` tcon tBool)
( letType "a" (tvar "b" `tapp` tcon tBool) emptyHole
`ann` tlet "a" (tvar "b" `tapp` tcon tBool) (tvar "a")
)
`ann` (tvar "b" `tapp` tcon tBool)
-- Elide a pointless let
-- Inline a let
-- Λb. λx. (((lettype a = b Bool in ?) : (b Bool)) : (b Bool))
expectA9 <-
lAM "b" $
Expand Down Expand Up @@ -772,9 +773,9 @@ unit_type_preservation_BETA_regression =
lAM "b" $
letType "a" (tcon tChar) (gvar foo `aPP` (tvar "b" `tapp` tcon tBool))
`ann` tlet "b" (tcon tChar) (tcon tNat)
-- Drop annotation, push lettype to leaves and then elide all lettypes
-- Drop annotation and elide lettype
-- Λb. foo @(b Bool)
expectB7 <- lAM "b" $ gvar foo `aPP` (tvar "b" `tapp` tcon tBool)
expectB3 <- lAM "b" $ gvar foo `aPP` (tvar "b" `tapp` tcon tBool)
-- Note that the reduction of eA and eB take slightly
-- different paths: we do not remove the annotation in eA
-- because it has an occurrence of a type variable and is thus
Expand All @@ -796,7 +797,7 @@ unit_type_preservation_BETA_regression =
, expectA10
]
)
, (eB, [(1, expectB1), (7, expectB7)])
, (eB, [(1, expectB1), (3, expectB3)])
)
sA n = evalFullTest maxID builtinTypes mempty n Chk exprA
sB n = evalFullTest maxID builtinTypes mempty n Chk exprB
Expand Down Expand Up @@ -1457,7 +1458,7 @@ unit_prim_partial_map =
-- and then in two steps (expand @go@, push stack of let+letrec)
-- @λxs. let α=Char, β=Char, f=toUpper in (letrec go : List α -> List β; go = λxs.RHS in RHS : List α -> List β)
-- we carry around the subst for α,β and f, using α,β inside annotation and f in RHS each time expand the letrec
s <- evalFullTest maxID builtinTypes (gs <> prims) 203 Syn e
s <- evalFullTest maxID builtinTypes (gs <> prims) 169 Syn e
s <~==> Right r

-- Test that handleEvalFullRequest will reduce imported terms
Expand Down

0 comments on commit 936eb84

Please sign in to comment.