Skip to content

Commit

Permalink
select without unsafeCoerce
Browse files Browse the repository at this point in the history
  • Loading branch information
kleinreact committed Nov 21, 2024
1 parent 2d6287f commit de50812
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions clash-prelude/src/Clash/Sized/Vector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ import qualified Data.Foldable as F
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Singletons (TyFun,Apply,type (@@))
import GHC.TypeLits (CmpNat, KnownNat, Nat, type (+), type (-), type (*),
import GHC.TypeLits (KnownNat, Nat, type (+), type (-), type (*),
type (^), type (<=), natVal)
import GHC.Base (Int(I#),Int#,isTrue#)
import GHC.Generics hiding (Fixity (..))
Expand All @@ -138,7 +138,6 @@ import qualified Data.String.Interpolate as I
import qualified Prelude as P
import Test.QuickCheck
(Arbitrary(arbitrary, shrink), CoArbitrary(coarbitrary))
import Unsafe.Coerce (unsafeCoerce)

import Clash.Annotations.Primitive
(Primitive(InlineYamlPrimitive), HDL(..), dontTranslate, hasBlackBox)
Expand Down Expand Up @@ -1687,18 +1686,24 @@ at n xs = head $ snd $ splitAt n xs
-- 2 :> 4 :> 6 :> Nil
-- >>> select d1 d2 d3 (1:>2:>3:>4:>5:>6:>7:>8:>Nil)
-- 2 :> 4 :> 6 :> Nil
select :: (CmpNat (i + s) (s * n) ~ 'GT)
select :: forall i s n f a. s * n + 1 <= i + s
=> SNat f
-> SNat s
-> SNat n
-> Vec (f + i) a
-> Vec n a
select f s n xs = select' (toUNat n) $ drop f xs
where
select' :: UNat n -> Vec i a -> Vec n a
select' UZero _ = Nil
select' (USucc n') vs@(x `Cons` _) = x `Cons`
select' n' (drop s (unsafeCoerce vs))
where
select' :: forall m j b. (s * m + 1 <= j + s) => UNat m -> Vec j b -> Vec m b
select' m vs = case m of
UZero -> Nil
USucc UZero -> head @(j - 1) vs `Cons` Nil
USucc m'@(USucc _) -> case deduce @(s * (m - 1) + 1) @j Proxy Proxy of
Dict -> head @(j - 1) vs `Cons` select' m' (drop @s @(j - s) s vs)

deduce :: e + s <= k + s => p e -> p k -> Dict (e <= k)
deduce _ _ = Dict

-- See: https://github.com/clash-lang/clash-compiler/pull/2511
{-# CLASH_OPAQUE select #-}
{-# ANN select hasBlackBox #-}
Expand All @@ -1708,7 +1713,7 @@ select f s n xs = select' (toUNat n) $ drop f xs
--
-- >>> selectI d1 d2 (1:>2:>3:>4:>5:>6:>7:>8:>Nil) :: Vec 2 Int
-- 2 :> 4 :> Nil
selectI :: (CmpNat (i + s) (s * n) ~ 'GT, KnownNat n)
selectI :: (1 <= s, s * n + 1 <= i + s, KnownNat n)
=> SNat f
-> SNat s
-> Vec (f + i) a
Expand Down

0 comments on commit de50812

Please sign in to comment.