-
Notifications
You must be signed in to change notification settings - Fork 446
/
ConstFolding.lean
231 lines (185 loc) · 7.45 KB
/
ConstFolding.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Expr
/-! Constant folding for primitives that have special runtime support. -/
namespace Lean.Compiler
def mkLcProof (p : Expr) :=
mkApp (mkConst ``lcProof []) p
abbrev BinFoldFn := Bool → Expr → Expr → Option Expr
abbrev UnFoldFn := Bool → Expr → Option Expr
def mkUIntTypeName (nbytes : Nat) : Name :=
Name.mkSimple ("UInt" ++ toString nbytes)
structure NumScalarTypeInfo where
nbits : Nat
id : Name := mkUIntTypeName nbits
ofNatFn : Name := Name.mkStr id "ofNat"
toNatFn : Name := Name.mkStr id "toNat"
size : Nat := 2^nbits
def numScalarTypes : List NumScalarTypeInfo :=
[{nbits := 8}, {nbits := 16}, {nbits := 32}, {nbits := 64},
{id := ``USize, nbits := System.Platform.numBits}]
def isOfNat (fn : Name) : Bool :=
numScalarTypes.any (fun info => info.ofNatFn == fn)
def isToNat (fn : Name) : Bool :=
numScalarTypes.any (fun info => info.toNatFn == fn)
def getInfoFromFn (fn : Name) : List NumScalarTypeInfo → Option NumScalarTypeInfo
| [] => none
| info::infos =>
if info.ofNatFn == fn then some info
else getInfoFromFn fn infos
def getInfoFromVal : Expr → Option NumScalarTypeInfo
| Expr.app (Expr.const fn _) _ => getInfoFromFn fn numScalarTypes
| _ => none
@[export lean_get_num_lit]
def getNumLit : Expr → Option Nat
| Expr.lit (Literal.natVal n) => some n
| Expr.app (Expr.const fn _) a => if isOfNat fn then getNumLit a else none
| _ => none
def mkUIntLit (info : NumScalarTypeInfo) (n : Nat) : Expr :=
mkApp (mkConst info.ofNatFn) (mkRawNatLit (n%info.size))
def mkUInt32Lit (n : Nat) : Expr :=
mkUIntLit {nbits := 32} n
def foldBinUInt (fn : NumScalarTypeInfo → Bool → Nat → Nat → Nat) (beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr := do
let n₁ ← getNumLit a₁
let n₂ ← getNumLit a₂
let info ← getInfoFromVal a₁
return mkUIntLit info (fn info beforeErasure n₁ n₂)
def foldUIntAdd := foldBinUInt fun _ _ => Add.add
def foldUIntMul := foldBinUInt fun _ _ => Mul.mul
def foldUIntDiv := foldBinUInt fun _ _ => Div.div
def foldUIntMod := foldBinUInt fun _ _ => Mod.mod
def foldUIntSub := foldBinUInt fun info _ a b => (a + (info.size - b))
def preUIntBinFoldFns : List (Name × BinFoldFn) :=
[(`add, foldUIntAdd), (`mul, foldUIntMul), (`div, foldUIntDiv),
(`mod, foldUIntMod), (`sub, foldUIntSub)]
def uintBinFoldFns : List (Name × BinFoldFn) :=
numScalarTypes.foldl (fun r info => r ++ (preUIntBinFoldFns.map (fun ⟨suffix, fn⟩ => (info.id ++ suffix, fn)))) []
def foldNatBinOp (fn : Nat → Nat → Nat) (a₁ a₂ : Expr) : Option Expr := do
let n₁ ← getNumLit a₁
let n₂ ← getNumLit a₂
return mkRawNatLit (fn n₁ n₂)
def foldNatAdd (_ : Bool) := foldNatBinOp Add.add
def foldNatMul (_ : Bool) := foldNatBinOp Mul.mul
def foldNatDiv (_ : Bool) := foldNatBinOp Div.div
def foldNatMod (_ : Bool) := foldNatBinOp Mod.mod
-- TODO: add option for controlling the limit
def natPowThreshold := 256
def foldNatPow (_ : Bool) (a₁ a₂ : Expr) : Option Expr := do
let n₁ ← getNumLit a₁
let n₂ ← getNumLit a₂
if n₂ < natPowThreshold then
return mkRawNatLit (n₁ ^ n₂)
else
failure
def mkNatEq (a b : Expr) : Expr :=
mkAppN (mkConst ``Eq [levelOne]) #[(mkConst `Nat), a, b]
def mkNatLt (a b : Expr) : Expr :=
mkAppN (mkConst ``LT.lt [levelZero]) #[mkConst ``Nat, mkConst ``Nat.lt, a, b]
def mkNatLe (a b : Expr) : Expr :=
mkAppN (mkConst ``LE.le [levelZero]) #[mkConst ``Nat, mkConst ``Nat.le, a, b]
def toDecidableExpr (beforeErasure : Bool) (pred : Expr) (r : Bool) : Expr :=
match beforeErasure, r with
| false, true => mkConst ``Bool.true
| false, false => mkConst ``Bool.false
| true, true => mkDecIsTrue pred (mkLcProof pred)
| true, false => mkDecIsFalse pred (mkLcProof pred)
def foldNatBinPred (mkPred : Expr → Expr → Expr) (fn : Nat → Nat → Bool)
(beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr := do
let n₁ ← getNumLit a₁
let n₂ ← getNumLit a₂
return toDecidableExpr beforeErasure (mkPred a₁ a₂) (fn n₁ n₂)
def foldNatDecEq := foldNatBinPred mkNatEq (fun a b => a = b)
def foldNatDecLt := foldNatBinPred mkNatLt (fun a b => a < b)
def foldNatDecLe := foldNatBinPred mkNatLe (fun a b => a ≤ b)
def foldNatBinBoolPred (fn : Nat → Nat → Bool) (a₁ a₂ : Expr) : Option Expr := do
let n₁ ← getNumLit a₁
let n₂ ← getNumLit a₂
if fn n₁ n₂ then
return mkConst ``Bool.true
else
return mkConst ``Bool.false
def foldNatBeq := fun _ : Bool => foldNatBinBoolPred (fun a b => a == b)
def foldNatBle := fun _ : Bool => foldNatBinBoolPred (fun a b => a < b)
def foldNatBlt := fun _ : Bool => foldNatBinBoolPred (fun a b => a ≤ b)
def natFoldFns : List (Name × BinFoldFn) :=
[(``Nat.add, foldNatAdd),
(``Nat.mul, foldNatMul),
(``Nat.div, foldNatDiv),
(``Nat.mod, foldNatMod),
(``Nat.pow, foldNatPow),
(``Nat.decEq, foldNatDecEq),
(``Nat.decLt, foldNatDecLt),
(``Nat.decLe, foldNatDecLe),
(``Nat.beq, foldNatBeq),
(``Nat.blt, foldNatBlt),
(``Nat.ble, foldNatBle)
]
def getBoolLit : Expr → Option Bool
| Expr.const ``Bool.true _ => some true
| Expr.const ``Bool.false _ => some false
| _ => none
def foldStrictAnd (_ : Bool) (a₁ a₂ : Expr) : Option Expr :=
let v₁ := getBoolLit a₁
let v₂ := getBoolLit a₂
match v₁, v₂ with
| some true, _ => a₂
| some false, _ => a₁
| _, some true => a₁
| _, some false => a₂
| _, _ => none
def foldStrictOr (_ : Bool) (a₁ a₂ : Expr) : Option Expr :=
let v₁ := getBoolLit a₁
let v₂ := getBoolLit a₂
match v₁, v₂ with
| some true, _ => a₁
| some false, _ => a₂
| _, some true => a₂
| _, some false => a₁
| _, _ => none
def boolFoldFns : List (Name × BinFoldFn) :=
[(``strictOr, foldStrictOr), (``strictAnd, foldStrictAnd)]
def binFoldFns : List (Name × BinFoldFn) :=
boolFoldFns ++ uintBinFoldFns ++ natFoldFns
def foldNatSucc (_ : Bool) (a : Expr) : Option Expr := do
let n ← getNumLit a
return mkRawNatLit (n+1)
def foldCharOfNat (beforeErasure : Bool) (a : Expr) : Option Expr := do
guard (!beforeErasure)
let n ← getNumLit a
if isValidChar n.toUInt32 then
return mkUInt32Lit n
else
return mkUInt32Lit 0
def foldToNat (_ : Bool) (a : Expr) : Option Expr := do
let n ← getNumLit a
return mkRawNatLit n
def uintFoldToNatFns : List (Name × UnFoldFn) :=
numScalarTypes.foldl (fun r info => (info.toNatFn, foldToNat) :: r) []
def unFoldFns : List (Name × UnFoldFn) :=
[(``Nat.succ, foldNatSucc),
(``Char.ofNat, foldCharOfNat)]
++ uintFoldToNatFns
def findBinFoldFn (fn : Name) : Option BinFoldFn :=
binFoldFns.lookup fn
def findUnFoldFn (fn : Name) : Option UnFoldFn :=
unFoldFns.lookup fn
@[export lean_fold_bin_op]
def foldBinOp (beforeErasure : Bool) (f : Expr) (a : Expr) (b : Expr) : Option Expr := do
match f with
| Expr.const fn _ =>
let foldFn ← findBinFoldFn fn
foldFn beforeErasure a b
| _ =>
failure
@[export lean_fold_un_op]
def foldUnOp (beforeErasure : Bool) (f : Expr) (a : Expr) : Option Expr := do
match f with
| Expr.const fn _ =>
let foldFn ← findUnFoldFn fn
foldFn beforeErasure a
| _ => failure
end Lean.Compiler