Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient, universal, standalone Jacobian backend #203

Open
wiseodd opened this issue Jul 4, 2024 · 2 comments
Open

Efficient, universal, standalone Jacobian backend #203

wiseodd opened this issue Jul 4, 2024 · 2 comments
Assignees
Labels
enhancement New feature or request
Milestone

Comments

@wiseodd
Copy link
Collaborator

wiseodd commented Jul 4, 2024

  • Functorch = memory blowup due to vmap
  • Asdl/asdfghjkl = can't backprop through the Jacobians => can't be used for continuous BO
  • BackPACK = requires inflexible extension

We need a Jacobian backend that is:

  • memory efficient (think about LLM applications!)
  • backpropable
  • applicable to general type of models (like functorch)
  • scalable to large numbers of outputs
@wiseodd wiseodd added the enhancement New feature or request label Jul 4, 2024
@wiseodd wiseodd self-assigned this Jul 4, 2024
@aleximmer
Copy link
Owner

In my experience, asdl works well for differentiable Jacobian computation. However, I would hope that curvlinops eventually offers such functionality as well since asdl is not actively maintained and, specifically, the differentiability aspect is only available in branches that have not been merged. It also is not present anymore in asdfghjkl, where differentiability broke for some reason I couldn't figure out.

@wiseodd wiseodd added this to the 0.3 milestone Jul 8, 2024
@wiseodd
Copy link
Collaborator Author

wiseodd commented Jul 10, 2024

Discussion result with @f-dangel

Let J(x)SJ(x)^T: \R^k \to R^k, i.e. (J(x)SJ(x)^T)(v) for v \in \R^k.

Computation with vmap (for small K or want differentiability)

Then vmap(J(x)SJ(x)^T)(I) gives us the var(f(x)). Note that this is vmap over num classes, unlike the current Jacobian implementation.

def model_fn_params_only(params_dict, buffers_dict):
      out = torch.func.functional_call(self.model, (params_dict, buffers_dict), x)
      return out, out

_, vjpfunc = torch.func.vjp(model_fn_params_only, self.params_dict, has_aux=True)
_, jvpfunc = torch.func.jvp(model_fn_params_only, self.params_dict, has_aux=True)

def JSJT(v):
    v = vjpfunc(v)
    v = S @ v
    v = jvpfunc(v)
    return v

func_var = vmap(JSJT)(I)

Computation with for-loop (for large K & don't need differentiability)

  • stack
    • (k, k)
    • More memory efficient than computing J(x)SJ(x)^T explicitly
      from J(x) and S since we only store (k, 1) or (p, 1) tensor each time.
func_var = torch.stack([JSJT(v).detach() for v in I])

If only care about diag of func_var:

func_var = torch.stack([JSJT(v).detach()[i] for (i, v) in enumerate(I)])

Sampling

For sampling f(x) this can be done cheaply (see Laurence paper).
For LLMs, this might be better because we don't really care about the explicit J(x)SJ(x),
but only the resulting \int softmax(f(x)) N(f(x) | f_\theta(x), J(x)SJ(x)^T) df(x)
I.e. the costs is now wrt. number of samples, instead of K.

Further things

All of them can be optimized further depending on the form of S

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants