Skip to content

Commit

Permalink
Use less unsafeCoerce to define the promoted nats
Browse files Browse the repository at this point in the history
  • Loading branch information
kleinreact committed Nov 15, 2024
1 parent 10f26ff commit 5260e63
Showing 1 changed file with 56 additions and 3 deletions.
59 changes: 56 additions & 3 deletions clash-prelude/src/Clash/Promoted/Nat.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,26 @@ module Clash.Promoted.Nat
-- * Constraints on natural numbers
, leToPlus
, leToPlusKN
, leZeroToEqZero
)
where

import Data.Constraint (Dict(..), (:-)(..))
import Data.Constraint.Nat (euclideanNat)
import Data.Kind (Type)
import Data.Type.Equality ((:~:)(..))
#if MIN_VERSION_base(4,16,0)
import Data.Type.Ord (OrderingI(..))
#else
import Unsafe.Coerce (unsafeCoerce)
#endif
import GHC.Show (appPrec)
import GHC.TypeLits (KnownNat, Nat, type (+), type (-), type (*),
type (^), type (<=), natVal)
type (^), type (<=), natVal,
#if MIN_VERSION_base(4,16,0)
cmpNat,
#endif
sameNat)
import GHC.TypeLits.Extra (CLog, FLog, Div, Log, Mod, Min, Max)
import GHC.Natural (naturalFromInteger)
import Language.Haskell.TH (appT, conT, litT, numTyLit, sigE)
Expand All @@ -82,7 +95,6 @@ import Language.Haskell.TH.Syntax (Lift (..))
import Language.Haskell.TH.Compat
#endif
import Numeric.Natural (Natural)
import Unsafe.Coerce (unsafeCoerce)

import Clash.Annotations.Primitive (hasBlackBox)
import Clash.XException (ShowX (..), showsPrecXWith)
Expand Down Expand Up @@ -184,11 +196,20 @@ instance KnownNat n => ShowX (UNat n) where
--
-- __NB__: Not synthesizable
toUNat :: forall n . SNat n -> UNat n
#if MIN_VERSION_base(4,16,0)
toUNat p@SNat = case cmpNat (SNat @1) p of
LTI -> USucc (toUNat @(n - 1) (predSNat p))
EQI -> USucc UZero
GTI -> case sameNat p (SNat @0) of
Just Refl -> UZero --
_ -> error "toUNat: impossible: 1 > n and n /= 0 for (n :: Nat)"
#else
toUNat p@SNat = fromI @n (snatToInteger p)
where
fromI :: forall m . Integer -> UNat m
fromI 0 = unsafeCoerce @(UNat 0) @(UNat m) UZero
fromI n = unsafeCoerce @(UNat ((m-1)+1)) @(UNat m) (USucc (fromI @(m-1) (n - 1)))
#endif

-- | Convert a unary-encoded natural number to its singleton representation
--
Expand Down Expand Up @@ -338,10 +359,20 @@ deriving instance Show (SNatLE a b)

-- | Get an ordering relation between two SNats
compareSNat :: forall a b . SNat a -> SNat b -> SNatLE a b
#if MIN_VERSION_base(4,16,0)
compareSNat a@SNat b@SNat = case cmpNat a b of
LTI -> SNatLE
EQI -> SNatLE
GTI -> case cmpNat (succSNat b) a of
LTI -> SNatGT
EQI -> SNatGT
GTI -> error "compareSNat: impossible: a > b and b + 1 > a"
#else
compareSNat a b =
if snatToInteger a <= snatToInteger b
then unsafeCoerce (SNatLE @0 @0)
else unsafeCoerce (SNatGT @1 @0)
#endif

-- | Base-2 encoded natural number
--
Expand Down Expand Up @@ -404,14 +435,28 @@ showBNat = go []
-- | Convert a singleton natural number to its base-2 representation
--
-- __NB__: Not synthesizable
toBNat :: SNat n -> BNat n
toBNat :: forall n. SNat n -> BNat n
#if MIN_VERSION_base(4,16,0)
toBNat s@SNat = case cmpNat (SNat @1) s of
LTI -> case euclideanNat @2 @n of
Sub Dict -> case sameNat (SNat @(n `Mod` 2)) (SNat @0) of
Just Refl -> B0 (toBNat (SNat @(n `Div` 2)))
Nothing -> case sameNat (SNat @(n `Mod` 2)) (SNat @1) of
Just Refl -> B1 (toBNat (SNat @(n `Div` 2)))
Nothing -> error "toBNat: impossible: n mod 2 is either 0 or 1"
EQI -> B1 BT
GTI -> case sameNat s (SNat @0) of
Just Refl -> BT
_ -> error "toBNat: impossible: 1 > n and n /= 0"
#else
toBNat s@SNat = toBNat' (snatToInteger s)
where
toBNat' :: forall m . Integer -> BNat m
toBNat' 0 = unsafeCoerce BT
toBNat' n = case n `divMod` 2 of
(n',1) -> unsafeCoerce (B1 (toBNat' @(Div (m-1) 2) n'))
(n',_) -> unsafeCoerce (B0 (toBNat' @(Div m 2) n'))
#endif

-- | Convert a base-2 encoded natural number to its singleton representation
--
Expand Down Expand Up @@ -560,3 +605,11 @@ leToPlusKN
-> r
leToPlusKN r = r @(n - k)
{-# INLINE leToPlusKN #-}

-- | Change a function that has an argument with an @(n <= 0)@ constraint to a
-- function with an argument that has an @(n ~ 0)@ constraint.
leZeroToEqZero :: forall n r. (KnownNat n, n <= 0) => (n ~ 0 => r) -> r
leZeroToEqZero x = case sameNat (SNat @n) (SNat @0) of
Just Refl -> x
_ -> error "leZeroToEqZero: impossible: n <= 0 implies n ~ 0 for (n :: Nat)"
{-# INLINE leZeroToEqZero #-}

0 comments on commit 5260e63

Please sign in to comment.