Skip to content

irenezhang30/fax

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

83 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

fax: fixed-point jax

Implicit and competitive differentiation in JAX.

Our "competitive differentiation" approach uses Competitive Gradient Descent to solve the equality-constrained nonlinear program associated with the fixed-point problem. A standalone implementation of CGD is provided under fax/competitive/cga.py and the equality-constrained solver derived from it can be accessed via fax.constrained.cga_lagrange_min or fax.constrained.cga_ecp. An implementation of implicit differentiation based on Christianson's two-phases reverse accumulation algorithm can also be obtained with the function fax.implicit.two_phase_solver.

See fax/constrained/constrained_test.py for examples. Please note that the API is subject to change.

Installation

pip install jax-fixedpoint

References

Citing competitive differentiation:

@inproceedings{bacon2019optrl,
  author={Pierre-Luc Bacon, Florian Schaefer, Clement Gehring, Animashree Anandkumar, Emma Brunskill},
  title={A Lagrangian Method for Inverse Problems in Reinforcement Learning},
  booktitle={NeurIPS Optimization Foundations for Reinforcement Learning Workshop},
  year={2019},
  url={http://lis.csail.mit.edu/pubs/bacon-optrl-2019.pdf},
  keywords={Optimization, Reinforcement Learning, Lagrangian}
}

Citing this repo:

@misc{gehring2019fax,
  author = {Clement Gehring, Pierre-Luc Bacon, Florian Schaefer},
  title = {{FAX: differentiating fixed point problems in JAX}},
  note = {Available at: https://github.com/gehring/fax},
  year = {2019}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 64.5%
  • Jupyter Notebook 35.5%