Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic operation fusion #169

Open
Bihaqo opened this issue Jan 9, 2019 · 1 comment
Open

Automatic operation fusion #169

Bihaqo opened this issue Jan 9, 2019 · 1 comment

Comments

@Bihaqo
Copy link
Owner

Bihaqo commented Jan 9, 2019

For example, instead of writing a separate function project_matmul we can implement something like project_matmul = fuse(lambda a, b: project(a, b), lambda b, c: matmul(b, c))

To do that lets express each operation as a sequence of recurrent steps of the form

def recurrent(a, b):
  res = 1.0
  for a_core, b_core in zip(a.tt_cores, b.tt_cores):
    res = einsum('rule', a_core, b_core, res)

and of independent steps

def independent(a, b):
  res_cores = []
  for a_core, b_core in zip(a.tt_cores, b.tt_cores):
    res.append(einsum('rule', a_core, b_core))

Then, we can automatically concat einsum-s of individual operations into a single big einsum (per core), and by using opt_einsum guarantee that the restulting einsum will be fast.

From the top of my head we can support any combinations of

  1. matmul(A, B)
  2. add a + b
  3. elementwise product a * b

Additionally, as the last operation of the combination, we can support

  1. dot product a^t b
  2. gram matrix G_ij = ai^t bj
  3. projection on the tangent space P_x y
  4. trace

By combining this ops we can for example automatically get fast versions of

  1. x^t A y (already implemented as a separate fast operation)
  2. ||A B||
  3. A B x
  4. P_x A y (already implemented)
  5. ||a * b||
  6. Px A B y
  7. ||A + B||
  8. P_x (a * b)
  9. x^t A B y
  10. ||(Ax) * (By)||
  11. trace(A^T B A)

Does anyone need this?

@Bihaqo
Copy link
Owner Author

Bihaqo commented Jan 27, 2019

A potential way to design the API:

with t3f.Fuse() as f:
  Ax = t3f.matmul(A, x)
  xAx = t3f.flat_inner(x, Ax)
  fast_xAx = f.optimize(xAx)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant