Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize delayedFold to arbitrary vectors #2804

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 71 additions & 22 deletions clash-prelude/src/Clash/Explicit/Signal/Delayed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Maintainer : Christiaan Baaij <[email protected]>
{-# LANGUAGE Trustworthy #-}

{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=10 #-}
{-# OPTIONS_HADDOCK show-extensions #-}

module Clash.Explicit.Signal.Delayed
Expand All @@ -42,22 +44,25 @@ module Clash.Explicit.Signal.Delayed
)
where

import Prelude ((.), (<$>), (<*>), id, Num(..))
import Prelude ((.), ($), (<$>), id, Num(..), Maybe(..), fmap)

import Control.Applicative (liftA2)
import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Singletons (Apply, TyFun, type (@@))
import GHC.TypeLits (KnownNat, Nat, type (+), type (^), type (*))
import Data.Type.Equality ((:~:)(Refl))
import GHC.TypeLits (sameNat, Div, Mod, KnownNat, Nat, type (+), type (*), type (<=))
import GHC.TypeLits.Extra (CLog)

import Clash.Magic (clashCompileError)
import Clash.Sized.Vector
import Clash.Signal.Delayed.Internal
(DSignal(..), dfromList, dfromList_lazy, fromSignal, toSignal,
unsafeFromSignal, antiDelay, feedback, forward)
import qualified Clash.Signal.Delayed.Bundle as D

import Clash.Explicit.Signal
(KnownDomain, Clock, Domain, Reset, Signal, Enable, register, delay, bundle, unbundle)
import Clash.Promoted.Nat (SNat (..), snatToInteger)
import Clash.Promoted.Nat (SNat (..), SNatLE (..), compareSNat, snatToInteger)
import Clash.XException (NFDataX)

{- $setup
Expand Down Expand Up @@ -230,12 +235,9 @@ delayI
-> DSignal dom (n+d) a
delayI dflt = delayN (SNat :: SNat d) dflt

data DelayedFold (dom :: Domain) (n :: Nat) (delay :: Nat) (a :: Type) (f :: TyFun Nat Type) :: Type
type instance Apply (DelayedFold dom n delay a) k = DSignal dom (n + (delay*k)) a

-- | Tree fold over a 'Vec' of 'DSignal's with a combinatorial function,
-- and delaying @delay@ cycles after each application.
-- Values at times 0..(delay*k)-1 are set to a default.
-- Values at times 0..(delay * CLog 2 n)-1 are set to a default.
--
-- @
-- countingSignals :: Vec 4 (DSignal dom 0 Int)
Expand All @@ -248,11 +250,12 @@ type instance Apply (DelayedFold dom n delay a) k = DSignal dom (n + (delay*k))
-- >>> printX $ sampleN 8 (delayedFold d2 (-1) (*) enableGen systemClockGen countingSignals)
-- [-1,-1,1,1,0,1,16,81]
delayedFold
:: forall dom n delay k a
:: forall dom d delay n a
. ( NFDataX a
, KnownDomain dom
, KnownNat delay
, KnownNat k )
, KnownNat n
, 1 <= n )
=> SNat delay
-- ^ Delay applied after each step
-> a
Expand All @@ -261,14 +264,60 @@ delayedFold
-- ^ Fold operation to apply
-> Enable dom
-> Clock dom
-> Vec (2^k) (DSignal dom n a)
-- ^ Vector input of size 2^k
-> DSignal dom (n + (delay * k)) a
-- ^ Output Signal delayed by (delay * k)
delayedFold _ dflt op ena clk = dtfold (Proxy :: Proxy (DelayedFold dom n delay a)) id go
where
go :: SNat l
-> DelayedFold dom n delay a @@ l
-> DelayedFold dom n delay a @@ l
-> DelayedFold dom n delay a @@ (l+1)
go SNat x y = delayI dflt ena clk (op <$> x <*> y)
-> Vec n (DSignal dom d a)
-- ^ Vector input of size @n@
-> DSignal dom (d + delay * CLog 2 n) a
-- ^ Output Signal delayed by @delay * CLog 2 n@
delayedFold SNat initial f ena clk inps = case sameNat (SNat @1) (SNat @n) of
Just Refl -> head inps
_ -> case (modProof, strictlyPosDivRu, divMulProof) of
(SNatLE, SNatLE, Just Refl) ->
case sameNat (SNat @(1 + CLog 2 (n `Div` 2 + n `Mod` 2))) (SNat @(CLog 2 n)) of
Just Refl -> delayedFold (SNat @delay) initial f ena clk newLayer
where
newLayer = D.unbundle $
step @(n `Div` 2) @(n `Mod` 2) @d @delay (SNat @(n `Div` 2)) initial f ena clk (D.bundle inps)
_ -> clashCompileError
"delayedFold0: absurd, report this to the clash-compiler team: https://github.com/clash-lang/clash-compiler/issues"
_ -> clashCompileError
"delayedFold1: absurd, report this to the clash-compiler team: https://github.com/clash-lang/clash-compiler/issues"
where
modProof = compareSNat (SNat @(n `Mod` 2)) (SNat @1)
strictlyPosDivRu = compareSNat (SNat @1) (SNat @(n `Div` 2 + n `Mod` 2))
divMulProof = sameNat (SNat @n) (SNat @(2 * (n `Div` 2) + n `Mod` 2))

-- | A single layer of the pipelined fold
step :: forall (m :: Nat) (p :: Nat) (d :: Nat) (delay :: Nat) (dom :: Domain) (a :: Type).
KnownNat p
=> KnownNat delay
=> KnownDomain dom
=> p <= 1
=> NFDataX a
=> SNat m
-> a
-> (a -> a -> a)
-> Enable dom
-> Clock dom
-> DSignal dom d (Vec (2 * m + p) a)
-> DSignal dom (d + delay) (Vec (m + p) a)
step SNat initial f ena clk inps =
let
layerCalc :: DSignal dom d (Vec (2 * m) a) -> DSignal dom d (Vec m a)
layerCalc = fmap (map applyF . unconcatI)

applyF :: Vec 2 a -> a
applyF (a `Cons` b `Cons` _) = f a b
in
case (sameNat (SNat @p) (SNat @0), sameNat (SNat @p) (SNat @1)) of
-- Size of the input vector is even
(Just Refl, Nothing) ->
delayI (repeat initial) ena clk (layerCalc inps)
-- Size of the input vector is odd
(Nothing, Just Refl) ->
delayI (repeat initial) ena clk $
liftA2
(++)
(singleton . head <$> inps)
(layerCalc (tail <$> inps))
_ -> clashCompileError
"delayedFold step: absurd, report this to the clash-compiler team: https://github.com/clash-lang/clash-compiler/issues"
14 changes: 8 additions & 6 deletions clash-prelude/src/Clash/Signal/Delayed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ module Clash.Signal.Delayed
where

import GHC.TypeLits
(KnownNat, type (^), type (+), type (*))
(KnownNat, type (+), type (*), type (<=))
import GHC.TypeLits.Extra (CLog)

import Clash.Signal.Delayed.Internal
(DSignal(..), dfromList, dfromList_lazy, fromSignal, toSignal,
Expand Down Expand Up @@ -192,7 +193,7 @@ delayI dflt = hideClock (hideEnable (E.delayI dflt))

-- | Tree fold over a 'Vec' of 'DSignal's with a combinatorial function,
-- and delaying @delay@ cycles after each application.
-- Values at times 0..(delay*k)-1 are set to a default.
-- Values at times 0..(delay * CLog 2 n)-1 are set to a default.
--
-- @
-- countingSignals :: Vec 4 (DSignal dom 0 Int)
Expand All @@ -205,20 +206,21 @@ delayI dflt = hideClock (hideEnable (E.delayI dflt))
-- >>> printX $ sampleN @System 8 (toSignal (delayedFold d2 (-1) (*) countingSignals))
-- [-1,-1,1,1,0,1,16,81]
delayedFold
:: forall dom n delay k a
:: forall dom d delay n a
. ( HiddenClock dom
, HiddenEnable dom
, NFDataX a
, KnownNat delay
, KnownNat k )
, KnownNat n
, 1 <= n)
=> SNat delay
-- ^ Delay applied after each step
-> a
-- ^ Initial value
-> (a -> a -> a)
-- ^ Fold operation to apply
-> Vec (2^k) (DSignal dom n a)
-> Vec n (DSignal dom d a)
-- ^ Vector input of size 2^k
-> DSignal dom (n + (delay * k)) a
-> DSignal dom (d + (delay * CLog 2 n)) a
-- ^ Output Signal delayed by (delay * k)
delayedFold d dflt f = hideClock (hideEnable (E.delayedFold d dflt f))
2 changes: 1 addition & 1 deletion clash-prelude/src/Clash/Signal/Delayed/Bundle.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import GHC.TypeLits (KnownNat)
import Prelude hiding (head, map, tail)

import Clash.Signal.Internal (Domain)
import Clash.Signal.Delayed (DSignal, toSignal, unsafeFromSignal)
import Clash.Signal.Delayed.Internal (DSignal, toSignal, unsafeFromSignal)
import qualified Clash.Signal.Bundle as B

import Clash.Sized.BitVector (Bit, BitVector)
Expand Down
Loading