-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add continous value (plDDT and PPL) support for Curriculum Learning #58
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Leo-T-Zang - nice work!
Had a couple questions/comments before approving:
- What is the role of the new fields in the yaml files (eg., ppl_category, plddt_category)? They do not seem to be used in the rest of the code
- We probably want to generate a sequence of amino acids here instead :)
- I would add two small scripts in dataset.py to compute the perplexity metric based on ESM2, and the pLDDT based on ESMFold - and have a separate test to ensure they return correct values. In practice we will then run these scripts on the training datasets to pre-compute the values necessary to the CL scheme you implemented
Another comment as Im re-reading the test routine: we should check that the variable of interest (eg., ppl) is ordered throughout training across mini batches (not within a mini batch). |
@pascalnotin I have a question. my PR was computing length within the batch_set_curriculum_learning_column function but through this PR we are assuming the CL column is already precomputed and stored within the function prior to the function call for batch_set_curriculum_learning_column. For example, my sequence length strategy is computing the negative of the length inside the batch_set function here:
whereas Leo's PR is relying on the fact that perplexity/pldt is already precomputed (see below). The code line below is indexing from ppl/pldt column under the assumption it is already computed.
To make my PR in line with leo's,we may need to make this assumption (CL column is precomputed ) applicable for all strategies . How do you want to go about it?
Personally, in my opinion, it makes sense for the time being to go with S2 because it makes the computation apparent whereas assuming it already exists could make for reproducibility issues although the latter is computational more feasible but im all ears what both @pascalnotin and @Leo-T-Zang have to say. Currently with S2 we will be computing the metric everytime but it makes sense to compute it only once and store the precomputed dataset somewhere like HF/Dropbox/Google drive, it could help as we wont need to precompute these metrics everytime . Also my apologies for creating random DNA sequences rather than the Amino acids! |
I can make this change in a subsequent PR |
That's a good point. I assumed things would be different for sequence length vs other CL strategies bc the length of input can be straightforwardly computed vs other strategies can be more involved (ie., calls for separate models such as Tranception to compute the perplexity or ESMfold to compute the plddt). But these more cpx strategies can also be handled via the batch_set function (as per your point S2) then it's probably easier to go with that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merging this PR based on pre-xmas discussion: we will pre-compute values to be used for the Curriculum Learning and store them together with the overarching cluster mapping file. This mapping file will thus contain cluster name, cluster representative sequence, pointer to file location on disk where all sequences in that cluster are stored, plDDT and PPL for the cluster representative sequence computed with a pretrained model (eg., ESM2 or Tranception).
@Leo-T-Zang @talkhanz @jamaliki for reference ^^ |
Both plDDT and PPL values should be pre-computed and saved within dataset.