Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jun 22, 2024
1 parent 35b3883 commit 4db3120
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 2 deletions.
2 changes: 1 addition & 1 deletion crates/luminal_cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl CudaFloat for u8 {

fn expr_to_cuda_string(expr: &BigExpression) -> String {
let mut symbols = vec![];
for term in expr.terms.clone() {
for term in expr.clone().simplify().terms {
let new_symbol = match term {
Term::Num(n) => n.to_string(),
Term::Var(c) => {
Expand Down
81 changes: 81 additions & 0 deletions crates/luminal_cuda/src/tests/fp32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -731,3 +731,84 @@ fn test_movement() {

assert_exact(&c.data(), &d_c.as_vec());
}

#[test]
fn test_conv2d() {
let mut cx = Graph::new();

const CH_IN: usize = 5;
const CH_OUT: usize = 2;
const KERNELX: usize = 2;
const KERNELY: usize = 2;
const STRIDEX: usize = KERNELX;
const STRIDEY: usize = KERNELY;
const DILATIONX: usize = 0;
const DILATIONY: usize = 0;
const DIMX_IN: usize = 16;
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMY_IN: usize = 9;
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;

let inp1 = cx.tensor::<R3<CH_IN, DIMX_IN, DIMY_IN>>().set(vec![
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8., 8.,
5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9., 4., 7.,
1., 0., 7., 7., 4., 9., 5., 0., 4., 7., 4., 7., 8., 8., 4., 8., 4., 7., 9., 3., 7., 9., 5.,
8., 5., 9., 0., 9., 5., 6., 8., 9., 5., 4., 1., 9., 7., 2., 2., 7., 9., 3., 1., 2., 8., 4.,
0., 8., 0., 5., 6., 7., 7., 4., 3., 4., 6., 8., 3., 7., 8., 8., 7., 1., 5., 1., 8., 0., 1.,
1., 7., 3., 2., 1., 0., 4., 5., 4., 3., 2., 5., 4., 2., 4., 1., 9., 4., 1., 9., 7., 7., 1.,
2., 6., 3., 4., 1., 1., 6., 6., 8., 2., 7., 7., 9., 0., 9., 0., 1., 4., 2., 4., 9., 6., 8.,
6., 1., 6., 3., 8., 3., 4., 5., 0., 2., 1., 8., 2., 2., 8., 7., 0., 7., 7., 3., 4., 5., 0.,
7., 2., 1., 1., 4., 2., 9., 9., 6., 1., 5., 4., 6., 9., 5., 4., 1., 9., 1., 5., 5., 5., 8.,
8., 0., 1., 3., 0., 8., 8., 5., 1., 6., 1., 5., 6., 4., 4., 4., 0., 1., 1., 5., 1., 7., 2.,
3., 5., 5., 4., 9., 1., 3., 7., 6., 7., 1., 5., 3., 8., 6., 6., 6., 7., 3., 2., 2., 8., 1.,
3., 0., 2., 7., 6., 5., 7., 5., 7., 8., 1., 2., 2., 5., 0., 2., 9., 1., 5., 3., 8., 7., 9.,
7., 2., 8., 8., 8., 6., 3., 2., 7., 7., 0., 3., 7., 8., 3., 7., 2., 3., 2., 7., 5., 5., 6.,
0., 9., 0., 9., 9., 1., 8., 7., 9., 6., 8., 7., 5., 4., 9., 5., 6., 3., 2., 8., 3., 0., 6.,
3., 8., 3., 1., 8., 7., 2., 0., 7., 7., 7., 7., 8., 0., 4., 9., 8., 2., 0., 4., 4., 3., 5.,
5., 3., 0., 3., 6., 3., 1., 2., 9., 9., 6., 8., 1., 2., 6., 8., 6., 0., 0., 2., 8., 8., 5.,
0., 5., 9., 0., 8., 1., 1., 3., 5., 9., 3., 5., 8., 6., 3., 2., 9., 4., 8., 3., 9., 5., 2.,
9., 0., 1., 6., 8., 0., 3., 0., 1., 2., 1., 0., 1., 4., 1., 1., 0., 6., 9., 2., 7., 2., 6.,
0., 4., 8., 2., 6., 7., 2., 2., 7., 4., 5., 8., 1., 4., 7., 5., 9., 7., 2., 5., 9., 1., 6.,
1., 7., 9., 5., 6., 9., 3., 5., 1., 6., 1., 3., 3., 9., 3., 9., 0., 1., 8., 1., 9., 8., 5.,
3., 4., 4., 1., 5., 5., 4., 4., 5., 8., 7., 1., 1., 7., 3., 9., 0., 1., 3., 4., 8., 4., 0.,
5., 6., 2., 0., 7., 8., 2., 6., 2., 9., 6., 2., 0., 3., 7., 5., 7., 1., 8., 5., 5., 9., 1.,
0., 3., 5., 7., 5., 3., 2., 8., 6., 3., 0., 5., 8., 5., 7., 8., 8., 2., 9., 0., 1., 8., 6.,
0., 3., 2., 5., 2., 9., 8., 9., 6., 2., 0., 3., 2., 5., 9., 1., 3., 6., 5., 2., 8., 2., 2.,
1., 8., 6., 4., 1., 6., 0., 7., 3., 0., 9., 6., 5., 5., 5., 2., 4., 2., 8., 3., 0., 6., 3.,
8., 8., 4., 9., 4., 7., 0., 3., 5., 1., 4., 6., 0., 0., 5., 9., 7., 8., 6., 7., 0., 6., 7.,
0., 5., 8., 8., 6., 4., 6., 0., 2., 3., 2., 8., 7., 5., 9., 6., 6., 2., 0., 4., 4., 4., 4.,
2., 7., 5., 3., 2., 6., 3., 7., 0., 7., 2., 5., 1., 4., 4., 5., 1., 6., 7., 5., 7., 0., 7.,
8., 4., 7., 3., 9., 1., 7., 5., 6., 1., 0., 2., 0., 0., 5., 5., 8., 8., 7., 3., 7., 2., 9.,
3., 8., 4., 5., 3., 8., 5., 2., 0., 2., 0., 5., 9., 0., 3., 8., 0., 4., 1., 8., 4., 8., 9.,
1., 1., 4., 5., 0., 2., 0., 9., 4., 2., 3., 9., 0., 7., 3., 1., 5., 9., 1., 6., 5., 4., 2.,
1., 2., 1., 1., 4., 7., 2.,
]);

let model = luminal_nn::Conv2D::<CH_IN, CH_OUT, KERNELX, KERNELY>::initialize(&mut cx);
model.weight.set(vec![
0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300,
0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, 0.1500,
0.0700, -0.0800, 0.1700, 0.1000, -0.0700, 0.1600, -0.1600, -0.1900, -0.0500, -0.2100,
0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400,
]);

let mut out1 = model
.forward::<DIMX_IN, DIMY_IN, DIMX_OUT, DIMY_OUT>(inp1)
.retrieve();

cx.compile(<(GenericCompiler, CudaCompiler<f32>)>::default(), &mut out1);
cx.execute();

assert_close(
&out1.data(),
&[
3.9600, -0.3300, -1.7800, 4.0400, 1.5300, 0.2900, 2.8700, 3.0000, 0.9600, -1.8700,
4.5900, 3.9700, 1.2800, 1.1800, 3.7800, 2.8500, 0.5500, 0.5600, 3.9800, 1.3200,
-0.7100, -0.6500, 4.3900, 0.4000, 1.0300, 0.9800, 3.1200, 2.7400, 2.5100, 0.1200,
1.8500, 2.0000, -0.7900, 1.0700, -0.3900, -0.8100, -2.5100, -2.9700, 0.2100, 1.8400,
-0.7700, -0.3900, 1.2200, 0.1900, 4.1700, -4.3600, -1.8600, 0.4800, -2.4400, 2.6300,
1.5000, -1.9700, 1.2800, -2.8200, -2.3200, 0.2200, -0.3800, 2.1800, -0.8200, -1.5700,
1.2000, -3.4200, -1.6700, 0.9000,
],
);
}
59 changes: 59 additions & 0 deletions src/generic_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,65 @@ impl Compiler for ArithmeticElimination {
graph.remove_node(out);
graph.safe_remove_node(intermediate, 0);
}

// exp2(log2(x))
let inp = node();
let intermediate = unary::<Exp2>(inp.clone());
let out = unary::<Log2>(intermediate.clone());
let mut s = out.clone().search(graph);
while s.next_match() {
let (inp, intermediate, out) = (s.get(&inp), s.get(&intermediate), s.get(&out));
if graph.no_delete.contains(&intermediate) {
continue;
}
// Carry over outgoing edges
let input_shape = graph
.graph
.edges_connecting(inp, intermediate)
.find_map(|e| e.weight().as_data())
.unwrap()
.2;
if input_shape.is_reshaped() {
// If any output shape is non-contiguous, we need to keep the op for it's contiguous functionality TODO: replace with explicit contiguous op here
if graph
.graph
.edges_connecting(inp, intermediate)
.filter_map(|e| e.weight().as_data())
.any(|(_, _, sh)| sh.is_reshaped())
|| graph
.graph
.edges_connecting(intermediate, out)
.filter_map(|e| e.weight().as_data())
.any(|(_, _, sh)| sh.is_reshaped())
{
continue;
}
for (weight, target) in graph
.graph
.edges_directed(intermediate, petgraph::Direction::Outgoing)
.map(|e| (*e.weight(), e.target()))
.collect::<Vec<_>>()
{
if let Some(weight) = weight.as_data() {
graph.graph.add_edge(
inp,
target,
Dependency::Data {
input_order: weight.0,
output_order: weight.1,
shape: input_shape,
},
);
}
}
} else {
move_outgoing_edge(out, inp, &mut graph.graph);
}
remap(intermediate, inp, &mut ids, graph);
remap(out, inp, &mut ids, graph);
graph.remove_node(out);
graph.safe_remove_node(intermediate, 0);
}
}
}

Expand Down
8 changes: 7 additions & 1 deletion src/shape/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,12 @@ fn make_rules() -> Vec<Rewrite> {
// Other
rewrite!("distribute"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
rewrite!("factor" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),
rewrite!("group-terms"; "(+ ?a ?a)" => "(* 2 ?a)"),
rewrite!("distribute-mod"; "(* (% ?b ?c) ?a)" => "(% (* ?b ?a) (* ?c ?a))"),
rewrite!("explicit-truncate"; "(* (/ ?a ?b) ?b)" => "(- ?a (% ?a ?b))"),
rewrite!("mul-mod"; "(% (* ?a ?b) ?b)" => "0"),
// rewrite!("mul-distribute"; "(* ?a (% (/ ?b ?c) ?d))" => "(% (/ ?b (* ?c ?a)) (* ?d ?a))"),
// rewrite!("div-mod-mul"; "(% (/ ?a ?b) ?c)" => "(% ?a (* ?b ?c))"),
]
}

Expand All @@ -941,7 +947,7 @@ fn egg_simplify<S: ExpressionStorage>(expr: GenericExpression<S>) -> GenericExpr
// Simplify
let runner = Runner::default()
.with_expr(&expr)
.with_iter_limit(10)
.with_iter_limit(100)
.run(&make_rules());
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, best) = extractor.find_best(runner.roots[0]);
Expand Down

0 comments on commit 4db3120

Please sign in to comment.