diff --git a/primer/src/Primer/Eval/Redex.hs b/primer/src/Primer/Eval/Redex.hs index 76039cf5f..95cd4d3b4 100644 --- a/primer/src/Primer/Eval/Redex.hs +++ b/primer/src/Primer/Eval/Redex.hs @@ -661,10 +661,26 @@ viewRedex opts tydefs globals dir = \case InlineGlobal{gvar, def, orig} orig@(viewLets -> Just (bindings, expr)) | opts.pushMulti || null (NonEmpty.tail bindings) + , letBinders <- foldMap' (S.singleton . letBindingName . snd) bindings , S.disjoint (getBoundHereDn expr) - (foldMap' (S.singleton . letBindingName . snd) bindings <> setOf (folded % _2 % _freeVarsLetBinding) bindings) -> + (letBinders <> setOf (folded % _2 % _freeVarsLetBinding) bindings) + , -- prefer to elide if possible + allLetsUsed (fmap snd bindings) expr -> pure $ PushLet{bindings, expr, orig} + where + -- Fold right-to-left calculating free var set and whether each + -- binder has been referenced + allLetsUsed ls b = + snd $ + foldr + ( \l (fvs, allUsed) -> + let n = letBindingName l + rhs = setOf _freeVarsLetBinding l + in (S.delete n fvs `S.union` rhs, allUsed && n `S.member` fvs) + ) + (freeVars b, True) + ls orig@(Let _ var rhs body) | Var _ (LocalVarRef var') <- body , var == var' -> @@ -800,10 +816,26 @@ viewRedexType opts = \case | Just (bindingsWithID, intoTy) <- viewLetsTy origTy , opts.pushMulti || null (NonEmpty.tail bindingsWithID) , (bindings, bindingIDs) <- NonEmpty.unzip bindingsWithID + , letBinders <- foldMap' (S.singleton . letTypeBindingName) bindings + , -- prefer to elide if possible + allLetsUsed bindings intoTy , S.disjoint (S.map unLocalName $ getBoundHereDnTy intoTy) - (foldMap' (S.singleton . letTypeBindingName) bindings <> setOf (folded % _freeVarsLetTypeBinding) bindings) -> + (letBinders <> setOf (folded % _freeVarsLetTypeBinding) bindings) -> purer $ PushLetType{bindings, intoTy, origTy, bindingIDs} + where + -- Fold right-to-left calculating free var set and whether each + -- binder has been referenced + allLetsUsed ls b = + snd $ + foldr + ( \l (fvs, allUsed) -> + let n = letTypeBindingName l + rhs = setOf _freeVarsLetTypeBinding l + in (S.delete n fvs `S.union` rhs, allUsed && n `S.member` fvs) + ) + (S.map unLocalName $ freeVarsTy b, True) + ls orig@(TLet _ v s body) | TVar _ var <- body , v == var -> diff --git a/primer/test/Tests/EvalFull.hs b/primer/test/Tests/EvalFull.hs index ea5f73e33..a73632e46 100644 --- a/primer/test/Tests/EvalFull.hs +++ b/primer/test/Tests/EvalFull.hs @@ -90,6 +90,10 @@ import Primer.Primitives.DSL (pfun) import Primer.Test.App ( runAppTestM, ) +import Primer.Test.Expected ( + Expected (defMap, expectedResult, expr, maxID), + mapEven, + ) import Primer.Test.TestM ( evalTestM, ) @@ -201,17 +205,16 @@ unit_7 = s <- evalFullTest maxID mempty mempty 100 Syn e s <~==> Right e --- Temporarily disabled for performance reasons --- unit_8 :: Assertion --- unit_8 = --- let n = 10 --- e = mapEven n --- in do --- evalFullTest (maxID e) builtinTypes (defMap e) 1000 Syn (expr e) >>= \case --- Left (TimedOut _) -> pure () --- x -> assertFailure $ show x --- s <- evalFullTest (maxID e) builtinTypes (defMap e) 2000 Syn (expr e) --- s <~==> Right (expectedResult e) +unit_8 :: Assertion +unit_8 = + let n = 10 + e = mapEven n + in do + evalFullTest (maxID e) builtinTypes (defMap e) 1000 Syn (expr e) >>= \case + Left (TimedOut _) -> pure () + x -> assertFailure $ show x + s <- evalFullTest (maxID e) builtinTypes (defMap e) 2000 Syn (expr e) + s <~==> Right (expectedResult e) -- A worker/wrapper'd map unit_9 :: Assertion @@ -722,25 +725,23 @@ unit_type_preservation_BETA_regression = lam "x" $ let_ "c" (lvar "x" `ann` tcon tNat) (letType "a" (tvar "b" `tapp` tcon tBool) (emptyHole `ann` tvar "a")) `ann` (tvar "b" `tapp` tcon tBool) - -- Push the lets - -- Λb. λx. (((let c = (x : Nat) in (lettype a = b Bool in ?)) : (lettype a = b Bool in a)) : (b Bool)) + -- Elide a let + -- Λb. λx. ((lettype a = b Bool in (? : a)) : (b Bool)) expectA7 <- lAM "b" $ lam "x" $ - ( let_ "c" (lvar "x" `ann` tcon tNat) (letType "a" (tvar "b" `tapp` tcon tBool) emptyHole) - `ann` tlet "a" (tvar "b" `tapp` tcon tBool) (tvar "a") - ) + letType "a" (tvar "b" `tapp` tcon tBool) (emptyHole `ann` tvar "a") `ann` (tvar "b" `tapp` tcon tBool) - -- Inline a let - -- Λb. λx. (((let c = (x : Nat) in (lettype a = b Bool in ?)) : (b Bool)) : (b Bool)) + -- Push the let + -- Λb. λx. (((lettype a = b Bool in ?) : (lettype a = b Bool in a)) : (b Bool)) expectA8 <- lAM "b" $ lam "x" $ - ( let_ "c" (lvar "x" `ann` tcon tNat) (letType "a" (tvar "b" `tapp` tcon tBool) emptyHole) - `ann` (tvar "b" `tapp` tcon tBool) + ( letType "a" (tvar "b" `tapp` tcon tBool) emptyHole + `ann` tlet "a" (tvar "b" `tapp` tcon tBool) (tvar "a") ) `ann` (tvar "b" `tapp` tcon tBool) - -- Elide a pointless let + -- Inline a let -- Λb. λx. (((lettype a = b Bool in ?) : (b Bool)) : (b Bool)) expectA9 <- lAM "b" $ @@ -772,9 +773,9 @@ unit_type_preservation_BETA_regression = lAM "b" $ letType "a" (tcon tChar) (gvar foo `aPP` (tvar "b" `tapp` tcon tBool)) `ann` tlet "b" (tcon tChar) (tcon tNat) - -- Drop annotation, push lettype to leaves and then elide all lettypes + -- Drop annotation and elide lettype -- Λb. foo @(b Bool) - expectB7 <- lAM "b" $ gvar foo `aPP` (tvar "b" `tapp` tcon tBool) + expectB3 <- lAM "b" $ gvar foo `aPP` (tvar "b" `tapp` tcon tBool) -- Note that the reduction of eA and eB take slightly -- different paths: we do not remove the annotation in eA -- because it has an occurrence of a type variable and is thus @@ -796,7 +797,7 @@ unit_type_preservation_BETA_regression = , expectA10 ] ) - , (eB, [(1, expectB1), (7, expectB7)]) + , (eB, [(1, expectB1), (3, expectB3)]) ) sA n = evalFullTest maxID builtinTypes mempty n Chk exprA sB n = evalFullTest maxID builtinTypes mempty n Chk exprB @@ -1457,7 +1458,7 @@ unit_prim_partial_map = -- and then in two steps (expand @go@, push stack of let+letrec) -- @λxs. let α=Char, β=Char, f=toUpper in (letrec go : List α -> List β; go = λxs.RHS in RHS : List α -> List β) -- we carry around the subst for α,β and f, using α,β inside annotation and f in RHS each time expand the letrec - s <- evalFullTest maxID builtinTypes (gs <> prims) 203 Syn e + s <- evalFullTest maxID builtinTypes (gs <> prims) 169 Syn e s <~==> Right r -- Test that handleEvalFullRequest will reduce imported terms