-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsolve.rs
339 lines (298 loc) · 16 KB
/
solve.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
//! Solve ast-based expressions by converting to another form.
use std::env;
///
/// the tests in this module use command line options to show or hide diagrams.
/// -a show AST (problem statement)
/// -r show result (BDD, ANF, etc)
///
/// note that you need to use two '--' parameters to pass arguments to a test.
///
/// syntax:
/// cargo test pattern -- test_engine_args -- actual_args
/// example:
/// cargo test nano_bdd -- --nocapture -- -a -r
///
/// (the --nocapture is an optional argument to the test engine. it turns off
/// capturing of stdout so that you can see debug lines from the solver)
use std::{collections::HashSet, time::SystemTime};
use crate::{apl, ops};
use crate::base::Base;
use crate::nid::NID;
use crate::vid::VID;
use crate::ops::Ops;
use crate::reg::Reg;
use crate::{GraphViz, ast::{ASTBase, RawASTBase}, int::{GBASE,BInt,BaseBit}};
/// protocol used by solve.rs. These allow the base to prepare itself for different steps
/// in a substitution solver.
pub trait SubSolver {
/// Initialize the solver by constructing the node corresponding to the final
/// virtual variable in the expression. Return its nid.
fn init(&mut self, top: VID)->NID { NID::from_vid(top) }
/// tell the implementation to perform a substitution step.
/// context NIDs are passed in and out so the implementation
/// itself doesn't have to remember it.
fn subst(&mut self, ctx:NID, vid:VID, ops:&Ops)->NID;
/// fetch a solution, (if one exists)
fn get_one(&self, ctx:NID, nvars:usize)->Option<Reg> {
println!("Warning: default SubSolver::get_one() calls get_all(). Override this!");
self.get_all(ctx, nvars).iter().next().cloned() }
/// fetch all solutions
fn get_all(&self, ctx:NID, nvars:usize)->HashSet<Reg>;
// a status message for the progress report
fn status(&self)->String { "".to_string() }
/// Dump the current internal state for inspection by some external process.
/// Generally this means writing to a graphviz (*.dot) file.
fn dump(&self, _step: usize, _nid: NID) { }
// !! these are defined here but never overwritten in the trait (used by solver) [fix this]
fn init_stats(&mut self) { }
fn print_stats(&mut self) { }}
impl<B:Base> SubSolver for B {
fn subst(&mut self, ctx:NID, v:VID, ops:&Ops) ->NID {
let def = match ops {
Ops::RPN(x) => if x.len() == 3 {
match x[2].to_fun().unwrap() {
ops::AND => self.and(x[0], x[1]),
ops::XOR => self.xor(x[0], x[1]),
ops::VEL => self.or(x[0], x[1]),
_ => panic!("don't know how to translate {:?}", ops)}}
else { todo!("SubSolver impl for Base can only handle simple dyadic ops for now.") }};
//_ => { todo!("SubSolver impl for Base can only handle RPN for now")}};
self.sub(v, def, ctx)}
fn get_all(&self, ctx:NID, nvars:usize)->HashSet<Reg> { self.solution_set(ctx, nvars) }
fn init_stats(&mut self) { Base::init_stats(self) }
fn print_stats(&mut self) { Base::print_stats(self) }
}
pub trait Progress<S:SubSolver> {
fn on_start(&mut self, _ctx:&DstNid) { } // println!("INITIAL ctx: {:?}", ctx)
fn on_step(&mut self, src:&RawASTBase, dest: &mut S, step:usize, millis:u128, oldtop:DstNid, newtop:DstNid);
fn on_done(&mut self, src:&RawASTBase, dest: &mut S, newtop:DstNid); }
pub struct ProgressReport<'a> {
pub start: std::time::SystemTime,
pub millis: u128,
pub save_dot: bool,
pub save_dest: bool,
pub prefix: &'a str }
/// these are wrappers so the type system can help us keep the src and dest nids separate
#[derive(Clone, Copy, Debug, PartialEq)] pub struct SrcNid { pub n: NID }
#[derive(Clone, Copy, Debug, PartialEq)] pub struct DstNid { pub n: NID }
impl<S:SubSolver> Progress<S> for ProgressReport<'_> {
fn on_start(&mut self, _ctx:&DstNid) { self.start = std::time::SystemTime::now(); }
fn on_step(&mut self, _src:&RawASTBase, _dest: &mut S, _step:usize, _millis:u128, _oldtop:DstNid, _newtop:DstNid) { }
/*
self.millis += millis;
let DstNid{ n: new } = newtop;
println!("{:4}, {:8} ms, {:45?} → {:45?}, {:45?}",
step, millis, oldtop.n,
if new.vid().is_vir() {
format!("{:?}", src.get_ops(NID::ixn(new.vid().vir_ix()))) }
else { format!("{:?}", new)},
newtop.n);
// dest.show_named(newtop.n, format!("step-{}", step).as_str());
if step.trailing_zeros() >= 3 { // every so often, save the state
// !! TODO: expected number of steps only works if sort_by_cost was called.
{ let expected_steps = src.len() as f64;
let percent_done = 100.0 * (step as f64) / expected_steps;
println!("\n# newtop: {:?} step:{}/{} ({:.2}%)",
newtop.n.vid(), step, src.len(), percent_done); }
if self.save_dest {
println!("TODO: save_dest for SwapSolver instead of Base")
// dest.tag(new, "top".to_string()); dest.tag(NID::var(step as u32), "step".to_string());
// TODO: remove the 'bdd' suffix
// dest.save(format!("{}-{:04}.bdd", self.prefix, step).as_str()).expect("failed to save");
}}
if step.trailing_zeros() >= 5 { println!("step, millis, change, newtop"); }
if self.save_dot && (step.trailing_zeros() >= 5) || (step==446)
{ // on really special occasions, output a diagram
let note = &dest.status();
let path = Path::new("."); // todo
let ops = &Ops::RPN(vec![]); // todo
dest.dump(path, note, step, oldtop.n, newtop.n.vid(), ops, newtop.n); }}
*/
fn on_done(&mut self, _src:&RawASTBase, _dest: &mut S, _newtop:DstNid) {
println!("total time: {} ms", self.start.elapsed().unwrap().as_millis() ) }}
fn default_bitmask(_src:&RawASTBase, v:VID) -> u64 { v.bitmask() }
/// This function renumbers the NIDs so that nodes with higher IDs "cost" more.
/// Sorting your AST this way dramatically reduces the cost of converting to
/// another form. (For example, the test_tiny benchmark drops from 5282 steps to 111 for BddBase)
pub fn sort_by_cost(src:&RawASTBase, top:SrcNid)->(RawASTBase,SrcNid) {
let (mut src0,kept0) = src.repack(vec![top.n]);
src0.tag(kept0[0], "-top-".to_string());
// m:mask (which input vars are required?); c:cost (in steps before we can calculate)
let (_m0,c0) = src0.masks_and_costs(default_bitmask);
let p = apl::gradeup(&c0); // p[new idx] = old idx
let ast = src0.permute(&p);
let n = ast.get("-top-").expect("what? I just put it there.");
(ast,SrcNid{n}) }
/// map a nid from the source to a (usually virtual) variable in the destination
pub fn convert_nid(sn:SrcNid)->DstNid {
let SrcNid{ n } = sn;
let r = if n.is_const() { n }
else {
let r0 =
if n.is_vir() { panic!("what? should never be a VIR in the source."); }
else if n.is_var() { n.raw() }
else if n.is_ixn() { NID::vir(n.idx() as u32) }
else { todo!("convert_nid({:?})", n) };
if n.is_inv() { !r0 } else { r0 }};
DstNid{ n: r } }
/// replace node in destination with its definition form source
fn refine_one(dst: &mut dyn SubSolver, v:VID, src:&RawASTBase, d:DstNid)->DstNid {
// println!("refine_one({:?})", d)
let ctx = d.n;
let ops = src.get_ops(NID::ixn(v.vir_ix()));
let cn = |x0:&NID|->NID { if x0.is_fun() { *x0 } else { convert_nid(SrcNid{n:*x0}).n }};
let def:Ops = Ops::RPN( ops.to_rpn().map(cn).collect() );
DstNid{n: dst.subst(ctx, v, &def) }}
/// This is the core algorithm for solving by substitution. We are given a (presumably empty)
/// destination (the `SubSolver`), a source ASTBase (`src0`), and a source nid (`sn`),
/// pointing to a node inside the ASTBase.
///
/// The source nids we encounter are indices into the ASTBase. We begin by sorting/rewriting
/// the ASTBase in terms of "cost", so that a node at index k is only dependent on nodes
/// with indices < k. We also filter out any nodes that are not actually used (for example,
/// there may be nodes in the middle of the AST that are expensive to calculate on their own,
/// but get canceled out later on (perhaps by XORing with itself, or ANDing with 0) -- there's
/// no point including these at all as we work backwards.
///
/// After this sorting and filtering, we map each nid in the AST to a `VID::vir` with
/// the corresponding index. We then initialize `dst` with the highest vid (the one
/// corresponding to the topmost/highest cost node in the AST).
///
/// We then replace each VID in turn with its definition. The definition of each intermediate
/// node is always in terms of either other AST nodes (mapped to `VID::vir` in the destination,
/// or actual input variables (`VID::var`), which are added to the destination directly).
///
/// The dependency ordering ensures that we never re-introduce a node after substitution,
/// so the number of substitution steps is equal to the number of AST nodes.
///
/// Of course, the cost of each substitution is likely to increase as the destination
/// becomes more and more detailed. Depending on the implementation, this cost may even
/// grow exponentially. However, the hope is that by working "backward" from the final
/// result, we will have access to the maximal number of constraints, and there
/// will be opportunities to streamline and cancel out even more nodes. The hope is that
/// no matter how slow this process is, it will be less slow that trying to fully solve
/// each intermediate node by working "forward".
pub fn solve<S:SubSolver>(dst:&mut S, src0:&RawASTBase, sn:NID)->DstNid {
// AST nids don't contain VIR nodes (they "are" vir nodes).
// If it's already a const or a VID::var, though, there's nothing to do.
if sn.is_lit() { DstNid{n:sn} }
else {
dst.init(sn.vid());
// renumber and garbage collect, leaving only the AST nodes reachable from sn
let (src, top) = sort_by_cost(src0, SrcNid{n:sn});
// step is just a number that counts downward.
let mut step:usize = top.n.idx();
// !! These lines were a kludge to allow storing the step number in the dst,
// with the idea of persisting the destination to disk to resume later.
// The current solvers are so slow that I'm not actually using them for
// anything but testing, though, so I don't need this yet.
// TODO: re-enable the ability to save and load the destination mid-run.
// let step_node = dst.get(&"step".to_string()).unwrap_or_else(||NID::var(0));
// let mut step:usize = step_node.vid().var_ix();
// v is the next virtual variable to replace.
let mut v = VID::vir(step as u32);
// The context is the evolving top-level node in the destination.
// It begins with just the vir representing the top node in the AST.
let mut ctx = DstNid{n: dst.init(v)};
// This just lets us record timing info. TODO: pr probably should be an input parameter.
let mut pr = ProgressReport{ start: SystemTime::now(), save_dot: false, save_dest: false, prefix:"x", millis: 0 };
<dyn Progress<S>>::on_start(&mut pr, &ctx);
// main loop:
while !(ctx.n.is_var() || ctx.n.is_const()) {
let now = std::time::SystemTime::now();
let old = ctx; ctx = refine_one(dst, v, &src, ctx);
let millis = now.elapsed().expect("elapsed?").as_millis();
pr.on_step(&src, dst, step, millis, old, ctx);
if step == 0 { break } else { step -= 1; v=VID::vir(step as u32) }}
pr.on_done(&src, dst, ctx);
ctx}}
fn multiplication_bits<T0:BInt, T1:BInt>(k:usize)->(BaseBit, BaseBit) {
GBASE.with(|gb| gb.replace(ASTBase::empty())); // reset on each test
let (y, x) = (T0::def("y", 0), T0::def("x", T0::n())); let lt = x.lt(&y);
let xy:T1 = x.times(&y); let k = T1::new(k); let eq = xy.eq(&k);
(lt,eq) }
/// This is an example solver used by the tests and benchmarks.
/// It finds all pairs of type T0 that multiply to k as a T1.
/// dest is the solver that does the work.
pub fn find_factors<T0:BInt, T1:BInt, S:SubSolver>(dest:&mut S, k:usize, expected:Vec<(u64,u64)>) {
let (lt, eq) = multiplication_bits::<T0,T1>(k);
let mut show_ast = false; // let mut show_res = false;
for arg in env::args() { match arg.as_str() {
"-a" => { show_ast = true }
"-r" => { /*show_res = true*/ }
_ => {} }}
if show_ast {
GBASE.with(|gb| { gb.borrow().show_named(lt.clone().n, "lt") });
GBASE.with(|gb| { gb.borrow().show_named(eq.clone().n, "eq") }); }
let top:BaseBit = lt & eq;
assert!(top.n.is_ixn(), "top nid seems to be a literal. (TODO: handle these already solved cases)");
let gb = GBASE.with(|gb| gb.replace(ASTBase::empty())); // swap out the thread-local one
let src = gb.raw_ast();
if show_ast { src.show_named(top.n, "ast"); }
// --- now we have the ast, so solve ----
dest.init_stats();
let answer:DstNid = solve(dest, src, top.n);
// if show_res { dest.show_named(answer.n, "result") }
type Factors = (u64,u64);
let to_factors = |r:&Reg|->Factors {
let t = r.as_usize();
let x = t & ((1<<T0::n())-1);
let y = t >> T0::n();
(y as u64, x as u64) };
let actual_regs:HashSet<Reg> = dest.get_all(answer.n, 2*T0::n() as usize);
let actual:HashSet<Factors> = actual_regs.iter().map(to_factors).collect();
let expect:HashSet<Factors> = expected.iter().map(|&(x,y)| (x, y)).collect();
assert_eq!(actual, expect);
dest.print_stats(); }
/// nano test case for BDD: factor (*/2 3)=6 into two bitpairs. The only answer is 2,3.
#[test] pub fn test_nano_bdd() {
use crate::{bdd::BddBase, int::{X2,X4}};
find_factors::<X2,X4,BddBase>(&mut BddBase::new(), 6, vec![(2,3)]); }
/// nano test case for ANF: factor (*/2 3)=6 into two bitpairs. The only answer is 2,3.
#[test] pub fn test_nano_anf() {
use crate::{anf::ANFBase, int::{X2,X4}};
find_factors::<X2,X4,ANFBase>(&mut ANFBase::new(), 6, vec![(2,3)]); }
/// nano test case for swap solver: factor (*/2 3)=6 into two bitpairs. The only answer is 2,3.
#[test] pub fn test_nano_swap() {
use crate::{swap::SwapSolver, int::{X2,X4}};
find_factors::<X2, X4, SwapSolver>(&mut SwapSolver::new(), 6, vec![(2,3)]); }
/// tiny test case: factor (*/2 3 5 7)=210 into 2 nibbles. The only answer is 14,15.
#[test] pub fn test_tiny_bdd() {
use crate::{bdd::BddBase, int::{X4,X8}};
find_factors::<X4, X8, BddBase>(&mut BddBase::new(), 210, vec![(14,15)]); }
/// tiny test case: factor (*/2 3 5 7)=210 into 2 nibbles. The only answer is 14,15.
#[test] pub fn test_tiny_anf() {
use crate::{anf::ANFBase, int::{X4,X8}};
find_factors::<X4, X8, ANFBase>(&mut ANFBase::new(), 210, vec![(14,15)]); }
/// tiny test case: factor (*/2 3 5 7)=210 into 2 nibbles. The only answer is 14,15.
#[test] pub fn test_tiny_swap() {
use crate::{swap::SwapSolver, int::{X4,X8}};
find_factors::<X4, X8, SwapSolver>(&mut SwapSolver::new(), 210, vec![(14,15)]); }
/// multi: factor (*/2 3 5)=30 into 2 nibbles. There are three answers.
#[test] pub fn test_multi_bdd() {
use crate::{bdd::BddBase, int::{X4,X8}};
find_factors::<X4, X8, BddBase>(&mut BddBase::new(), 30, vec![(2,15), (3,10), (5,6)]); }
/// multi: factor (*/2 3 5)=30 into 2 nibbles. There are three answers.
#[test] pub fn test_multi_anf() {
use crate::{anf::ANFBase, int::{X4,X8}};
find_factors::<X4, X8, ANFBase>(&mut ANFBase::new(), 30, vec![(2,15), (3,10), (5,6)]); }
/// same as tiny test, but multiply 2 bytes to get 210. There are 8 distinct answers.
/// this was intended as a unit test but is *way* too slow.
/// (11m17.768s on rincewind (hex-core Intel i7-8700K @ 3.70 GHz with 16GB ram) as of 6/16/2020)
/// (that's with debug information and no optimizations enabled in rustc)
#[cfg(feature="slowtests")]
#[test] pub fn test_small_bdd() {
use {bdd::BddBase, int::{X8,X16}};
let expected = vec![(1,210), (2,105), ( 3,70), ( 5,42),
(6, 35), (7, 30), (10,21), (14,15)];
find_factors::<X8, X16, BddBase>(&mut BddBase::new(), 210, expected); }
/// same test using the swap solver
/// `time cargo test --lib --features slowtests test_small_swap`
/// timing on rincewind is 5m13.901s as of 4/23/2021, so the swap
/// solver running on 1 core is more than 2x faster than old solver on 6!
#[cfg(feature="slowtests")]
#[test] pub fn test_small_swap() {
use {swap::SwapSolver, int::{X8,X16}};
let expected = vec![(1,210), (2,105), ( 3,70), ( 5,42),
(6, 35), (7, 30), (10,21), (14,15)];
find_factors::<X8, X16, SwapSolver>(&mut SwapSolver::new(), 210, expected); }