From 5260e638a60e7d7c4ab1a8430b03e1d2361920fd Mon Sep 17 00:00:00 2001 From: Felix Klein Date: Fri, 15 Nov 2024 13:09:30 +0100 Subject: [PATCH] Use less unsafeCoerce to define the promoted nats --- clash-prelude/src/Clash/Promoted/Nat.hs | 59 +++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/clash-prelude/src/Clash/Promoted/Nat.hs b/clash-prelude/src/Clash/Promoted/Nat.hs index a13f99b272..f9c36d7564 100644 --- a/clash-prelude/src/Clash/Promoted/Nat.hs +++ b/clash-prelude/src/Clash/Promoted/Nat.hs @@ -67,13 +67,26 @@ module Clash.Promoted.Nat -- * Constraints on natural numbers , leToPlus , leToPlusKN + , leZeroToEqZero ) where +import Data.Constraint (Dict(..), (:-)(..)) +import Data.Constraint.Nat (euclideanNat) import Data.Kind (Type) +import Data.Type.Equality ((:~:)(..)) +#if MIN_VERSION_base(4,16,0) +import Data.Type.Ord (OrderingI(..)) +#else +import Unsafe.Coerce (unsafeCoerce) +#endif import GHC.Show (appPrec) import GHC.TypeLits (KnownNat, Nat, type (+), type (-), type (*), - type (^), type (<=), natVal) + type (^), type (<=), natVal, +#if MIN_VERSION_base(4,16,0) + cmpNat, +#endif + sameNat) import GHC.TypeLits.Extra (CLog, FLog, Div, Log, Mod, Min, Max) import GHC.Natural (naturalFromInteger) import Language.Haskell.TH (appT, conT, litT, numTyLit, sigE) @@ -82,7 +95,6 @@ import Language.Haskell.TH.Syntax (Lift (..)) import Language.Haskell.TH.Compat #endif import Numeric.Natural (Natural) -import Unsafe.Coerce (unsafeCoerce) import Clash.Annotations.Primitive (hasBlackBox) import Clash.XException (ShowX (..), showsPrecXWith) @@ -184,11 +196,20 @@ instance KnownNat n => ShowX (UNat n) where -- -- __NB__: Not synthesizable toUNat :: forall n . SNat n -> UNat n +#if MIN_VERSION_base(4,16,0) +toUNat p@SNat = case cmpNat (SNat @1) p of + LTI -> USucc (toUNat @(n - 1) (predSNat p)) + EQI -> USucc UZero + GTI -> case sameNat p (SNat @0) of + Just Refl -> UZero -- + _ -> error "toUNat: impossible: 1 > n and n /= 0 for (n :: Nat)" +#else toUNat p@SNat = fromI @n (snatToInteger p) where fromI :: forall m . Integer -> UNat m fromI 0 = unsafeCoerce @(UNat 0) @(UNat m) UZero fromI n = unsafeCoerce @(UNat ((m-1)+1)) @(UNat m) (USucc (fromI @(m-1) (n - 1))) +#endif -- | Convert a unary-encoded natural number to its singleton representation -- @@ -338,10 +359,20 @@ deriving instance Show (SNatLE a b) -- | Get an ordering relation between two SNats compareSNat :: forall a b . SNat a -> SNat b -> SNatLE a b +#if MIN_VERSION_base(4,16,0) +compareSNat a@SNat b@SNat = case cmpNat a b of + LTI -> SNatLE + EQI -> SNatLE + GTI -> case cmpNat (succSNat b) a of + LTI -> SNatGT + EQI -> SNatGT + GTI -> error "compareSNat: impossible: a > b and b + 1 > a" +#else compareSNat a b = if snatToInteger a <= snatToInteger b then unsafeCoerce (SNatLE @0 @0) else unsafeCoerce (SNatGT @1 @0) +#endif -- | Base-2 encoded natural number -- @@ -404,7 +435,20 @@ showBNat = go [] -- | Convert a singleton natural number to its base-2 representation -- -- __NB__: Not synthesizable -toBNat :: SNat n -> BNat n +toBNat :: forall n. SNat n -> BNat n +#if MIN_VERSION_base(4,16,0) +toBNat s@SNat = case cmpNat (SNat @1) s of + LTI -> case euclideanNat @2 @n of + Sub Dict -> case sameNat (SNat @(n `Mod` 2)) (SNat @0) of + Just Refl -> B0 (toBNat (SNat @(n `Div` 2))) + Nothing -> case sameNat (SNat @(n `Mod` 2)) (SNat @1) of + Just Refl -> B1 (toBNat (SNat @(n `Div` 2))) + Nothing -> error "toBNat: impossible: n mod 2 is either 0 or 1" + EQI -> B1 BT + GTI -> case sameNat s (SNat @0) of + Just Refl -> BT + _ -> error "toBNat: impossible: 1 > n and n /= 0" +#else toBNat s@SNat = toBNat' (snatToInteger s) where toBNat' :: forall m . Integer -> BNat m @@ -412,6 +456,7 @@ toBNat s@SNat = toBNat' (snatToInteger s) toBNat' n = case n `divMod` 2 of (n',1) -> unsafeCoerce (B1 (toBNat' @(Div (m-1) 2) n')) (n',_) -> unsafeCoerce (B0 (toBNat' @(Div m 2) n')) +#endif -- | Convert a base-2 encoded natural number to its singleton representation -- @@ -560,3 +605,11 @@ leToPlusKN -> r leToPlusKN r = r @(n - k) {-# INLINE leToPlusKN #-} + +-- | Change a function that has an argument with an @(n <= 0)@ constraint to a +-- function with an argument that has an @(n ~ 0)@ constraint. +leZeroToEqZero :: forall n r. (KnownNat n, n <= 0) => (n ~ 0 => r) -> r +leZeroToEqZero x = case sameNat (SNat @n) (SNat @0) of + Just Refl -> x + _ -> error "leZeroToEqZero: impossible: n <= 0 implies n ~ 0 for (n :: Nat)" +{-# INLINE leZeroToEqZero #-}