Skip to content

Commit

Permalink
add align_additive_traits, maximize_correlations_k_means, add_additiv…
Browse files Browse the repository at this point in the history
…e_trait_statistics
  • Loading branch information
andrefaure committed Aug 13, 2024
1 parent f40df12 commit f4bcaf5
Showing 1 changed file with 133 additions and 10 deletions.
143 changes: 133 additions & 10 deletions pymochi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import itertools
import shutil
import functools
from sklearn.cluster import KMeans

class ConstrainedLinear(torch.nn.Linear):
"""
Expand Down Expand Up @@ -746,6 +747,124 @@ def get_linear_weights(
for i in range(len(at_list)):
at_list[i].to_csv(os.path.join(directory, "linears_weights_"+self.data.phenotype_names[i]+".txt"), sep = "\t", index = False)

def add_additive_trait_statistics(
self,
input_list,
RT = None):
"""
Add statistics to additive trait weights (mean, std, ci95 etc.).
:param input_list: list of of data frames (one per additive trait).
:param RT: R=gas constant (in kcal/K/mol) * T=Temperature (in K) (optional).
:returns: A list of data frames (one per additive trait).
"""

#Calculate summary metrics for each additive trait (mean, std, ci95 etc.)
for i in range(len(input_list)):
fold_cols = [j for j in list(input_list[i].columns) if j.startswith("fold_")]
input_list[i]['n'] = input_list[i].loc[:,fold_cols].notnull().sum(axis=1)
input_list[i]['mean'] = input_list[i].loc[:,fold_cols].mean(axis=1)
input_list[i]['std'] = input_list[i].loc[:,fold_cols].std(axis=1)
input_list[i]['ci95'] = input_list[i]['std']*1.96*2
input_list[i]['trait_name'] = self.data.additive_trait_names[i]
if RT!=None:
input_list[i]['mean_kcal/mol'] = input_list[i]['mean']*RT
input_list[i]['std_kcal/mol'] = input_list[i]['std']*RT
input_list[i]['ci95_kcal/mol'] = input_list[i]['std_kcal/mol']*1.96*2

return input_list

def maximize_correlations_k_means(
self,
dfs):
"""
Maximizes within-DataFrame correlations using K-Means clustering, preserving column order.
:param input_dfs: list of pandas DataFrames.
:returns: A list of optimized DataFrames (one per additive trait).
"""

#DataFrame parameters
num_dfs = len(dfs)
n = len(dfs[0].columns)
column_names = list(dfs[0].columns)*num_dfs

#Concatenate DataFrames along rows
data = pd.concat([df.T.reset_index(drop=True) for df in dfs], axis=0)

#Perform K-Means clustering while dropping coefficients that have NAs
kmeans = KMeans(n_clusters=num_dfs, random_state=0).fit(np.array(data.dropna(axis=1, how='any')))
labels = kmeans.labels_

#Change labels to preserve df order
trans_dict = {}
count = 0
for i in labels.reshape((num_dfs,n)):
trans_dict[pd.Series(i).value_counts().index[0]] = count
count += 1
try:
labels = np.array([trans_dict[i] for i in labels])
except KeyError:
print(f"Warning: Aligning additive traits using K-means clustering failed.")
return dfs

#Assign columns to DataFrames, preserving original order
optimized_dfs = [pd.DataFrame() for _ in range(num_dfs)]
for i, (column_index, label) in enumerate(zip(column_names, labels)):
optimized_dfs[label][column_index] = data.iloc[i]

#Check each DataFrame is complete
if sum([i.shape[1]!=n for i in optimized_dfs]) != 0:
print(f"Warning: Aligning additive traits using K-means clustering failed.")
return dfs

print(f"Aligning additive traits using K-means clustering succeeded.")

#Sort columns by name
for i in range(len(optimized_dfs)):
optimized_dfs[i] = optimized_dfs[i][dfs[0].columns]

return optimized_dfs

def align_additive_traits(
self,
input_list):
"""
Align additive trait fold columns for inferred multidimensional global epistasis ('SumOfSigmoids').
:param input_list: list of of data frames (one per additive trait).
:returns: A list of data frames (one per additive trait).
"""

#Check if global epistasis inferred
if sum(self.data.model_design['transformation']=='SumOfSigmoids') == 0:
return []

#Multidimensional global epistasis additive traits
mge_at_list = [i for i in self.data.model_design.loc[self.data.model_design['transformation']=='SumOfSigmoids','trait'] if len(i)>1]

#Check if multidimensional global epistasis inferred
if mge_at_list == []:
return []

#Aligned DataFrames object
output_list = input_list.copy()

#Align DataFrames if associated additive traits are not shared between multiple phenotypes
for mge_at in mge_at_list:
#Frequency of each additive trait in model design
at_freq_dict = pd.Series([item for sublist in self.data.model_design['trait'] for item in sublist]).value_counts().to_dict()
#No additive traits shared between multiple phenotypes
if [i for i in mge_at if at_freq_dict[i]>1] == []:
#Align DataFrames
opt_list = self.maximize_correlations_k_means(
dfs = [input_list[i-1][[j for j in input_list[i-1].columns if j.startswith("fold_")]] for i in mge_at])
#Replace columns with aligned data
for i in range(len(opt_list)):
output_list[mge_at[i]-1][[j for j in output_list[mge_at[i]-1].columns if j.startswith("fold_")]] = opt_list[i]

return output_list

def get_additive_trait_weights(
self,
folds = None,
Expand Down Expand Up @@ -827,16 +946,8 @@ def get_additive_trait_weights(
at_list[-1][-1] = at_list[-1][-1].loc[mask!=0,:]
#Merge weight data frames corresponding to different folds
at_list[-1] = functools.reduce(lambda x, y: pd.merge(x, y, how='outer', on = ['id', 'id_ref', 'Pos', 'Pos_ref']), at_list[-1])
fold_cols = [i for i in list(at_list[-1].columns) if not i in ['id', 'id_ref', 'Pos', 'Pos_ref']]
at_list[-1]['n'] = at_list[-1].loc[:,fold_cols].notnull().sum(axis=1)
at_list[-1]['mean'] = at_list[-1].loc[:,fold_cols].mean(axis=1)
at_list[-1]['std'] = at_list[-1].loc[:,fold_cols].std(axis=1)
at_list[-1]['ci95'] = at_list[-1]['std']*1.96*2
at_list[-1]['trait_name'] = self.data.additive_trait_names[j]
if RT!=None:
at_list[-1]['mean_kcal/mol'] = at_list[-1]['mean']*RT
at_list[-1]['std_kcal/mol'] = at_list[-1]['std']*RT
at_list[-1]['ci95_kcal/mol'] = at_list[-1]['std_kcal/mol']*1.96*2
#Calculate summary metrics for each additive trait (mean, std, ci95 etc.)
at_list = self.add_additive_trait_statistics(at_list, RT)

#Aggregate weights
if aggregate==True:
Expand Down Expand Up @@ -866,6 +977,18 @@ def get_additive_trait_weights(
for i in range(len(at_list)):
if save:
at_list[i].to_csv(os.path.join(directory, "weights_"+self.data.additive_trait_names[i]+".txt"), sep = "\t", index = False)

#Align additive trait fold columns for inferred multidimensional global epistasis ('SumOfSigmoids')
aat_list = self.align_additive_traits(at_list)
#Save if alignment successful
if aat_list != []:
#Calculate summary metrics for each additive trait (mean, std, ci95 etc.)
aat_list = self.add_additive_trait_statistics(aat_list, RT)
#Save model weights
for i in range(len(aat_list)):
if save:
aat_list[i].to_csv(os.path.join(directory, "weights_aligned_"+self.data.additive_trait_names[i]+".txt"), sep = "\t", index = False)

#Return
if aggregate:
return agg_list
Expand Down

0 comments on commit f4bcaf5

Please sign in to comment.