Skip to content

Commit

Permalink
fixes for version 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
andrefaure committed Jan 20, 2024
1 parent a175cc5 commit ed815fa
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
6 changes: 3 additions & 3 deletions pymochi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""A tool to fit mechanistic models to deep mutational scanning data"""
# Import classes users will interact with

# Add imports here
from .pymochi import *
from pymochi.project import MochiProject

# Handle versioneer
from ._version import get_versions
from pymochi._version import get_versions
versions = get_versions()
__version__ = versions['version']
__git_revision__ = versions['full-revisionid']
Expand Down
8 changes: 4 additions & 4 deletions pymochi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pymochi.transformation import get_transformation
import itertools
import shutil
from functools import reduce
import functools

class ConstrainedLinear(torch.nn.Linear):
"""
Expand All @@ -26,7 +26,7 @@ def forward(
# return F.linear(input, self.weight.clamp(min=0, max=1000), self.bias)
return F.linear(input, self.weight.abs(), self.bias)

class WeightedL1Loss(torch.nn.L1Loss):
class MochiWeightedL1Loss(torch.nn.L1Loss):
"""
A weighted version of L1Loss with no reduction.
"""
Expand Down Expand Up @@ -826,7 +826,7 @@ def get_additive_trait_weights(
#Remove weights not reported on by a single corresponding phenotype
at_list[-1][-1] = at_list[-1][-1].loc[mask!=0,:]
#Merge weight data frames corresponding to different folds
at_list[-1] = reduce(lambda x, y: pd.merge(x, y, how='outer', on = ['id', 'id_ref', 'Pos', 'Pos_ref']), at_list[-1])
at_list[-1] = functools.reduce(lambda x, y: pd.merge(x, y, how='outer', on = ['id', 'id_ref', 'Pos', 'Pos_ref']), at_list[-1])
fold_cols = [i for i in list(at_list[-1].columns) if not i in ['id', 'id_ref', 'Pos', 'Pos_ref']]
at_list[-1]['n'] = at_list[-1].loc[:,fold_cols].notnull().sum(axis=1)
at_list[-1]['mean'] = at_list[-1].loc[:,fold_cols].mean(axis=1)
Expand Down Expand Up @@ -1326,7 +1326,7 @@ def fit(

#Construct loss function and Optimizer
if loss_function_name == 'WeightedL1':
loss_function = WeightedL1Loss()
loss_function = MochiWeightedL1Loss()
elif loss_function_name == 'GaussianNLL':
loss_function = MochiGaussianNLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
Expand Down
23 changes: 22 additions & 1 deletion pymochi/tests/test_pymochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ def test_MochiData_invalid_features_argument_features(capsys):
print(captured.out)
assert captured.out.split("\n")[-2] == "Error: Invalid feature names." and e_info

def test_MochiData_invalid_features_argument_missingWT(capsys):
"""Test MochiData initialization with invalid features argument missing WT"""
model_design = pd.read_csv(Path(__file__).parent.parent / "data/model_design.txt", sep = "\t", index_col = False)
model_design['file'] = [
str(Path(__file__).parent.parent / "data/fitness_abundance.txt"),
str(Path(__file__).parent.parent / "data/fitness_binding.txt")]
#Create a problematic features dict
features = {
'Folding': ["WT"],
'Binding': ["Hello"]}
with pytest.raises(ValueError) as e_info:
MochiData(
model_design = model_design,
features = features)
captured = capsys.readouterr()
print(captured.out)
assert captured.out.split("\n")[-2] == "Error: 'WT' missing for one or more traits in 'features' argument." and e_info

def test_MochiData_features_argument_Nonekey(capsys):
"""Test MochiData initialization with features argument None key"""
model_design = pd.read_csv(Path(__file__).parent.parent / "data/model_design.txt", sep = "\t", index_col = False)
Expand Down Expand Up @@ -186,7 +204,10 @@ def test_fit_best_exploded_grid_search_models(capsys):
training_resample = True,
early_stopping = True,
scheduler_gamma = mochi_task.scheduler_gamma,
scheduler_epochs = 10)
scheduler_epochs = 10,
loss_function_name = 'WeightedL1',
sos_architecture = [20],
sos_outputlinear = False)
model.training_history['val_loss'] = [1.0, 1.0, np.nan]
#Fit best model
with pytest.raises(ValueError) as e_info:
Expand Down

0 comments on commit ed815fa

Please sign in to comment.