Skip to content

Commit

Permalink
Implement pragma for declaring nat optimizations
Browse files Browse the repository at this point in the history
Co-authored-by: MarcelineVQ <[email protected]>
  • Loading branch information
fabianhjr and MarcelineVQ committed May 3, 2020
1 parent 45d02a1 commit cd106f6
Show file tree
Hide file tree
Showing 13 changed files with 493 additions and 56 deletions.
24 changes: 17 additions & 7 deletions libs/prelude/Prelude.idr
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,9 @@ for = flip traverse
-- NATS ---
-----------

%builtinNatZero Z
%builtinNatSucc S

||| Natural numbers: unbounded, unsigned integers which can be pattern matched.
public export
data Nat =
Expand All @@ -702,13 +705,27 @@ data Nat =

%name Nat k, j, i

%builtinIntegerToNat integerToNat
%builtinNatToInteger natToInteger

public export
integerToNat : Integer -> Nat
integerToNat x
= if intToBool (prim__lte_Integer x 0)
then Z
else S (assert_total (integerToNat (prim__sub_Integer x 1)))

public export
natToInteger : Nat -> Integer
natToInteger Z = 0
natToInteger (S k) = 1 + natToInteger k
-- integer (+) may be non-linear in second
-- argument

%builtinNatAdd plus
%builtinNatSub minus
%builtinNatMul mult

-- Define separately so we can spot the name when optimising Nats
||| Add two natural numbers.
||| @ x the number to case-split on
Expand Down Expand Up @@ -752,13 +769,6 @@ Ord Nat where
compare (S k) Z = GT
compare (S j) (S k) = compare j k

public export
natToInteger : Nat -> Integer
natToInteger Z = 0
natToInteger (S k) = 1 + natToInteger k
-- integer (+) may be non-linear in second
-- argument

-----------
-- PAIRS --
-----------
Expand Down
7 changes: 3 additions & 4 deletions src/Compiler/Common.idr
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ natHackNames : List Name
natHackNames
= [UN "prim__add_Integer",
UN "prim__sub_Integer",
UN "prim__mul_Integer",
NS ["Prelude"] (UN "natToInteger"),
NS ["Prelude"] (UN "integerToNat")]
UN "prim__mul_Integer"]

export
fastAppend : List String -> String
Expand Down Expand Up @@ -196,7 +194,8 @@ getCompileData tm
= do defs <- get Ctxt
sopts <- getSession
let ns = getRefs (Resolved (-1)) tm
natHackNames' <- traverse toResolvedNames natHackNames
builtins <- getBuiltins
natHackNames' <- traverse toResolvedNames (natHackNames ++ getDefinedBuiltinNames builtins)
-- make an array of Bools to hold which names we've found (quicker
-- to check than a NameMap!)
asize <- getNextEntry
Expand Down
52 changes: 35 additions & 17 deletions src/Compiler/CompileExpr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Core.Context
import Core.Env
import Core.Name
import Core.Normalise
import Core.Options
import Core.TT
import Core.Value

Expand Down Expand Up @@ -136,20 +137,36 @@ mkDropSubst i es rest (x :: xs)
-- NOTE: Make sure that names mentioned here are listed in 'natHackNames' in
-- Common.idr, so that they get compiled, as they won't be spotted by the
-- usual calls to 'getRefs'.
natHack : CExp vars -> CExp vars
natHack (CCon fc (NS ["Prelude"] (UN "Z")) _ []) = CPrimVal fc (BI 0)
natHack (CCon fc (NS ["Prelude"] (UN "S")) _ [k])
= CApp fc (CRef fc (UN "prim__add_Integer")) [CPrimVal fc (BI 1), k]
natHack (CApp fc (CRef _ (NS ["Prelude"] (UN "natToInteger"))) [k]) = k
natHack (CApp fc (CRef _ (NS ["Prelude"] (UN "integerToNat"))) [k]) = k
natHack (CApp fc (CRef fc' (NS ["Prelude"] (UN "plus"))) args)
= CApp fc (CRef fc' (UN "prim__add_Integer")) args
natHack (CApp fc (CRef fc' (NS ["Prelude"] (UN "mult"))) args)
= CApp fc (CRef fc' (UN "prim__mul_Integer")) args
natHack (CApp fc (CRef fc' (NS ["Nat", "Data"] (UN "minus"))) args)
= CApp fc (CRef fc' (UN "prim__sub_Integer")) args
natHack (CLam fc x exp) = CLam fc x (natHack exp)
natHack t = t
natHack : PrimBuiltinNames -> CExp vars -> CExp vars
natHack primBuiltinNames (CCon fc name t [])
= if Just name == builtinNatZero primBuiltinNames
then CPrimVal fc (BI 0)
else (CCon fc name t [])
natHack primBuiltinNames (CCon fc name t [k])
= if Just name == builtinNatSucc primBuiltinNames
then CApp fc (CRef fc (UN "prim__add_Integer")) [CPrimVal fc (BI 1), k]
else (CCon fc name t [k])
natHack primBuiltinNames (CApp fc (CRef fc' name) [k])
= if Just name == builtinNatToInteger primBuiltinNames
then k
else if Just name == builtinIntegerToNat primBuiltinNames
then k
else (CApp fc (CRef fc' name) [k])
natHack primBuiltinNames (CApp fc (CRef fc' name) args)
= if Just name == builtinNatAdd primBuiltinNames
then CApp fc (CRef fc' (UN "prim__add_Integer")) args
else if Just name == builtinNatSub primBuiltinNames
then CApp fc (CRef fc' (UN "prim__sub_Integer")) args
else if Just name == builtinNatMul primBuiltinNames
then CApp fc (CRef fc' (UN "prim__mul_Integer")) args
else if Just name == builtinNatDiv primBuiltinNames
then CApp fc (CRef fc' (UN "prim__div_Integer")) args
else if Just name == builtinNatMod primBuiltinNames
then CApp fc (CRef fc' (UN "prim__mod_Integer")) args
else (CApp fc (CRef fc' name) args)
-- TODO: Eq, LT, LTE, GT, GTE
natHack p (CLam fc x exp) = CLam fc x (natHack p exp)
natHack _ t = t

isNatCon : Name -> Bool
isNatCon (NS ["Prelude"] (UN "Z")) = True
Expand Down Expand Up @@ -247,15 +264,16 @@ mutual
(f, args) =>
do args' <- traverse (toCExp tags n) args
defs <- get Ctxt
builtins <- getBuiltins
Arity a <- numArgs defs f
| NewTypeBy arity pos =>
do let res = applyNewType arity pos !(toCExpTm tags n f) args'
pure $ natHack res
pure $ natHack builtins res
| EraseArgs arity epos =>
do let res = eraseConArgs arity epos !(toCExpTm tags n f) args'
pure $ natHack res
pure $ natHack builtins res
let res = expandToArity a !(toCExpTm tags n f) args'
pure $ natHack res
pure $ natHack builtins res

mutual
conCases : {auto c : Ref Ctxt Defs} ->
Expand Down
66 changes: 58 additions & 8 deletions src/Core/Binary.idr
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import Data.Buffer
-- TTC files can only be compatible if the version number is the same
export
ttcVersion : Int
ttcVersion = 23
ttcVersion = 24

export
checkTTCVersion : String -> Int -> Int -> Core ()
Expand Down Expand Up @@ -111,10 +111,23 @@ HasNames e => HasNames (TTCFile e) where
fullRW gam (Just (MkRewriteNs e r))
= pure $ Just $ MkRewriteNs !(full gam e) !(full gam r)

fullPrim : Context -> PrimNames -> Core PrimNames
fullPrim gam (MkPrimNs mi ms mc)
= pure $ MkPrimNs !(full gam mi) !(full gam ms) !(full gam mc)
fullPrimCast : Context -> PrimCastNames -> Core PrimCastNames
fullPrimCast gam (MkPrimCastNames i str c)
= pure $ MkPrimCastNames !(full gam i) !(full gam str) !(full gam c)

fullPrimBuiltin : Context -> PrimBuiltinNames -> Core PrimBuiltinNames
fullPrimBuiltin gam (MkPrimBuiltinNames natZero natSucc natToInteger integerToNat
natAdd natSub natMul natDiv natRem natEq natLT natLTE natGT natGTE)
= pure $ MkPrimBuiltinNames !(full gam natZero) !(full gam natSucc) !(full gam natToInteger)
!(full gam integerToNat) !(full gam natAdd) !(full gam natSub) !(full gam natMul)
!(full gam natDiv) !(full gam natRem) !(full gam natEq) !(full gam natLT)
!(full gam natLTE) !(full gam natGT) !(full gam natGTE)

fullPrim : Context -> PrimNames -> Core PrimNames
fullPrim gam (MkPrimNames primCastNames primBuiltinNames)
= do castNames <- fullPrimCast gam primCastNames
builtinNames <- fullPrimBuiltin gam primBuiltinNames
pure $ MkPrimNames castNames builtinNames

-- I don't think we ever actually want to call this, because after we read
-- from the file we're going to add them to learn what the resolved names
Expand Down Expand Up @@ -148,9 +161,23 @@ HasNames e => HasNames (TTCFile e) where
resolvedRW gam (Just (MkRewriteNs e r))
= pure $ Just $ MkRewriteNs !(resolved gam e) !(resolved gam r)

resolvedPrimCast : Context -> PrimCastNames -> Core PrimCastNames
resolvedPrimCast gam (MkPrimCastNames i str c)
= pure $ MkPrimCastNames !(resolved gam i) !(resolved gam str) !(resolved gam c)

resolvedPrimBuiltin : Context -> PrimBuiltinNames -> Core PrimBuiltinNames
resolvedPrimBuiltin gam (MkPrimBuiltinNames natZero natSucc natToInteger integerToNat
natAdd natSub natMul natDiv natRem natEq natLT natLTE natGT natGTE)
= pure $ MkPrimBuiltinNames !(resolved gam natZero) !(resolved gam natSucc) !(resolved gam natToInteger)
!(resolved gam integerToNat) !(resolved gam natAdd) !(resolved gam natSub) !(resolved gam natMul)
!(resolved gam natDiv) !(resolved gam natRem) !(resolved gam natEq) !(resolved gam natLT)
!(resolved gam natLTE) !(resolved gam natGT) !(resolved gam natGTE)

resolvedPrim : Context -> PrimNames -> Core PrimNames
resolvedPrim gam (MkPrimNs mi ms mc)
= pure $ MkPrimNs !(resolved gam mi) !(resolved gam ms) !(resolved gam mc)
resolvedPrim gam (MkPrimNames primCastNames primBuiltinNames)
= do castNames <- resolvedPrimCast gam primCastNames
builtinNames <- resolvedPrimBuiltin gam primBuiltinNames
pure $ MkPrimNames castNames builtinNames


asName : List String -> Maybe (List String) -> Name -> Name
Expand Down Expand Up @@ -320,11 +347,34 @@ updateRewrite r
put Ctxt (record { options->rewritenames $= (r <+>) } defs)

export
updatePrimNames : PrimNames -> PrimNames -> PrimNames
updatePrimNames p
updatePrimCastNames : PrimCastNames -> PrimCastNames -> PrimCastNames
updatePrimCastNames p
= record { fromIntegerName $= ((fromIntegerName p) <+>),
fromStringName $= ((fromStringName p) <+>),
fromCharName $= ((fromCharName p) <+>) }
export
updatePrimBuiltinNames : PrimBuiltinNames -> PrimBuiltinNames -> PrimBuiltinNames
updatePrimBuiltinNames p
= record { builtinNatZero $= ((builtinNatZero p) <+>),
builtinNatSucc $= ((builtinNatSucc p) <+>),
builtinNatToInteger $= ((builtinNatToInteger p) <+>),
builtinIntegerToNat $= ((builtinIntegerToNat p) <+>),
builtinNatAdd $= ((builtinNatAdd p) <+>),
builtinNatSub $= ((builtinNatSub p) <+>),
builtinNatMul $= ((builtinNatMul p) <+>),
builtinNatDiv $= ((builtinNatDiv p) <+>),
builtinNatMod $= ((builtinNatMod p) <+>),
builtinNatEq $= ((builtinNatEq p) <+>),
builtinNatLT $= ((builtinNatLT p) <+>),
builtinNatLTE $= ((builtinNatLTE p) <+>),
builtinNatGT $= ((builtinNatGT p) <+>),
builtinNatGTE $= ((builtinNatGTE p) <+>) }

export
updatePrimNames : PrimNames -> PrimNames -> PrimNames
updatePrimNames p
= record { primCastNames $= updatePrimCastNames . primCastNames $ p,
primBuiltinNames $= updatePrimBuiltinNames . primBuiltinNames $ p }

export
updatePrims : {auto c : Ref Ctxt Defs} ->
Expand Down
Loading

0 comments on commit cd106f6

Please sign in to comment.