Skip to content

Commit

Permalink
Give evaluator access to inscope let-bindings
Browse files Browse the repository at this point in the history
Without it, Clash goes into an infinite loop on T1354A in
combination with:

clash-lang/ghc-typelits-knownnat#47
  • Loading branch information
christiaanb committed Aug 28, 2023
1 parent 1366748 commit 8470b9a
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 15 deletions.
5 changes: 3 additions & 2 deletions clash-lib/src/Clash/Core/Evaluator/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,20 @@ import Clash.Pretty (ClashPretty(..), fromPretty, showDoc)
whnf'
:: Evaluator
-> BindingMap
-> VarEnv Term
-> TyConMap
-> PrimHeap
-> Supply
-> InScopeSet
-> Bool
-> Term
-> (PrimHeap, PureHeap, Term)
whnf' eval bm tcm ph ids is isSubj e =
whnf' eval bm lh tcm ph ids is isSubj e =
toResult $ whnf eval tcm isSubj m
where
toResult x = (mHeapPrim x, mHeapLocal x, mTerm x)

m = Machine ph gh emptyVarEnv [] ids is e
m = Machine ph gh lh [] ids is e
gh = mapVarEnv bindingTerm bm

-- | Evaluate to WHNF given an existing Heap and Stack
Expand Down
2 changes: 1 addition & 1 deletion clash-lib/src/Clash/Core/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ pprPrecCast prec e ty1 ty2 = do
pprPrecLetrec :: Monad m => Rational -> Bool -> [(Id, Term)] -> Term -> m ClashDoc
pprPrecLetrec prec isRec xes body = do
let bndrs = fst <$> xes
body' <- annotate (AnnContext $ LetBody bndrs) <$> pprPrec noPrec body
body' <- annotate (AnnContext $ LetBody xes) <$> pprPrec noPrec body
xes' <- mapM (\(x,e) -> do
x' <- pprBndr LetBind x
e' <- pprPrec noPrec e
Expand Down
4 changes: 2 additions & 2 deletions clash-lib/src/Clash/Core/Term.hs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ data CoreContext
-- ^ Function position of a type application
| LetBinding Id [Id]
-- ^ RHS of a Let-binder with the sibling LHS'
| LetBody [Id]
| LetBody [LetBinding]
-- ^ Body of a Let-binding with the bound LHS'
| LamBody Id
-- ^ Body of a lambda-term with the abstracted variable
Expand Down Expand Up @@ -303,7 +303,7 @@ instance Eq CoreContext where
-- NB: we do not see inside the argument here
(TyAppC, TyAppC) -> True
(LetBinding i is, LetBinding i' is') -> i == i' && is == is'
(LetBody is, LetBody is') -> is == is'
(LetBody is, LetBody is') -> map fst is == map fst is'
(LamBody i, LamBody i') -> i == i'
(TyLamBody tv, TyLamBody tv') -> tv == tv'
(CaseAlt p, CaseAlt p') -> p == p'
Expand Down
8 changes: 8 additions & 0 deletions clash-lib/src/Clash/Core/VarEnv.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ module Clash.Core.VarEnv
, delVarEnvList
, unionVarEnv
, unionVarEnvWith
, differenceVarEnv
-- ** Element-wise operations
-- *** Mapping
, mapVarEnv
Expand Down Expand Up @@ -227,6 +228,13 @@ unionVarEnvWith
-> VarEnv a
unionVarEnvWith = UniqMap.unionWith

-- | Filter the first varenv to only contain keys which are not in the second varenv.
differenceVarEnv
:: VarEnv a
-> VarEnv a
-> VarEnv a
differenceVarEnv = UniqMap.difference

-- | Create an environment given a list of var-value pairs
mkVarEnv
:: [(Var a,b)]
Expand Down
2 changes: 1 addition & 1 deletion clash-lib/src/Clash/Normalize/Transformations/DEC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ collectGlobals' is0 substitution seen e@(collectArgsTicks -> (fun, args@(_:_), t
let (ids1,ids2) = splitSupply ids
uniqSupply Lens..= ids2
gh <- Lens.use globalHeap
let eval = (Lens.view Lens._3) . whnf' evaluate bndrs tcm gh ids1 is0 False
let eval = (Lens.view Lens._3) . whnf' evaluate bndrs mempty tcm gh ids1 is0 False
let eTy = inferCoreTypeOf tcm e
untran <- isUntranslatableType False eTy
case untran of
Expand Down
4 changes: 2 additions & 2 deletions clash-lib/src/Clash/Normalize/Transformations/EtaExpand.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ etaExpansionTL (TransformContext is0 ctx) (Lam bndr e) = do
return $ Lam bndr e'

etaExpansionTL (TransformContext is0 ctx) (Let (NonRec i x) e) = do
let ctx' = TransformContext (extendInScopeSet is0 i) (LetBody [i] : ctx)
let ctx' = TransformContext (extendInScopeSet is0 i) (LetBody [(i,x)] : ctx)
e' <- etaExpansionTL ctx' e
case stripLambda e' of
(bs@(_:_),e2) -> do
Expand All @@ -81,7 +81,7 @@ etaExpansionTL (TransformContext is0 ctx) (Let (NonRec i x) e) = do

etaExpansionTL (TransformContext is0 ctx) (Let (Rec xes) e) = do
let bndrs = map fst xes
ctx' = TransformContext (extendInScopeSetList is0 bndrs) (LetBody bndrs : ctx)
ctx' = TransformContext (extendInScopeSetList is0 bndrs) (LetBody xes : ctx)
e' <- etaExpansionTL ctx' e
case stripLambda e' of
(bs@(_:_),e2) -> do
Expand Down
2 changes: 1 addition & 1 deletion clash-lib/src/Clash/Rewrite/Combinators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ allR trans (TransformContext is c) (Cast e ty1 ty2) =

allR trans (TransformContext is c) (Letrec xes e) = do
xes' <- traverse rewriteBind xes
e' <- trans (TransformContext is' (LetBody bndrs:c)) e
e' <- trans (TransformContext is' (LetBody xes:c)) e
return (Letrec xes' e')
where
bndrs = map fst xes
Expand Down
19 changes: 14 additions & 5 deletions clash-lib/src/Clash/Rewrite/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ import Clash.Core.Var
import Clash.Core.VarEnv
(InScopeSet, extendInScopeSet, extendInScopeSetList, mkInScopeSet,
uniqAway, uniqAway', mapVarEnv, eltsVarEnv, unitVarSet, emptyVarEnv,
mkVarEnv, eltsVarSet, elemVarEnv, lookupVarEnv, extendVarEnv, elemVarSet)
mkVarEnv, eltsVarSet, elemVarEnv, lookupVarEnv, extendVarEnv, elemVarSet,
differenceVarEnv)
import Clash.Data.UniqMap (UniqMap)
import qualified Clash.Data.UniqMap as UniqMap
import Clash.Debug
Expand Down Expand Up @@ -730,19 +731,27 @@ whnfRW
-> Term
-> Rewrite extra
-> RewriteMonad extra Term
whnfRW isSubj ctx@(TransformContext is0 _) e rw = do
whnfRW isSubj ctx@(TransformContext is0 hist) e rw = do
tcm <- Lens.view tcCache
bndrs <- Lens.use bindings
eval <- Lens.view evaluator
ids <- Lens.use uniqSupply
let (ids1,ids2) = splitSupply ids
uniqSupply Lens..= ids2
gh <- Lens.use globalHeap
let lh = localBinders mempty hist

case whnf' eval bndrs tcm gh ids1 is0 isSubj e of
case whnf' eval bndrs lh tcm gh ids1 is0 isSubj e of
(!gh1,ph,v) -> do
globalHeap Lens..= gh1
bindPureHeap tcm ph rw ctx v
bindPureHeap tcm (ph `differenceVarEnv` lh) rw ctx v
where
localBinders acc [] = acc
localBinders !acc (h:hs) = case h of
-- LetBinding _ ls -> localBinders (acc <> mkVarEnv ls) hs
LetBody ls -> localBinders (acc <> mkVarEnv ls) hs
_ -> localBinders acc hs

{-# SCC whnfRW #-}

-- | Binds variables on the PureHeap over the result of the rewrite
Expand Down Expand Up @@ -791,7 +800,7 @@ bindPureHeap tcm heap rw ctx0@(TransformContext is0 hist) e = do
where
heapIds = map fst bndrs
is1 = extendInScopeSetList is0 heapIds
ctx = TransformContext is1 (LetBody heapIds : hist)
ctx = TransformContext is1 (LetBody bndrs : hist)

bndrs = map toLetBinding $ UniqMap.toList heap

Expand Down
2 changes: 1 addition & 1 deletion clash-term/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ instance Diff Term where
(Letrec bnds body, LetBinding i' _) ->
Letrec (mapBindings i' bnds) body
(Letrec bnds t, LetBody is) ->
if (fst <$> bnds) == is
if (fst <$> bnds) == (fst <$> is)
then Letrec bnds (go t)
else error "Ctx.LetBody: different bindings"
(Lam i t, LamBody i') ->
Expand Down

0 comments on commit 8470b9a

Please sign in to comment.