Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Symbolic Algebra Library #47

Open
Tracked by #29
jafioti opened this issue Apr 9, 2024 · 28 comments
Open
Tracked by #29

Better Symbolic Algebra Library #47

jafioti opened this issue Apr 9, 2024 · 28 comments

Comments

@jafioti
Copy link
Owner

jafioti commented Apr 9, 2024

Currently luminal uses a small symbolic algebra library I wrote to do expressions in src/shape/symbolic.rs.

The pro is that it's very simple and easy to reason about. The con is that it's very bad at simplifying complex expressions, so we get index expressions that look like this:

#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device float* input0 [[buffer(0)]], device float* input1 [[buffer(1)]], device float* input2 [[buffer(2)]], device float *out [[buffer(3)]], device uint& n_elements [[buffer(4)]], uint idx [[thread_position_in_grid]], device int& s [[buffer(5)]], device int& t [[buffer(6)]], device int& p [[buffer(7)]]) {
    if (idx < n_elements) {
        float intermediate0 = (((int)((((int)idx/(128*t))%s) < s) != 0) ? (float)input1[(((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))%2)+(((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/2)%64))+((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/128)%s)*64))+((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/(128*s))%32)*(64*s)))] + (float)input2[((((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))%2)-1)+(((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/2)%64))+((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/128)%s)*64))+((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/(128*s))%32)*(64*s)))] : 0.0);
        out[idx] = (float)(intermediate0 * (float)input0[((((int)idx%128)+((((int)idx/128)%(p+max((int)s, (int)0)))*128))+((((int)idx/(((128*(p+max((int)s, (int)0)))*s)*4))%8)*(128*(p+max((int)s, (int)0)))))]);
    }
}

We need to balance power and simplicity. I don't want to write a huge symbolic algebra library in the core, but it needs to be more capable than the current system. Ideally we can 80/20 this and get 80% of the simplifications with 20% of the code. Doesn't need to be perfect.

So the options are:

  • Use an existing library (rusymbols or savage)
  • Keep iterating on the current symbolic.rs
  • Spin out the symbolic library to a new crate, luminal_symbolic

I think the last one is the correct approach. Symbolic libraries can get very complex, but all that complexity can be effectively bottlenecked by the Expression type. And they can be unit tested quite well. Having a seperate crate will allow for more complexity headroom in the simplification logic, while keeping the core of luminal clean, and allowing us to write extensive unit tests to ensure the library works.

Why not use another crate? I haven't found a crate that has all the needed ops (max, less than, mod, etc) and has very few other ops (which is needed to keep functions like expr_to_metal_string simple.

@jafioti
Copy link
Owner Author

jafioti commented Apr 9, 2024

This was referenced Apr 9, 2024
@jafioti
Copy link
Owner Author

jafioti commented Apr 13, 2024

I've fed the index expression above into sympy (likely one of the most mature symbolic algebra libraries out there) and it wasn't really able to minimize it:
image
Which tells me we need to better construct these equations, rather than naively build them and hope the symalg library can reduce them

@jafioti
Copy link
Owner Author

jafioti commented Apr 28, 2024

Dimension combination has been added, which slightly improves the generated equations. Adding ranges to the terms will allow for things like i % 15 where i has a min of 0 and max of 15 to be reduced to just i.

@NewBornRustacean
Copy link
Contributor

Hello @jafioti !
I'm looking around this issue. Is there any progress?

The way I see, you're planning to build a new, distilled symbolic computation crate(am I get it well?).
If so, would the desired output be a reduced(simpler, more calculated) form of a result of input equation?

@jafioti
Copy link
Owner Author

jafioti commented May 7, 2024

Yeah that's correct. The goal is to reduce the expressions to a minimum mathematically equivalent form so it's most efficient to compute many times over.

I've been working on this the past few days, hope to push soon

@jafioti
Copy link
Owner Author

jafioti commented May 7, 2024

I'm using cas-rs for doing simplification, the remaining issue is that it doesn't support Mod, which we need. Looking at implementing that now

@jafioti
Copy link
Owner Author

jafioti commented May 7, 2024

This is being worked on in the cas branch

@genos
Copy link

genos commented May 7, 2024

I've had good luck with egg before, especially if your needs are a little more specialized/bespoke than, say, cas-rs.

@jafioti
Copy link
Owner Author

jafioti commented May 7, 2024

@genos Does egg do symbolic algebra reductions? I didn't see much example code / documentation

@genos
Copy link

genos commented May 7, 2024

@jafioti You write the simplification rules yourself, like in the docs.rs example. The rest of those docs were helpful to me in a previous project, though they required some digging. The website is a good resource as well.

@YichengDWu
Copy link

What is this complicated indexing coming from?

@jafioti
Copy link
Owner Author

jafioti commented May 8, 2024

The shape tracker can do zero cost movements like transpose or slicing, and then it generates an indexing expression to convert logical indexes to physical indexes.

@YichengDWu
Copy link

You can mostly get rid of them with nested views. See what I did here.tinygrad/tinygrad#3988

@jafioti
Copy link
Owner Author

jafioti commented May 8, 2024

@genos egg is amazing! I've been looking at it yesterday and today, and switched from cas-rs over to it since it's much more flexible and can easily reach the same level of reductions cas-rs did but in a more robust way. Thanks for the suggestion!

There's still a few bugs with it (conv2d still fails for some reason) and compile times are longer since it's not that fast, but it's definitely the right approach in this case.

@jafioti
Copy link
Owner Author

jafioti commented May 9, 2024

You can mostly get rid of them with nested views. See what I did here.tinygrad/tinygrad#3988

@YichengDWu
This looks very interesting, did you base it off a paper or something? I'd like to read more

@YichengDWu
Copy link

Cutlass/CuTe

@genos
Copy link

genos commented May 10, 2024

compile times are longer since it's not that fast

@jafioti two things come to mind concerning optimizing egg usage from previous experience:

  1. If you don't mind the extra dep (though it's transitively required by egg), building the SExp directly rather than using .parse() can offer a speed up.
  2. Carefully monitoring the rules you create and their usage, and trimming down as much as possible may help. For instance, full commutativity and associativity can blow up the search space a lot, though I admit it's hard to live without them.

@jafioti
Copy link
Owner Author

jafioti commented May 11, 2024

@genos Is there any examples for constructing an SExp and producing a RecExpr? Is there a way to directly build a RexExpr?

@genos
Copy link

genos commented May 12, 2024

@jafioti not that I recall; I went through the code (linked from the docs.rs for egg) by hand and found what .parse does.

@jafioti
Copy link
Owner Author

jafioti commented May 27, 2024

Egg is great, I think we'll stick with it for the expressions. Seems much better than other cas systems. Only thing remaining before I close this issue is more efficient conversion from luminal::Expression to and from egg::Expression.

@genos
Copy link

genos commented May 27, 2024

Glad my suggestion was helpful! If you want more speed, I still recommend building SExprs by hand. If I get some free time before you manage it, I’ll see if I can be of use.

@jafioti
Copy link
Owner Author

jafioti commented May 27, 2024

@genos I've added the code to build SExprs, .parse is no longer used. The next step is to build egg RecExpr directly, to avoid going through SExpr altogether. The other thing that needs to be done is work out a more efficient way to go from RecExpr -> luminal::Expression, because right now that function creates a ton of vectors which are directly consumed.

@genos
Copy link

genos commented May 28, 2024

@jafioti it turns out my memory was a bit hazy; we used egg at first to get as much simplification as we could, but eventually chucked it and instead simplified everything by hand because it was faster. We may not have known what simplifications were needed, however, without first turning to egg.

@genos
Copy link

genos commented May 30, 2024

@jafioti it occurs to me that with the way expressions are represented in luminal (with vectors), you've got an efficient way to compare expressions by size, looking at their .len(). As such, simplifying by hand (though perhaps tedious to write and test) will be quite speedy; a recursive function simplify(e: Expression, fuel: usize) -> Expression similar to what we wrote in the above PR, with fuel ticking down towards zero, and greedily applying whichever first entry a list of possible simplifications shrinks the size of your expression should do the trick.

@jafioti
Copy link
Owner Author

jafioti commented May 31, 2024

@genos Yeah that's somewhat similar to what we had, but I really liked the idea of an e-graph based solution with simple rewrite rules that can compose together to do complex rewrites. It isn't obvious to me if you greedily choose the next rewrite and wind up at nearly as good of simplifications as e-graphs yield

Or do you mean like recursively search down a tree of each rewrite and the shortest equation bubbles up?
Also another issue we have is that it's harder to do rewrites in reverse polish notation (though definitely possible)

@genos
Copy link

genos commented May 31, 2024

Agreed, e-graphs seem much more likely to find good simplifications than simplistic greedy search. And I thought I was the only one having trouble working with the RPN setup! 😆

@jafioti
Copy link
Owner Author

jafioti commented May 31, 2024

I ended up using RPN just cause it was super simple to store, and I want ShapeTracker to be stored on the stack so Expression had to be stored on the stack (no recursive types). I'm thinking of switching it to be more like the postfix egg uses, where you can have terms in rpn that reference other terms so that you can use common subexpressions, which should greatly speed up the translation from egg to luminal

@asukaminato0721
Copy link

in egg, it's possible to define the cost function. By default it's AstSize, but can be any other function.

For example, if we want eliminate some op, set their cost to a huge number.

pub struct MathCostFn;
impl egg::CostFunction<Math> for MathCostFn {
    type Cost = usize;
    fn cost<C>(&mut self, enode: &Math, mut costs: C) -> Self::Cost
    where
        C: FnMut(Id) -> Self::Cost,
    {
        let op_cost = match enode {
            Math::Diff(..) => 1000,
            Math::Integral(..) => 1000,
            _ => 1,
        };
        enode.fold(op_cost, |sum, i| sum + costs(i))
    }
}

taken from egg's test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants