Reimplementation of the UniRep protein featurization model in JAX.
The UniRep model was developed in George Church's lab, see the original publication here (bioRxiv) or here (Nature Methods), as well as the repository containing the original model.
The idea to reimplement the TF-based model in the much lighter JAX framework was coined by Eric Ma, who also developed a first version of it inside his functional deep-learning library fundl.
This repo is a self-contained version of the UniRep model (so far only the 1900 hidden-unit mLSTM), adapted and extended from fundl.
Ensure that your compute environment allows you to run JAX code. (A modern Linux or macOS with a GLIBC>=2.23 is probably necessary.)
For now, jax-unirep
is available by pip installing from source.
Installation from GitHub:
pip install git+https://github.com/ElArkk/jax-unirep.git
To generate representations of protein sequences,
pass a list of sequences as strings
or a single sequence to jax_unirep.get_reps
.
It will return a tuple consisting of the
following representations for each sequence:
h_avg
: Average hidden state of the mLSTM over the whole sequence.h_final
: Final hidden state of the mLSTMc_final
: Final cell state of the mLSTM
From the original paper,
h_avg
is considered the "representation" (or "rep") of the protein sequence.
Only valid amino acid sequence letters belonging to the set:
MRHKDESTNQCUGPAVIFYWLOXZBJ
are allowed as inputs to get_reps
.
They may be passed in as a single string or an iterable of strings,
and need not necessarily be of the same length.
In Python code, for a single sequence:
from jax_unirep import get_reps
sequence = "ASDFGHJKL"
# h_avg is the canonical "reps"
h_avg, h_final, c_final = get_reps(sequence)
And for multiple sequences:
from jax_unirep import get_reps
sequences = ["ASDF", "YJKAL", "QQLAMEHALQP"]
# h_avg is the canonical "reps"
h_avg, h_final, c_final= get_reps(sequences)
# each of the arrays will be of shape (len(sequences), 1900),
# with the correct order of sequences preserved
In the original paper the concept of 'evolutionary finetuning' is introduced,
where the pre-trained mLSTM weights get fine-tuned through weight-updates
using homolog protein sequences of a given protein of interest as input.
This feature is available as well in jax-unirep
.
Given a set of starter weights for the mLSTM (defaults to
the weights from the paper) as well as a set of sequences,
the weights get fine-tuned in such a way that test set loss
in the 'next-aa prediction task' is minimized.
There are two functions with differing levels of control available.
The evotune
function uses optuna
under the hood
to automatically find:
- the optimal number of epochs to train for, and
- the optimal learning rate,
given a set of sequences.
The study
object will contain all the information
about the training process of each trial.
evotuned_params
will contain the fine-tuned mLSTM and dense weights
from the trial with the lowest test set loss.
If you want to directly fine-tune the weights
for a fixed number of epochs
while using a fixed learning rate,
you should use the fit
function instead.
The fit
function has further customization options,
such as different batching strategies.
Please see the function docstring for more information.
You can find an example usages of both evotune
and fit
here.
If you want to pass a set of mLSTM and dense weights that were dumped in an earlier run, create params as follows:
from jax_unirep.utils import load_params
params = load_params(folderpath="path/to/params/folder")
If you want to start from randomly initialized mLSTM and dense weights instead:
from jax_unirep.evotuning import init_fun
from jax.random import PRNGKey
_, params = init_fun(PRNGKey(0), input_shape=(-1, 10))
The weights used in the 10-dimensional embedding of the input sequences always default to the weights from the paper, since they do not get updated during evotuning.
We implemented the mLSTM layers in such a way that
they are compatible with jax.experimental.stax
.
This means that they can easily be plugged into
a stax.serial
model, e.g. to train both the mLSTM
and a top-model at once:
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu
from jax_unirep.layers import mLSTM1900, mLSTM1900_AvgHidden
init_fun, apply_fun = stax.serial(
mLSTM1900(),
mLSTM1900_AvgHidden(),
Dense(512), Relu(),
Dense(1)
)
Have a look at the documentation and examples
for more information about how to implement a model in jax
.
To read more about how we reimplemented the model in JAX, we wrote it up. Both the HTML and PDF are available.
All the model weights are licensed under the terms of Creative Commons Attribution-NonCommercial 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
Otherwise the code in this repository is licensed under the terms of GPL v3.