Skip to content

Reimplementation of the UniRep protein featurization model.

License

Notifications You must be signed in to change notification settings

andrewfavor95/jax-unirep

 
 

Repository files navigation

Build Status Code style: black

jax-unirep

Reimplementation of the UniRep protein featurization model in JAX.

The UniRep model was developed in George Church's lab, see the original publication here (bioRxiv) or here (Nature Methods), as well as the repository containing the original model.

The idea to reimplement the TF-based model in the much lighter JAX framework was coined by Eric Ma, who also developed a first version of it inside his functional deep-learning library fundl.

This repo is a self-contained version of the UniRep model (so far only the 1900 hidden-unit mLSTM), adapted and extended from fundl.

Installation

Ensure that your compute environment allows you to run JAX code. (A modern Linux or macOS with a GLIBC>=2.23 is probably necessary.)

For now, jax-unirep is available by pip installing from source.

Installation from GitHub:

pip install git+https://github.com/ElArkk/jax-unirep.git

Usage

Getting UniReps

To generate representations of protein sequences, pass a list of sequences as strings or a single sequence to jax_unirep.get_reps. It will return a tuple consisting of the following representations for each sequence:

  • h_avg: Average hidden state of the mLSTM over the whole sequence.
  • h_final: Final hidden state of the mLSTM
  • c_final: Final cell state of the mLSTM

From the original paper, h_avg is considered the "representation" (or "rep") of the protein sequence.

Only valid amino acid sequence letters belonging to the set:

MRHKDESTNQCUGPAVIFYWLOXZBJ

are allowed as inputs to get_reps. They may be passed in as a single string or an iterable of strings, and need not necessarily be of the same length.

In Python code, for a single sequence:

from jax_unirep import get_reps

sequence = "ASDFGHJKL"

# h_avg is the canonical "reps"
h_avg, h_final, c_final = get_reps(sequence)

And for multiple sequences:

from jax_unirep import get_reps

sequences = ["ASDF", "YJKAL", "QQLAMEHALQP"]

# h_avg is the canonical "reps"
h_avg, h_final, c_final= get_reps(sequences)

# each of the arrays will be of shape (len(sequences), 1900),
# with the correct order of sequences preserved

Evotuning

In the original paper the concept of 'evolutionary finetuning' is introduced, where the pre-trained mLSTM weights get fine-tuned through weight-updates using homolog protein sequences of a given protein of interest as input. This feature is available as well in jax-unirep. Given a set of starter weights for the mLSTM (defaults to the weights from the paper) as well as a set of sequences, the weights get fine-tuned in such a way that test set loss in the 'next-aa prediction task' is minimized. There are two functions with differing levels of control available.

The evotune function uses optuna under the hood to automatically find:

  1. the optimal number of epochs to train for, and
  2. the optimal learning rate,

given a set of sequences. The study object will contain all the information about the training process of each trial. evotuned_params will contain the fine-tuned mLSTM and dense weights from the trial with the lowest test set loss.

If you want to directly fine-tune the weights for a fixed number of epochs while using a fixed learning rate, you should use the fit function instead. The fit function has further customization options, such as different batching strategies. Please see the function docstring for more information.

You can find an example usages of both evotune and fit here.

If you want to pass a set of mLSTM and dense weights that were dumped in an earlier run, create params as follows:

from jax_unirep.utils import load_params

params = load_params(folderpath="path/to/params/folder")

If you want to start from randomly initialized mLSTM and dense weights instead:

from jax_unirep.evotuning import init_fun
from jax.random import PRNGKey

_, params = init_fun(PRNGKey(0), input_shape=(-1, 10))

The weights used in the 10-dimensional embedding of the input sequences always default to the weights from the paper, since they do not get updated during evotuning.

UniRep stax

We implemented the mLSTM layers in such a way that they are compatible with jax.experimental.stax. This means that they can easily be plugged into a stax.serial model, e.g. to train both the mLSTM and a top-model at once:

from jax.experimental import stax
from jax.experimental.stax import Dense, Relu

from jax_unirep.layers import mLSTM1900, mLSTM1900_AvgHidden

init_fun, apply_fun = stax.serial(
    mLSTM1900(),
    mLSTM1900_AvgHidden(),
    Dense(512), Relu(),
    Dense(1)
)

Have a look at the documentation and examples for more information about how to implement a model in jax.

More Details

To read more about how we reimplemented the model in JAX, we wrote it up. Both the HTML and PDF are available.

License

All the model weights are licensed under the terms of Creative Commons Attribution-NonCommercial 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

Otherwise the code in this repository is licensed under the terms of GPL v3.

About

Reimplementation of the UniRep protein featurization model.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.5%
  • Makefile 0.5%