diff --git a/src/scripts/borzoi_bench_gtex_folds.py b/src/scripts/borzoi_bench_gtex_folds.py new file mode 100644 index 0000000..5b27a79 --- /dev/null +++ b/src/scripts/borzoi_bench_gtex_folds.py @@ -0,0 +1,675 @@ +#!/usr/bin/env python +# Copyright 2023 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser, OptionGroup +import glob +import json +import pickle +import pdb +import os +import shutil +import sys + +import h5py +import numpy as np +import pandas as pd + +import slurm + +""" +borzoi_bench_gtex_folds.py + +Benchmark Basenji model replicates on GTEx eQTL coefficient task. +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + + # sed options + sed_options = OptionGroup(parser, 'borzoi_sed.py options') + sed_options.add_option( + '-b', + dest='bedgraph', + default=False, + action='store_true', + help='Write ref/alt predictions as bedgraph [Default: %default]', + ) + sed_options.add_option( + '-f', + dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]', + ) + sed_options.add_option( + '-g', + dest='genes_gtf', + default='%s/genes/gencode41/gencode41_basic_nort.gtf' % os.environ['HG38'], + help='GTF for gene definition [Default %default]', + ) + sed_options.add_option( + '-o', + dest='out_dir', + default='sed', + help='Output directory for tables and plots [Default: %default]', + ) + sed_options.add_option( + '--rc', + dest='rc', + default=False, + action='store_true', + help='Average forward and reverse complement predictions [Default: %default]', + ) + sed_options.add_option( + '--shifts', + dest='shifts', + default='0', + type='str', + help='Ensemble prediction shifts [Default: %default]', + ) + sed_options.add_option( + '--span', + dest='span', + default=False, + action='store_true', + help='Aggregate entire gene span [Default: %default]', + ) + sed_options.add_option( + '--stats', + dest='sed_stats', + default='SED', + help='Comma-separated list of stats to save. [Default: %default]', + ) + sed_options.add_option( + '-t', + dest='targets_file', + default=None, + type='str', + help='File specifying target indexes and labels in table format', + ) + sed_options.add_option( + '-u', + dest='untransform_old', + default=False, + action='store_true', + ) + sed_options.add_option( + '--no_untransform', + dest='no_untransform', + default=False, + action='store_true', + ) + parser.add_option_group(sed_options) + + # cross-fold + fold_options = OptionGroup(parser, 'cross-fold options') + fold_options.add_option( + '-c', + dest='crosses', + default=1, + type='int', + help='Number of cross-fold rounds [Default:%default]', + ) + fold_options.add_option( + '-d', + dest='data_head', + default=None, + type='int', + help='Index for dataset/head [Default: %default]', + ) + fold_options.add_option( + '-e', + dest='conda_env', + default='tf210', + help='Anaconda environment [Default: %default]', + ) + fold_options.add_option( + '--gtex', + dest='gtex_vcf_dir', + default='/home/drk/seqnn/data/gtex_fine/susie_pip90', + ) + fold_options.add_option( + '--name', + dest='name', + default='gtex', + help='SLURM name prefix [Default: %default]', + ) + fold_options.add_option( + '--max_proc', + dest='max_proc', + default=None, + type='int', + help='Maximum concurrent processes [Default: %default]', + ) + fold_options.add_option( + '-p', + dest='processes', + default=None, + type='int', + help='Number of processes, passed by multi script. (Unused, but needs to appear as dummy.)', + ) + fold_options.add_option( + '-q', + dest='queue', + default='geforce', + help='SLURM queue on which to run the jobs [Default: %default]', + ) + parser.add_option_group(fold_options) + + (options, args) = parser.parse_args() + + if len(args) != 2: + parser.error('Must provide parameters file and cross-fold directory') + else: + params_file = args[0] + exp_dir = args[1] + + ####################################################### + # prep work + + # count folds + num_folds = 0 + fold0_dir = '%s/f%dc0' % (exp_dir, num_folds) + model_file = '%s/train/model_best.h5' % fold0_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (fold0_dir, options.data_head) + while os.path.isfile(model_file): + num_folds += 1 + fold0_dir = '%s/f%dc0' % (exp_dir, num_folds) + model_file = '%s/train/model_best.h5' % fold0_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (fold0_dir, options.data_head) + print('Found %d folds' % num_folds) + if num_folds == 0: + exit(1) + + # extract output subdirectory name + gtex_out_dir = options.out_dir + + # split SNP stats + sed_stats = options.sed_stats.split(',') + + # merge study/tissue variants + mpos_vcf_file = '%s/pos_merge.vcf' % options.gtex_vcf_dir + mneg_vcf_file = '%s/neg_merge.vcf' % options.gtex_vcf_dir + + ################################################################ + # SED + + # SED command base + cmd_base = '. /home/drk/anaconda3/etc/profile.d/conda.sh;' + cmd_base += ' conda activate %s;' % options.conda_env + cmd_base += ' echo $HOSTNAME;' + + jobs = [] + + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + name = '%s-f%dc%d' % (options.name, fi, ci) + + # update output directory + it_out_dir = '%s/%s' % (it_dir, gtex_out_dir) + os.makedirs(it_out_dir, exist_ok=True) + + # choose model + model_file = '%s/train/model_best.h5' % it_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (it_dir, options.data_head) + + ######################################## + # negative jobs + + # pickle options + options.out_dir = '%s/merge_neg' % it_out_dir + os.makedirs(options.out_dir, exist_ok=True) + options_pkl_file = '%s/options.pkl' % options.out_dir + options_pkl = open(options_pkl_file, 'wb') + pickle.dump(options, options_pkl) + options_pkl.close() + + # create base fold command + cmd_fold = '%s time borzoi_sed.py %s %s %s' % ( + cmd_base, options_pkl_file, params_file, model_file) + + for pi in range(options.processes): + sed_file = '%s/job%d/sed.h5' % (options.out_dir, pi) + if not nonzero_h5(sed_file, sed_stats): + cmd_job = '%s %s %d' % (cmd_fold, mneg_vcf_file, pi) + j = slurm.Job(cmd_job, '%s_neg%d' % (name,pi), + '%s/job%d.out' % (options.out_dir,pi), + '%s/job%d.err' % (options.out_dir,pi), + '%s/job%d.sb' % (options.out_dir,pi), + queue=options.queue, gpu=1, cpu=2, + mem=48000, time='7-0:0:0') + jobs.append(j) + + ######################################## + # positive jobs + + # pickle options + options.out_dir = '%s/merge_pos' % it_out_dir + os.makedirs(options.out_dir, exist_ok=True) + options_pkl_file = '%s/options.pkl' % options.out_dir + options_pkl = open(options_pkl_file, 'wb') + pickle.dump(options, options_pkl) + options_pkl.close() + + # create base fold command + cmd_fold = '%s time borzoi_sed.py %s %s %s' % ( + cmd_base, options_pkl_file, params_file, model_file) + + for pi in range(options.processes): + sed_file = '%s/job%d/sed.h5' % (options.out_dir, pi) + if not nonzero_h5(sed_file, sed_stats): + cmd_job = '%s %s %d' % (cmd_fold, mpos_vcf_file, pi) + j = slurm.Job(cmd_job, '%s_pos%d' % (name,pi), + '%s/job%d.out' % (options.out_dir,pi), + '%s/job%d.err' % (options.out_dir,pi), + '%s/job%d.sb' % (options.out_dir,pi), + queue=options.queue, gpu=1, cpu=2, + mem=48000, time='7-0:0:0') + jobs.append(j) + + slurm.multi_run(jobs, max_proc=options.max_proc, verbose=True, + launch_sleep=10, update_sleep=60) + + ####################################################### + # collect output + + for ci in range(options.crosses): + for fi in range(num_folds): + it_out_dir = '%s/f%dc%d/%s' % (exp_dir, fi, ci, gtex_out_dir) + + # collect negatives + neg_out_dir = '%s/merge_neg' % it_out_dir + if not os.path.isfile('%s/sed.h5' % neg_out_dir): + collect_scores(neg_out_dir, options.processes, 'sed.h5') + + # collect positives + pos_out_dir = '%s/merge_pos' % it_out_dir + if not os.path.isfile('%s/sed.h5' % pos_out_dir): + collect_scores(pos_out_dir, options.processes, 'sed.h5') + + + ################################################################ + # split study/tissue variants + + for ci in range(options.crosses): + for fi in range(num_folds): + it_out_dir = '%s/f%dc%d/%s' % (exp_dir, fi, ci, gtex_out_dir) + print(it_out_dir) + + # split positives + split_scores(it_out_dir, 'pos', options.gtex_vcf_dir, sed_stats) + + # split negatives + split_scores(it_out_dir, 'neg', options.gtex_vcf_dir, sed_stats) + + ################################################################ + # ensemble + + ensemble_dir = '%s/ensemble' % exp_dir + if not os.path.isdir(ensemble_dir): + os.mkdir(ensemble_dir) + + gtex_dir = '%s/%s' % (ensemble_dir, gtex_out_dir) + if not os.path.isdir(gtex_dir): + os.mkdir(gtex_dir) + + for gtex_pos_vcf in glob.glob('%s/*_pos.vcf' % options.gtex_vcf_dir): + gtex_neg_vcf = gtex_pos_vcf.replace('_pos.','_neg.') + pos_base = os.path.splitext(os.path.split(gtex_pos_vcf)[1])[0] + neg_base = os.path.splitext(os.path.split(gtex_neg_vcf)[1])[0] + + # collect SED files + sed_pos_files = [] + sed_neg_files = [] + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + it_out_dir = '%s/%s' % (it_dir, gtex_out_dir) + + sed_pos_file = '%s/%s/sed.h5' % (it_out_dir, pos_base) + sed_pos_files.append(sed_pos_file) + + sed_neg_file = '%s/%s/sed.h5' % (it_out_dir, neg_base) + sed_neg_files.append(sed_neg_file) + + # ensemble + ens_pos_dir = '%s/%s' % (gtex_dir, pos_base) + os.makedirs(ens_pos_dir, exist_ok=True) + ens_pos_file = '%s/sed.h5' % (ens_pos_dir) + if not os.path.isfile(ens_pos_file): + ensemble_h5(ens_pos_file, sed_pos_files, sed_stats) + + ens_neg_dir = '%s/%s' % (gtex_dir, neg_base) + os.makedirs(ens_neg_dir, exist_ok=True) + ens_neg_file = '%s/sed.h5' % (ens_neg_dir) + if not os.path.isfile(ens_neg_file): + ensemble_h5(ens_neg_file, sed_neg_files, sed_stats) + + + ################################################################ + # coefficient analysis + + cmd_base = 'borzoi_gtex_coef.py -g %s' % options.gtex_vcf_dir + + jobs = [] + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + it_out_dir = '%s/%s' % (it_dir, gtex_out_dir) + + for sed_stat in sed_stats: + coef_out_dir = f'{it_out_dir}/coef-{sed_stat}' + + if not os.path.isfile('%s/metrics.tsv' % coef_out_dir): + cmd_coef = f'{cmd_base} -o {coef_out_dir} -s {sed_stat} {it_out_dir}' + j = slurm.Job(cmd_coef, 'coef', + f'{coef_out_dir}.out', f'{coef_out_dir}.err', + queue='standard', cpu=2, + mem=30000, time='12:0:0') + jobs.append(j) + + # ensemble + it_out_dir = f'{exp_dir}/ensemble/{gtex_out_dir}' + for sed_stat in sed_stats: + coef_out_dir = f'{it_out_dir}/coef-{sed_stat}' + + if not os.path.isfile('%s/metrics.tsv' % coef_out_dir): + cmd_coef = f'{cmd_base} -o {coef_out_dir} -s {sed_stat} {it_out_dir}' + j = slurm.Job(cmd_coef, 'coef', + f'{coef_out_dir}.out', f'{coef_out_dir}.err', + queue='standard', cpu=2, + mem=30000, time='12:0:0') + jobs.append(j) + + slurm.multi_run(jobs, verbose=True) + + +def collect_scores(out_dir: str, num_jobs: int, h5f_name: str='sad.h5'): + """Collect parallel SAD jobs' output into one HDF5. + + Args: + out_dir (str): Output directory. + num_jobs (int): Number of jobs to combine results from. + """ + # count variants + num_variants = 0 + num_rows = 0 + for pi in range(num_jobs): + # open job + job_h5_file = '%s/job%d/%s' % (out_dir, pi, h5f_name) + job_h5_open = h5py.File(job_h5_file, 'r') + num_variants += len(job_h5_open['snp']) + num_rows += len(job_h5_open['si']) + job_h5_open.close() + + # initialize final h5 + final_h5_file = '%s/%s' % (out_dir, h5f_name) + final_h5_open = h5py.File(final_h5_file, 'w') + + # SNP stats + snp_stats = {} + + job0_h5_file = '%s/job0/%s' % (out_dir, h5f_name) + job0_h5_open = h5py.File(job0_h5_file, 'r') + for key in job0_h5_open.keys(): + if key in ['target_ids', 'target_labels']: + # copy + final_h5_open.create_dataset(key, data=job0_h5_open[key]) + + elif key in ['snp', 'chr', 'pos', 'ref_allele', 'alt_allele', 'gene']: + snp_stats[key] = [] + + elif job0_h5_open[key].ndim == 1: + final_h5_open.create_dataset(key, shape=(num_rows,), dtype=job0_h5_open[key].dtype) + + else: + num_targets = job0_h5_open[key].shape[1] + final_h5_open.create_dataset(key, shape=(num_rows, num_targets), dtype=job0_h5_open[key].dtype) + + job0_h5_open.close() + + # set values + vgi = 0 + vi = 0 + for pi in range(num_jobs): + # open job + job_h5_file = '%s/job%d/%s' % (out_dir, pi, h5f_name) + with h5py.File(job_h5_file, 'r') as job_h5_open: + job_snps = len(job_h5_open['snp']) + job_rows = job_h5_open['si'].shape[0] + + # append to final + for key in job_h5_open.keys(): + try: + if key in ['target_ids', 'target_labels']: + # once is enough + pass + + elif key in ['snp', 'chr', 'pos', 'ref_allele', 'alt_allele', 'gene']: + snp_stats[key] += list(job_h5_open[key]) + + elif key == 'si': + # re-index SNPs + final_h5_open[key][vgi:vgi+job_rows] = job_h5_open[key][:] + vi + + else: + final_h5_open[key][vgi:vgi+job_rows] = job_h5_open[key] + + except TypeError as e: + print(e) + print(f'{job_h5_file} {key} has the wrong shape. Remove this file and rerun') + exit() + + vgi += job_rows + vi += job_snps + + # create final SNP stat datasets + for key in snp_stats: + if key == 'pos': + final_h5_open.create_dataset(key, + data=np.array(snp_stats[key])) + else: + final_h5_open.create_dataset(key, + data=np.array(snp_stats[key], dtype='S')) + + final_h5_open.close() + + +def ensemble_h5(ensemble_h5_file: str, scores_files: list, sed_stats: list): + """Ensemble scores from multiple files into a single file. + + Args: + ensemble_h5_file (str): ensemble score HDF5. + scores_files ([str]): list of replicate score HDFs. + sed_stats ([str]): SED stats to average over folds. + """ + # open ensemble + ensemble_h5 = h5py.File(ensemble_h5_file, 'w') + + # transfer non-SED keys + sed_shapes = {} + scores0_h5 = h5py.File(scores_files[0], 'r') + for key in scores0_h5.keys(): + if key not in sed_stats: + ensemble_h5.create_dataset(key, data=scores0_h5[key]) + else: + sed_shapes[key] = scores0_h5[key].shape + scores0_h5.close() + + # average stats + num_folds = len(scores_files) + for sed_stat in sed_stats: + # initialize ensemble array + sed_values = np.zeros(shape=sed_shapes[sed_stat], dtype='float32') + + # read and add folds + for scores_file in scores_files: + with h5py.File(scores_file, 'r') as scores_h5: + sed_values += scores_h5[sed_stat][:].astype('float32') + + # normalize and downcast + sed_values /= num_folds + sed_values = sed_values.astype('float16') + + # save + ensemble_h5.create_dataset(sed_stat, data=sed_values) + + ensemble_h5.close() + + +def split_scores(it_out_dir: str, posneg: str, vcf_dir: str, sed_stats): + """Split merged VCF predictions in HDF5 into tissue-specific + predictions in HDF5. + + Args: + it_out_dir (str): output directory for iteration. + posneg (str): 'pos' or 'neg'. + vcf_dir (str): directory containing tissue-specific VCFs. + sed_stats ([str]]): list of SED stats. + """ + merge_h5_file = '%s/merge_%s/sed.h5' % (it_out_dir, posneg) + merge_h5 = h5py.File(merge_h5_file, 'r') + + # read merged data + merge_si = merge_h5['si'][:] + merge_snps = [snp.decode('UTF-8') for snp in merge_h5['snp']] + merge_gene = [gene.decode('UTF-8') for gene in merge_h5['gene']] + merge_scores = {} + for ss in sed_stats: + merge_scores[ss] = merge_h5[ss][:] + + # hash snps to row indexes + snp_ri = {} + for ri, si in enumerate(merge_si): + snp_ri.setdefault(merge_snps[si],[]).append(ri) + + # for each tissue VCF + vcf_glob = '%s/*_%s.vcf' % (vcf_dir, posneg) + for tissue_vcf_file in glob.glob(vcf_glob): + tissue_label = tissue_vcf_file.split('/')[-1] + tissue_label = tissue_label.replace('_pos.vcf','') + tissue_label = tissue_label.replace('_neg.vcf','') + + # initialize HDF5 arrays + sed_snp = [] + sed_chr = [] + sed_pos = [] + sed_ref = [] + sed_alt = [] + sed_gene = [] + sed_snpi = [] + sed_scores = {} + for ss in sed_stats: + sed_scores[ss] = [] + + # fill HDF5 arrays with ordered SNPs + si = 0 + for line in open(tissue_vcf_file): + if not line.startswith('#'): + a = line.split() + chrm, pos, snp, ref, alt = a[:5] + + # SNPs w/o genes disappear + if snp in snp_ri: + sed_snp.append(snp) + sed_chr.append(chrm) + sed_pos.append(int(pos)) + sed_ref.append(ref) + sed_alt.append(alt) + + for ri in snp_ri[snp]: + sed_snpi.append(si) + sed_gene.append(merge_gene[ri]) + for ss in sed_stats: + sed_scores[ss].append(merge_scores[ss][ri]) + + si += 1 + + # write tissue HDF5 + tissue_dir = '%s/%s_%s' % (it_out_dir, tissue_label, posneg) + os.makedirs(tissue_dir, exist_ok=True) + with h5py.File('%s/sed.h5' % tissue_dir, 'w') as tissue_h5: + # write SNPs + tissue_h5.create_dataset('snp', + data=np.array(sed_snp, 'S')) + + # write chr + tissue_h5.create_dataset('chr', + data=np.array(sed_chr, 'S')) + + # write SNP pos + tissue_h5.create_dataset('pos', + data=np.array(sed_pos, dtype='uint32')) + + # write ref allele + tissue_h5.create_dataset('ref_allele', + data=np.array(sed_ref, dtype='S')) + + # write alt allele + tissue_h5.create_dataset('alt_allele', + data=np.array(sed_alt, dtype='S')) + + # write SNP i + tissue_h5.create_dataset('si', + data=np.array(sed_snpi)) + + # write gene + tissue_h5.create_dataset('gene', + data=np.array(sed_gene, 'S')) + + # write targets + tissue_h5.create_dataset('target_ids', data=merge_h5['target_ids']) + tissue_h5.create_dataset('target_labels', data=merge_h5['target_labels']) + + # write sed stats + for ss in sed_stats: + tissue_h5.create_dataset(ss, + data=np.array(sed_scores[ss], dtype='float16')) + + merge_h5.close() + +def nonzero_h5(h5_file: str, stat_keys): + """Verify the HDF5 exists, and there are nonzero values + for each stat key given. + + Args: + h5_file (str): HDF5 file name. + stat_keys ([str]): List of SNP stat keys. + """ + if os.path.isfile(h5_file): + try: + with h5py.File(h5_file, 'r') as h5_open: + for sk in stat_keys: + sad = h5_open[sk][:] + if (sad != 0).sum() > 0: + return True + return False + except: + return False + else: + return False + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/src/scripts/borzoi_bench_sqtl_folds.py b/src/scripts/borzoi_bench_sqtl_folds.py index bb7751e..c71e724 100755 --- a/src/scripts/borzoi_bench_sqtl_folds.py +++ b/src/scripts/borzoi_bench_sqtl_folds.py @@ -98,6 +98,18 @@ def main(): type="str", help="File specifying target indexes and labels in table format", ) + sed_options.add_option( + '-u', + dest='untransform_old', + default=False, + action='store_true' + ) + sed_options.add_option( + '--no_untransform', + dest='no_untransform', + default=False, + action='store_true' + ) parser.add_option_group(sed_options) # classify diff --git a/src/scripts/borzoi_gtex_coef.py b/src/scripts/borzoi_gtex_coef.py new file mode 100644 index 0000000..680d60b --- /dev/null +++ b/src/scripts/borzoi_gtex_coef.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python +from optparse import OptionParser +import os +import pdb +import re +import sys + +import h5py +import numpy as np +import pandas as pd +from scipy.stats import spearmanr, pearsonr +from sklearn.metrics import roc_auc_score + +import matplotlib.pyplot as plt +import seaborn as sns + +''' +borzoi_gtex_coef.py + +Evaluate concordance of variant effect prediction sign classifcation +and coefficient correlations. +''' + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + + parser.add_option( + '-o', + dest='out_dir', + default='coef_out', + help='Output directory for tissue metrics', + ) + parser.add_option( + '-g', + dest='gtex_vcf_dir', + default='/home/drk/seqnn/data/gtex_fine/susie_pip90', + help='GTEx VCF directory', + ) + parser.add_option( + '-m', + dest='min_variants', + type=int, + default=32, + help='Minimum number of variants for tissue to be included', + ) + parser.add_option( + '-p', + dest='plot', + default=False, + action='store_true', + help='Generate tissue prediction plots', + ) + parser.add_option( + '-s', + dest='snp_stat', + default='logSAD', + help='SNP statistic. [Default: %(default)s]', + ) + parser.add_option( + '-v', + dest='verbose', + default=False, + action='store_true', + ) + + (options, args) = parser.parse_args() + + if len(args) != 1: + parser.error('Must provide gtex output directory') + else: + gtex_dir = args[0] + + os.makedirs(options.out_dir, exist_ok=True) + + tissue_keywords = { + 'Adipose_Subcutaneous': 'adipose', + 'Adipose_Visceral_Omentum': 'adipose', + 'Adrenal_Gland': 'adrenal_gland', + 'Artery_Aorta': 'blood_vessel', + 'Artery_Coronary': 'blood_vessel', + 'Artery_Tibial': 'blood_vessel', + 'Brain_Amygdala' : 'brain', + 'Brain_Anterior_cingulate_cortex_BA24' : 'brain', + 'Brain_Caudate_basal_ganglia' : 'brain', + 'Brain_Cerebellar_Hemisphere' : 'brain', + 'Brain_Cerebellum': 'brain', + 'Brain_Cortex': 'brain', + 'Brain_Frontal_Cortex_BA9' : 'brain', + 'Brain_Hippocampus' : 'brain', + 'Brain_Hypothalamus' : 'brain', + 'Brain_Nucleus_accumbens_basal_ganglia' : 'brain', + 'Brain_Putamen_basal_ganglia' : 'brain', + 'Brain_Spinal_cord_cervical_c-1' : 'brain', + 'Brain_Substantia_nigra' : 'brain', + 'Breast_Mammary_Tissue': 'breast', + 'Cells_Cultured_fibroblasts' : 'skin', + 'Cells_EBV-transformed_lymphocytes' : 'blood', + 'Colon_Sigmoid': 'colon', + 'Colon_Transverse': 'colon', + 'Esophagus_Gastroesophageal_Junction' : 'esophagus', + 'Esophagus_Mucosa': 'esophagus', + 'Esophagus_Muscularis': 'esophagus', + 'Heart_Atrial_Appendage' : 'heart', + 'Heart_Left_Ventricle' : 'heart', + 'Kidney_Cortex' : 'kidney', + 'Liver': 'liver', + 'Lung': 'lung', + 'Minor_Salivary_Gland' : 'salivary_gland', + 'Muscle_Skeletal': 'muscle', + 'Nerve_Tibial': 'nerve', + 'Ovary': 'ovary', + 'Pancreas': 'pancreas', + 'Pituitary': 'pituitary', + 'Prostate': 'prostate', + 'Skin_Not_Sun_Exposed_Suprapubic': 'skin', + 'Skin_Sun_Exposed_Lower_leg' : 'skin', + 'Small_Intestine_Terminal_Ileum' : 'small_intestine', + 'Spleen': 'spleen', + 'Stomach': 'stomach', + 'Testis': 'testis', + 'Thyroid': 'thyroid', + 'Uterus' : 'uterus', + 'Vagina' : 'vagina', + 'Whole_Blood': 'blood', + } + + metrics_tissue = [] + metrics_sauroc = [] + metrics_cauroc = [] + metrics_rs = [] + metrics_rp = [] + metrics_n = [] + for tissue, keyword in tissue_keywords.items(): + if options.verbose: print(tissue) + + # read causal variants + eqtl_df = read_eqtl(tissue, options.gtex_vcf_dir) + + if eqtl_df is not None and eqtl_df.shape[0] > options.min_variants: + # read model predictions + gtex_scores_file = f'{gtex_dir}/{tissue}_pos/sed.h5' + eqtl_df = add_scores(gtex_scores_file, keyword, eqtl_df, + options.snp_stat, verbose=options.verbose) + + eqtl_df = eqtl_df.loc[~eqtl_df['score'].isnull()].copy() + + # compute AUROCs + sign_auroc = roc_auc_score(eqtl_df.coef > 0, eqtl_df.score) + + # compute SpearmanR + coef_r = spearmanr(eqtl_df.coef, eqtl_df.score)[0] + + # compute PearsonR + coef_rp = pearsonr(eqtl_df.coef, eqtl_df.score)[0] + + coef_n = len(eqtl_df) + + # classification AUROC + class_auroc = classify_auroc(gtex_scores_file, keyword, eqtl_df, + options.snp_stat) + + if options.plot: + eqtl_df.to_csv(f'{options.out_dir}/{tissue}.tsv', + index=False, sep='\t') + + # scatterplot + plt.figure(figsize=(6,6)) + sns.scatterplot(x=eqtl_df.coef, y=eqtl_df.score, + alpha=0.5, s=20) + plt.gca().set_xlabel('eQTL coefficient') + plt.gca().set_ylabel('Variant effect prediction') + plt.savefig(f'{options.out_dir}/{tissue}.png', dpi=300) + + # save + metrics_tissue.append(tissue) + metrics_sauroc.append(sign_auroc) + metrics_cauroc.append(class_auroc) + metrics_rs.append(coef_r) + metrics_rp.append(coef_rp) + metrics_n.append(coef_n) + + if options.verbose: print('') + + # save metrics + metrics_df = pd.DataFrame({ + 'tissue': metrics_tissue, + 'auroc_sign': metrics_sauroc, + 'spearmanr': metrics_rs, + 'pearsonr': metrics_rp, + 'n': metrics_n, + 'auroc_class': metrics_cauroc + }) + metrics_df.to_csv(f'{options.out_dir}/metrics.tsv', + sep='\t', index=False, float_format='%.4f') + + # summarize + print('Sign AUROC: %.4f' % np.mean(metrics_df.auroc_sign)) + print('SpearmanR: %.4f' % np.mean(metrics_df.spearmanr)) + print('Class AUROC: %.4f' % np.mean(metrics_df.auroc_class)) + + +def read_eqtl(tissue: str, gtex_vcf_dir: str, pip_t: float=0.9): + """Reads eQTLs from SUSIE output. + + Args: + tissue (str): Tissue name. + gtex_vcf_dir (str): GTEx VCF directory. + pip_t (float): PIP threshold. + + Returns: + eqtl_df (pd.DataFrame): eQTL dataframe, or None if tissue skipped. + """ + susie_dir = '/home/drk/seqnn/data/gtex_fine/tissues_susie' + + # read causal variants + eqtl_file = f'{susie_dir}/{tissue}.tsv' + df_eqtl = pd.read_csv(eqtl_file, sep='\t', index_col=0) + + # pip filter + pip_match = re.search(r"_pip(\d+).?$", gtex_vcf_dir).group(1) + pip_t = float(pip_match) / 100 + assert(pip_t > 0 and pip_t <= 1) + df_causal = df_eqtl[df_eqtl.pip > pip_t] + + # make table + tissue_vcf_file = f'{gtex_vcf_dir}/{tissue}_pos.vcf' + if not os.path.isfile(tissue_vcf_file): + eqtl_df = None + else: + # create dataframe + eqtl_df = pd.DataFrame({ + 'variant': df_causal.variant, + 'gene': [trim_dot(gene_id) for gene_id in df_causal.gene], + 'coef': df_causal.beta_posterior, + 'allele1': df_causal.allele1 + }) + return eqtl_df + + +def add_scores(gtex_scores_file: str, + keyword: str, + eqtl_df: pd.DataFrame, + score_key: str='SED', + verbose: bool=False): + """Read eQTL RNA predictions for the given tissue. + + Args: + gtex_scores_file (str): Variant scores HDF5. + tissue_keyword (str): tissue keyword, for matching GTEx targets + eqtl_df (pd.DataFrame): eQTL dataframe + score_key (str): score key in HDF5 file + verbose (bool): Print matching targets. + + Returns: + eqtl_df (pd.DataFrame): eQTL dataframe, with added scores + """ + with h5py.File(gtex_scores_file, 'r') as gtex_scores_h5: + # read data + snp_i = gtex_scores_h5['si'][:] + snps = np.array([snp.decode('UTF-8') for snp in gtex_scores_h5['snp']]) + ref_allele = np.array([ref.decode('UTF-8') for ref in gtex_scores_h5['ref_allele']]) + genes = np.array([snp.decode('UTF-8') for snp in gtex_scores_h5['gene']]) + target_ids = np.array([ref.decode('UTF-8') for ref in gtex_scores_h5['target_ids']]) + target_labels = np.array([ref.decode('UTF-8') for ref in gtex_scores_h5['target_labels']]) + + # determine matching GTEx targets + match_tis = [] + for ti in range(len(target_ids)): + if target_ids[ti].find('GTEX') != -1 and target_labels[ti].find(keyword) != -1: + if not keyword == 'blood' or target_labels[ti].find('vessel') == -1: + if verbose: + print(ti, target_ids[ti], target_labels[ti]) + match_tis.append(ti) + match_tis = np.array(match_tis) + + # read scores and take mean across targets + score_table = gtex_scores_h5[score_key][...,match_tis].mean(axis=-1, dtype='float32') + score_table = np.arcsinh(score_table) + + # hash scores to (snp,gene) + snpgene_scores = {} + for sgi in range(score_table.shape[0]): + snp = snps[snp_i[sgi]] + gene = trim_dot(genes[sgi]) + snpgene_scores[(snp,gene)] = score_table[sgi] + + # add scores to eQTL table + # flipping when allele1 doesn't match + snp_ref = dict(zip(snps, ref_allele)) + eqtl_df_scores = [] + for sgi, eqtl in eqtl_df.iterrows(): + sgs = snpgene_scores.get((eqtl.variant,eqtl.gene), 0) + if not np.isnan(sgs) and sgs != 0 and snp_ref[eqtl.variant] != eqtl.allele1: + sgs *= -1 + eqtl_df_scores.append(sgs) + eqtl_df['score'] = eqtl_df_scores + + return eqtl_df + + +def classify_auroc(gtex_scores_file: str, + keyword: str, + eqtl_df: pd.DataFrame, + score_key: str='SED', + agg_mode: str='max'): + """Read eQTL RNA predictions for negatives from the given tissue. + + Args: + gtex_scores_file (str): Variant scores HDF5. + tissue_keyword (str): tissue keyword, for matching GTEx targets + eqtl_df (pd.DataFrame): eQTL dataframe + score_key (str): score key in HDF5 file + verbose (bool): Print matching targets. + + Returns: + class_auroc (float): Classification AUROC. + """ + + # read positive scores + with h5py.File(gtex_scores_file, 'r') as gtex_scores_h5: + # read data + snp_i = gtex_scores_h5['si'][:] + snps = np.array([snp.decode('UTF-8') for snp in gtex_scores_h5['snp']]) + genes = np.array([snp.decode('UTF-8') for snp in gtex_scores_h5['gene']]) + target_ids = np.array([ref.decode('UTF-8') for ref in gtex_scores_h5['target_ids']]) + target_labels = np.array([ref.decode('UTF-8') for ref in gtex_scores_h5['target_labels']]) + + # determine matching GTEx targets + match_tis = [] + for ti in range(len(target_ids)): + if target_ids[ti].find('GTEX') != -1 and target_labels[ti].find(keyword) != -1: + if not keyword == 'blood' or target_labels[ti].find('vessel') == -1: + match_tis.append(ti) + match_tis = np.array(match_tis) + + # read scores and take mean across targets + score_table = gtex_scores_h5[score_key][...,match_tis].mean(axis=-1, dtype='float32') + score_table = np.arcsinh(score_table) + + # aggregate across genes (sum abs or max abs); positives + psnp_scores = {} + for sgi in range(score_table.shape[0]): + snp = snps[snp_i[sgi]] + if agg_mode == 'sum' : + psnp_scores[snp] = psnp_scores.get(snp,0) + np.abs(score_table[sgi]) + elif agg_mode == 'max' : + psnp_scores[snp] = max(psnp_scores.get(snp,0), np.abs(score_table[sgi])) + + # read negative scores + gtex_nscores_file = gtex_scores_file.replace('_pos','_neg') + with h5py.File(gtex_nscores_file, 'r') as gtex_scores_h5: + # read data + snp_i = gtex_scores_h5['si'][:] + snps = np.array([snp.decode('UTF-8') for snp in gtex_scores_h5['snp']]) + genes = np.array([snp.decode('UTF-8') for snp in gtex_scores_h5['gene']]) + target_ids = np.array([ref.decode('UTF-8') for ref in gtex_scores_h5['target_ids']]) + target_labels = np.array([ref.decode('UTF-8') for ref in gtex_scores_h5['target_labels']]) + + # determine matching GTEx targets + match_tis = [] + for ti in range(len(target_ids)): + if target_ids[ti].find('GTEX') != -1 and target_labels[ti].find(keyword) != -1: + if not keyword == 'blood' or target_labels[ti].find('vessel') == -1: + match_tis.append(ti) + match_tis = np.array(match_tis) + + # read scores and take mean across targets + score_table = gtex_scores_h5[score_key][...,match_tis].mean(axis=-1, dtype='float32') + score_table = np.arcsinh(score_table) + + # aggregate across genes (sum abs or max abs); negatives + nsnp_scores = {} + for sgi in range(score_table.shape[0]): + snp = snps[snp_i[sgi]] + if agg_mode == 'sum' : + nsnp_scores[snp] = nsnp_scores.get(snp,0) + np.abs(score_table[sgi]) + elif agg_mode == 'max' : + nsnp_scores[snp] = max(nsnp_scores.get(snp,0), np.abs(score_table[sgi])) + + # compute AUROC + Xp = list(psnp_scores.values()) + Xn = list(nsnp_scores.values()) + X = Xp + Xn + y = [1]*len(Xp) + [0]*len(Xn) + + return roc_auc_score(y, X) + + +def trim_dot(gene_id): + """Trim dot off GENCODE id's.""" + dot_i = gene_id.rfind('.') + if dot_i != -1: + gene_id = gene_id[:dot_i] + return gene_id + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/src/scripts/borzoi_sed.py b/src/scripts/borzoi_sed.py index 3471a11..dc0a736 100755 --- a/src/scripts/borzoi_sed.py +++ b/src/scripts/borzoi_sed.py @@ -111,7 +111,18 @@ def main(): type="str", help="File specifying target indexes and labels in table format", ) - parser.add_option("-u", dest="untransform_old", default=False, action="store_true") + parser.add_option( + "-u", + dest="untransform_old", + default=False, + action="store_true", + ) + parser.add_option( + "--no_untransform", + dest="no_untransform", + default=False, + action="store_true", + ) (options, args) = parser.parse_args() if len(args) == 3: @@ -271,12 +282,13 @@ def main(): # untransform predictions if options.targets_file is not None: - if options.untransform_old: - ref_preds = dataset.untransform_preds1(ref_preds, targets_df) - alt_preds = dataset.untransform_preds1(alt_preds, targets_df) - else: - ref_preds = dataset.untransform_preds(ref_preds, targets_df) - alt_preds = dataset.untransform_preds(alt_preds, targets_df) + if not options.no_untransform: + if options.untransform_old: + ref_preds = dataset.untransform_preds1(ref_preds, targets_df) + alt_preds = dataset.untransform_preds1(alt_preds, targets_df) + else: + ref_preds = dataset.untransform_preds(ref_preds, targets_df) + alt_preds = dataset.untransform_preds(alt_preds, targets_df) if options.bedgraph: write_bedgraph_snp( @@ -599,29 +611,41 @@ def write_snp(ref_preds, alt_preds, sed_out, xi: int, sed_stats, pseudocounts): # ref/alt_preds is L x T seq_len, num_targets = ref_preds.shape - # sum across bins + # log/sqrt + ref_preds_log = np.log2(ref_preds+1) + alt_preds_log = np.log2(alt_preds+1) + + # sum across length ref_preds_sum = ref_preds.sum(axis=0) alt_preds_sum = alt_preds.sum(axis=0) # difference of sums - if "SED" in sed_stats: + if 'SED' in sed_stats: sed = alt_preds_sum - ref_preds_sum - sed_out["SED"][xi] = clip_float(sed).astype("float16") - if "logSED" in sed_stats: + sed_out['SED'][xi] = clip_float(sed).astype('float16') + if 'logSED' in sed_stats: log_sed = np.log2(alt_preds_sum + 1) - np.log2(ref_preds_sum + 1) - sed_out["logSED"][xi] = log_sed.astype("float16") + sed_out['logSED'][xi] = log_sed.astype('float16') # difference L1 norm if "D1" in sed_stats: diff_abs = np.abs(ref_preds - alt_preds) diff_norm1 = diff_abs.sum(axis=0) sed_out["D1"][xi] = clip_float(diff_norm1).astype("float16") + if 'logD1' in sed_stats: + diff1_log = np.abs(ref_preds_log - alt_preds_log, 2) + diff_log_norm1 = diff1_log.sum(axis=0) + sed_out['logD1'][xi] = clip_float(diff_log_norm1).astype('float16') # difference L2 norm - if "D2" in sed_stats: + if 'D2' in sed_stats: diff2 = np.power(ref_preds - alt_preds, 2) diff_norm2 = np.sqrt(diff2.sum(axis=0)) - sed_out["D2"][xi] = clip_float(diff_norm2).astype("float16") + sed_out['D2'][xi] = clip_float(diff_norm2).astype('float16') + if 'logD2' in sed_stats: + diff2_log = np.power(ref_preds_log - alt_preds_log, 2) + diff_log_norm2 = np.sqrt(diff2_log.sum(axis=0)) + sed_out['logD2'][xi] = clip_float(diff_log_norm2).astype('float16') # normalized scores ref_preds_norm = ref_preds + pseudocounts diff --git a/src/scripts/borzoi_sed_folds.py b/src/scripts/borzoi_sed_folds.py new file mode 100644 index 0000000..eb0176b --- /dev/null +++ b/src/scripts/borzoi_sed_folds.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python +# Copyright 2019 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser, OptionGroup +import glob +import h5py +import json +import pdb +import os +import sys + +import numpy as np +import pandas as pd + +import slurm + +""" +borzoi_sed_folds.py + +Compute SED scores across model folds. +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + + # sed + sed_options = OptionGroup(parser, 'borzoi_sed_folds.py options') + sed_options.add_option( + '-f', + dest='genome_fasta', + default='%s/data/hg38.fa' % os.environ['BASENJIDIR'], + help='Genome FASTA for sequences [Default: %default]', + ) + sed_options.add_option( + '-g', + dest='genes_gtf', + default='%s/genes/gencode41/gencode41_basic_nort.gtf' % os.environ['HG38'], + help='GTF for gene definition [Default %default]', + ) + sed_options.add_option( + '-o', + dest='out_dir', + default='sed', + help='Output directory for tables and plots [Default: %default]', + ) + sed_options.add_option( + '-p', + dest='processes', + default=None, + type='int', + help='Number of processes, passed by multi script', + ) + sed_options.add_option( + '--rc', + dest='rc', + default=False, + action='store_true', + help='Average forward and reverse complement predictions [Default: %default]', + ) + sed_options.add_option( + '--shifts', + dest='shifts', + default='0', + type='str', + help='Ensemble prediction shifts [Default: %default]', + ) + sed_options.add_option( + '--span', + dest='span', + default=False, + action='store_true', + help='Aggregate entire gene span [Default: %default]', + ) + sed_options.add_option( + '-u', + dest='untransform_old', + default=False, + action='store_true', + help='Undo scale, clip_soft and sqrt transforms (old) [Default: %default]', + ) + sed_options.add_option( + '--no_untransform', + dest='no_untransform', + default=False, + action='store_true', + ) + sed_options.add_option( + '--stats', + dest='sed_stats', + default='D2', + help='Comma-separated list of stats to save. [Default: %default]', + ) + sed_options.add_option( + '-t', + dest='targets_file', + default=None, + type='str', + help='File specifying target indexes and labels in table format', + ) + sed_options.add_option( + '-u', + dest='untransform_old', + default=False, + action='store_true', + ) + sed_options.add_option( + '--no_untransform', + dest='no_untransform', + default=False, + action='store_true', + ) + parser.add_option_group(sed_options) + + + # cross-fold + fold_options = OptionGroup(parser, 'cross-fold options') + fold_options.add_option( + '-c', + dest='crosses', + default=1, + type='int', + help='Number of cross-fold rounds [Default:%default]', + ) + fold_options.add_option( + '-d', + dest='data_head', + default=None, + type='int', + help='Index for dataset/head [Default: %default]', + ) + fold_options.add_option( + '-e', + dest='conda_env', + default='tf210', + help='Anaconda environment [Default: %default]', + ) + fold_options.add_option( + '--name', + dest='name', + default='sed', + help='SLURM name prefix [Default: %default]', + ) + fold_options.add_option( + '--max_proc', + dest='max_proc', + default=None, + type='int', + help='Maximum concurrent processes [Default: %default]', + ) + fold_options.add_option( + '-q', + dest='queue', + default='geforce', + help='SLURM queue on which to run the jobs [Default: %default]', + ) + fold_options.add_option( + '-r', + dest='restart', + default=False, + action='store_true', + help='Restart a partially completed job [Default: %default]', + ) + fold_options.add_option( + '--vcf', + dest='vcf_file', + default='/home/drk/seqnn/data/gtex_fine/susie_pip90/pos_merge.vcf', + ) + parser.add_option_group(fold_options) + + (options, args) = parser.parse_args() + + if len(args) != 2: + parser.error('Must provide parameters file and cross-fold directory') + else: + params_file = args[0] + exp_dir = args[1] + + ####################################################### + # prep work + + # count folds + num_folds = 0 + fold0_dir = '%s/f%dc0' % (exp_dir, num_folds) + model_file = '%s/train/model_best.h5' % fold0_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (fold0_dir, options.data_head) + while os.path.isfile(model_file): + num_folds += 1 + fold0_dir = '%s/f%dc0' % (exp_dir, num_folds) + model_file = '%s/train/model_best.h5' % fold0_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (fold0_dir, options.data_head) + print('Found %d folds' % num_folds) + if num_folds == 0: + exit(1) + + ################################################################ + # SNP scores + + # command base + cmd_base = '. /home/drk/anaconda3/etc/profile.d/conda.sh;' + cmd_base += ' conda activate %s;' % options.conda_env + cmd_base += ' echo $HOSTNAME;' + + jobs = [] + + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (exp_dir, fi, ci) + name = '%s-f%dc%d' % (options.name, fi, ci) + + # update output directory + it_out_dir = '%s/%s' % (it_dir, options.out_dir) + os.makedirs(it_out_dir, exist_ok=True) + + model_file = '%s/train/model_best.h5' % it_dir + if options.data_head is not None: + model_file = '%s/train/model%d_best.h5' % (it_dir, options.data_head) + + cmd_fold = '%s time borzoi_sed.py %s %s' % (cmd_base, params_file, model_file) + + # variant scoring job + job_out_dir = it_out_dir + if not options.restart or not os.path.isfile('%s/sed.h5'%job_out_dir): + cmd_job = '%s %s' % (cmd_fold, options.vcf_file) + cmd_job += ' %s' % options_string(options, sed_options, job_out_dir) + j = slurm.Job(cmd_job, '%s' % name, + '%s.out'%job_out_dir, '%s.err'%job_out_dir, '%s.sb'%job_out_dir, + queue=options.queue, gpu=1, + mem=60000, time='30-0:0:0') + jobs.append(j) + + slurm.multi_run(jobs, max_proc=options.max_proc, verbose=True, + launch_sleep=10, update_sleep=60) + +def options_string(options, group_options, rep_dir): + options_str = '' + + for opt in group_options.option_list: + opt_str = opt.get_opt_string() + opt_value = options.__dict__[opt.dest] + + # wrap askeriks in "" + if type(opt_value) == str and opt_value.find('*') != -1: + opt_value = '"%s"' % opt_value + + # no value for bools + elif type(opt_value) == bool: + if not opt_value: + opt_str = '' + opt_value = '' + + # skip Nones + elif opt_value is None: + opt_str = '' + opt_value = '' + + # modify + elif opt.dest == 'out_dir': + opt_value = rep_dir + + options_str += ' %s %s' % (opt_str, opt_value) + + return options_str + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/src/scripts/borzoi_test_genes.py b/src/scripts/borzoi_test_genes.py index 43f51dc..e36fa54 100755 --- a/src/scripts/borzoi_test_genes.py +++ b/src/scripts/borzoi_test_genes.py @@ -90,6 +90,20 @@ def main(): default="test", help="Dataset split label for eg TFR pattern [Default: %default]", ) + parser.add_option( + '--no_unclip', + dest='no_unclip', + default=False, + action='store_true', + help='Turn off unclip transform [Default: %default]', + ) + parser.add_option( + '--pseudo_qtl', + dest='pseudo_qtl', + default=None, + type='float', + help='Quantile of coverage to add as pseudo counts to genes [Default: %default]', + ) parser.add_option( "--tfr", dest="tfr_pattern", @@ -293,11 +307,11 @@ def main(): # untransform if options.untransform_old: - gene_preds_gi = dataset.untransform_preds1(gene_preds_gi, targets_strand_df) - gene_targets_gi = dataset.untransform_preds1(gene_targets_gi, targets_strand_df) + gene_preds_gi = dataset.untransform_preds1(gene_preds_gi, targets_strand_df, unclip=not options.no_unclip) + gene_targets_gi = dataset.untransform_preds1(gene_targets_gi, targets_strand_df, unclip=not options.no_unclip) else: - gene_preds_gi = dataset.untransform_preds(gene_preds_gi, targets_strand_df) - gene_targets_gi = dataset.untransform_preds(gene_targets_gi, targets_strand_df) + gene_preds_gi = dataset.untransform_preds(gene_preds_gi, targets_strand_df, unclip=not options.no_unclip) + gene_targets_gi = dataset.untransform_preds(gene_targets_gi, targets_strand_df, unclip=not options.no_unclip) # compute within gene correlation before dropping length axis gene_corr_gi = np.zeros(num_targets_strand) @@ -321,8 +335,8 @@ def main(): # np.save('%s/gene_within/%s_targets.npy' % (options.out_dir, gene_id), gene_targets_gi.astype('float16')) # mean coverage - gene_preds_gi = gene_preds_gi.mean(axis=0) - gene_targets_gi = gene_targets_gi.mean(axis=0) + gene_preds_gi = gene_preds_gi.mean(axis=0) / float(pool_width) + gene_targets_gi = gene_targets_gi.mean(axis=0) / float(pool_width) # scale by gene length gene_preds_gi *= gene_lengths[gene_id] @@ -336,6 +350,17 @@ def main(): gene_within = np.array(gene_within) gene_wvar = np.array(gene_wvar) + # add pseudo coverage + if options.pseudo_qtl is not None : + for ti in range(num_targets_strand): + nonzero_index = np.nonzero(gene_targets[:, ti] != 0.)[0] + + pseudo_t = np.quantile(gene_targets[:, ti][nonzero_index], q=options.pseudo_qtl) + pseudo_p = np.quantile(gene_preds[:, ti][nonzero_index], q=options.pseudo_qtl) + + gene_targets[:, ti] += pseudo_t + gene_preds[:, ti] += pseudo_p + # log2 transform gene_targets = np.log2(gene_targets + 1) gene_preds = np.log2(gene_preds + 1) diff --git a/src/scripts/borzoi_test_genes_folds.py b/src/scripts/borzoi_test_genes_folds.py index f889fda..84a334c 100755 --- a/src/scripts/borzoi_test_genes_folds.py +++ b/src/scripts/borzoi_test_genes_folds.py @@ -69,6 +69,19 @@ def main(): type="str", help="File specifying target indexes and labels in table format", ) + parser.add_option( + '--no_unclip', + dest='no_unclip', + default=False, + action='store_true', + help='Turn off unclip transform [Default: %default]', + ) + parser.add_option( + '--pseudo_qtl', + dest='pseudo_qtl', + default=None, type='float', + help='Quantile of coverage to add as pseudo counts to genes [Default: %default]', + ) parser.add_option( "-u", dest="untransform_old", @@ -245,6 +258,12 @@ def main(): cmd += " --rc" if options.shifts: cmd += " --shifts %s" % options.shifts + if options.no_unclip: + cmd += ' --no_unclip' + if options.pseudo_qtl is not None: + cmd += ' --pseudo_qtl %.2f' % options.pseudo_qtl + if options.untransform_old: + cmd += ' -u' if options.span: cmd += " --span" job_mem = 240000 diff --git a/src/scripts/borzoi_tfmodisco.py b/src/scripts/borzoi_tfmodisco.py new file mode 100644 index 0000000..4caa384 --- /dev/null +++ b/src/scripts/borzoi_tfmodisco.py @@ -0,0 +1,492 @@ +#!/usr/bin/env python +from optparse import OptionParser +from collections import Counter +import os +import pdb +import subprocess +import time + +import h5py +import numpy as np +import pandas as pd +import pybedtools +import pyranges +from scipy.ndimage import gaussian_filter1d +from tqdm import tqdm + +from matplotlib import pyplot as plt +import seaborn as sns + +from baskerville import dna_io +import pygene +import modisco + +''' +borzoi_tfmodisco.py + +Run TF Modisco on borzoi input saliency scores. +''' + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option( + '-c', + dest='center_bp', + default=None, + type='int', + help='Extract only center bp [Default: %default]', + ) + parser.add_option( + '-d', + dest='meme_db', + default='/homde/drk/code/meme-5.4.1/motif_databases/CIS-BP_2.00/Homo_sapiens.meme', + help='Meme database [Default: %default]', + ) + parser.add_option( + '-g', + dest='genes_gtf_file', + default='/home/drk/common/data/genomes/hg38/genes/gencode38/gencode38_basic_protein.gtf', + help='Gencode GTF [Default: %default]', + ) + parser.add_option( + '--gc', + dest='gc_content', + default=0.41, + type='float', + help='Genome GC content [Default: %default]', + ) + parser.add_option( + '--fwd', + dest='force_fwd', + default=0, + type='int', + help='Do not use rev-comp in modisco [Default: %default]', + ) + parser.add_option( + '--modisco_window_size', + dest='modisco_window_size', + default=24, + type='int', + help='Modisco window size [Default: %default]', + ) + parser.add_option( + '--modisco_flank', + dest='modisco_flank', + default=8, + type='int', + help='Modisco flanks to add [Default: %default]', + ) + parser.add_option( + '--modisco_sliding_window_size', + dest='modisco_sliding_window_size', + default=18, + type='int', + help='Modisco sliding window size [Default: %default]', + ) + parser.add_option( + '--modisco_sliding_window_flank', + dest='modisco_sliding_window_flank', + default=8, + type='int', + help='Modisco sliding window flanks [Default: %default]', + ) + parser.add_option( + '--modisco_max_seqlets', + dest='modisco_max_seqlets', + default=20000, + type='int', + help='Modisco sliding window flanks [Default: %default]', + ) + parser.add_option( + '-i', + dest='ic_t', + default=0.1, + type='float', + help='Information content threshold [Default: %default]', + ) + parser.add_option( + '-n', + dest='norm_type', + default='max', + ) + parser.add_option( + '-o', + dest='out_dir', + default='tfm_out', + help='Output directory [Default: %default]', + ) + parser.add_option( + '-r', + dest='region', + default=None, + help='Limit to specific gene region [Default: %default', + ) + parser.add_option( + '-t', + dest='targets_file', + default=None, + type='str', + help='File specifying target indexes and labels in table format', + ) + parser.add_option( + '--kmer_len', + dest='kmer_len', + default=None, + type='int', + help='Extract only center bp [Default: %default]', + ) + parser.add_option( + '--num_gaps', + dest='num_gaps', + default=None, + type='int', + help='Extract only center bp [Default: %default]', + ) + parser.add_option( + '--num_mismatches', + dest='num_mismatches', + default=None, + type='int', + help='Extract only center bp [Default: %default]', + ) + parser.add_option( + '--clip_perc', + dest='clip_perc', + default=25, + type='int', + help='Percentile of max deviations to clip by [Default: %default]', + ) + parser.add_option( + '--tissue', + dest='tissue', + default=None, + type='str', + help='Main tissue name.', + ) + parser.add_option( + '--gene_file', + dest='gene_file', + default=None, + type='str', + help='Csv-file of gene metadata.', + ) + + (options,args) = parser.parse_args() + + if len(args) != 1: + parser.error('Must provide Basenji nucleotide scores.') + else: + scores_h5_file = args[0] + + # setup output dir + os.makedirs(options.out_dir, exist_ok=True) + + #Load gene dataframe and select tissue + tissue_genes = None + if options.gene_file is not None and options.tissue is not None : + gene_df = pd.read_csv(options.gene_file, sep='\t') + gene_df = gene_df.query("tissue == '" + str(options.tissue) + "'").copy().reset_index(drop=True) + gene_df = gene_df.drop(columns=['Unnamed: 0']) + + print("len(gene_df) = " + str(len(gene_df))) + + #Get list of gene for tissue + tissue_genes = gene_df['gene_base'].values.tolist() + + print("len(tissue_genes) = " + str(len(tissue_genes))) + + # read nucleotide scores + t0 = time.time() + print('Reading scores...', flush=True, end='') + with h5py.File(scores_h5_file, 'r') as scores_h5: + seq_len = scores_h5['grads'].shape[1] + pos_start = seq_len//2 - options.center_bp//2 + pos_end = pos_start + options.center_bp + hyp_scores = scores_h5['grads'][:,pos_start:pos_end] + seqs_1hot = scores_h5['seqs'][:,pos_start:pos_end] + seq_chrs = [chrm.decode('UTF-8') for chrm in scores_h5['chr']] + seq_genes = [gene.decode('UTF-8') for gene in scores_h5['gene']] + seq_starts = scores_h5['start'][:] + pos_start + seq_ends = scores_h5['end'][:] - (seq_len - pos_end) + + if tissue_genes is not None : + gene_dict = {gene.split(".")[0] : gene_i for gene_i, gene in enumerate(seq_genes)} + + #Get index of rows to keep + keep_index = [] + for tissue_gene in tissue_genes : + keep_index.append(gene_dict[tissue_gene]) + + #Filter/sub-select data + hyp_scores = hyp_scores[keep_index, ...] + seqs_1hot = seqs_1hot[keep_index, ...] + seq_chrs = [seq_chrs[k_ix] for k_ix in keep_index] + seq_genes = [seq_genes[k_ix] for k_ix in keep_index] + seq_starts = seq_starts[keep_index, ...] + seq_ends = seq_ends[keep_index, ...] + + print("Filtered genes = " + str(hyp_scores.shape[0])) + + num_seqs, seq_len, _ = seqs_1hot.shape + print('DONE in %ds.' % (time.time()-t0)) + + # average across targets + hyp_scores = hyp_scores.mean(axis=-1, dtype='float32') + + # normalize scores by sequence + t0 = time.time() + print('Normalizing scores...', flush=True, end='') + if options.norm_type == 'max': + scores_max = hyp_scores.std(axis=-1).max(axis=-1) + max_clip = np.percentile(scores_max, options.clip_perc) + scores_max = np.clip(scores_max, max_clip, np.inf) + hyp_scores /= np.reshape(scores_max, (num_seqs,1,1)) + elif options.norm_type == 'gaussian': + scores_std = hyp_scores.std(axis=-1) + scores_std_wide = gaussian_filter1d(scores_std, sigma=1280, truncate=2) + wide_clip = np.percentile(scores_std_wide, options.clip_perc) + scores_std_wide = np.clip(scores_std_wide, wide_clip, np.inf) + hyp_scores /= np.expand_dims(scores_std_wide, axis=-1) + else: + print('Unrecognized normalization %s' % options.norm_type) + print('DONE in %ds.' % (time.time()-t0)) + + ################################################ + # region filter + + if options.region is not None: + hyp_scores, seqs_1hot = filter_region(hyp_scores, seqs_1hot, + seq_genes, seq_starts, options.genes_gtf_file, options.region) + + # save to visualize individual examples + with h5py.File('%s/scores.h5'%options.out_dir, 'w') as scores_h5: + scores_h5.create_dataset('scores', data=hyp_scores, compression='gzip') + scores_h5.create_dataset('seqs', data=seqs_1hot, compression='gzip') + + ################################################ + # tfmodisco + + if isinstance(hyp_scores, list): + num_seqs = len(seqs_1hot) + contrib_scores = [np.multiply(hyp_scores[si], seqs_1hot[si]) for si in range(num_seqs)] + else: + num_seqs = seqs_1hot.shape[0] + contrib_scores = np.multiply(hyp_scores, seqs_1hot) + + # make seqlets to patterns factory + if options.kmer_len is not None and options.num_gaps is not None and options.num_mismatches is not None : + tfm_seqlets = modisco.tfmodisco_workflow.seqlets_to_patterns.TfModiscoSeqletsToPatternsFactory( + trim_to_window_size=options.modisco_window_size, + initial_flank_to_add=options.modisco_flank, + kmer_len=options.kmer_len, num_gaps=options.num_gaps, num_mismatches=options.num_mismatches, + final_min_cluster_size=20) + else : + tfm_seqlets = modisco.tfmodisco_workflow.seqlets_to_patterns.TfModiscoSeqletsToPatternsFactory( + trim_to_window_size=options.modisco_window_size, + initial_flank_to_add=options.modisco_flank, + final_min_cluster_size=20) + + # make modisco workflow + tfm_workflow = modisco.tfmodisco_workflow.workflow.TfModiscoWorkflow( + sliding_window_size=options.modisco_sliding_window_size, + flank_size=options.modisco_sliding_window_flank, + max_seqlets_per_metacluster=options.modisco_max_seqlets, + seqlets_to_patterns_factory=tfm_seqlets) + + # run modisco workflow + task_label = 'out0' + tfm_results = tfm_workflow( + task_names=[task_label], + contrib_scores={task_label: contrib_scores}, + hypothetical_contribs={task_label: hyp_scores}, + revcomp=False if options.force_fwd == 1 else True, + one_hot=seqs_1hot) + + # save results + tfm_h5_file = '%s/tfm.h5' % options.out_dir + with h5py.File(tfm_h5_file, 'w') as tfm_h5: + tfm_results.save_hdf5(tfm_h5) + + ################################################ + # extract motif PWMs + + at_pct = (1-options.gc_content)/2 + gc_pct = options.gc_content/2 + background = np.array([at_pct, gc_pct, gc_pct, at_pct]) + + tfm_pwms = {} + + with h5py.File(tfm_h5_file, 'r') as tfm_h5: + metacluster_names = [mcr.decode("utf-8") for mcr in list(tfm_h5["metaclustering_results"]["all_metacluster_names"][:])] + for metacluster_name in metacluster_names: + metacluster_grp = tfm_h5["metacluster_idx_to_submetacluster_results"][metacluster_name] + all_patterns = metacluster_grp["seqlets_to_patterns_result"]["patterns"]["all_pattern_names"][:] + all_pattern_names = [x.decode("utf-8") for x in list(all_patterns)] + for pattern_name in all_pattern_names: + pattern_id = (metacluster_name+'_'+pattern_name) + pattern = metacluster_grp["seqlets_to_patterns_result"]["patterns"][pattern_name] + fwd = np.array(pattern["sequence"]["fwd"]) + clip_pwm = ic_clip(fwd, options.ic_t, background) + if clip_pwm is None: + print('pattern_id: %s is skipped because no bp pass threshold.' % pattern_id) + else: + tfm_pwms[pattern_id] = clip_pwm + print('pattern_id: %s is converted to pwm.' % pattern_id) + + ################################################ + # tomtom + + # initialize MEME + modisco_meme_file = options.out_dir+'/modisco_' + options.out_dir.replace("/", "_") + '.meme' + modisco_meme_open = open(modisco_meme_file, 'w') + + # header + modisco_meme_open.write('MEME version 4\n\n') + modisco_meme_open.write('ALPHABET= ACGT\n\n') + modisco_meme_open.write('strands: + -\n\n') + modisco_meme_open.write('Background letter frequencies\n') + modisco_meme_open.write('A %f C %f G %f T %f\n\n' % tuple(background)) + + # PWMs + for key in tfm_pwms.keys(): + modisco_meme_open.write('MOTIF '+key+'\n') + modisco_meme_open.write('letter-probability matrix: alength= 4 w= ' + str(tfm_pwms[key].shape[0]) + '\n') + np.savetxt(modisco_meme_open, tfm_pwms[key]) + modisco_meme_open.write('\n') + + modisco_meme_open.close() + + # run tomtom + tomtom_cmd = '/home/drk/bin/tomtom -dist pearson -thresh 0.1 -oc %s %s %s' % \ + (options.out_dir, modisco_meme_file, options.meme_db) + subprocess.call(tomtom_cmd, shell=True) + + +def filter_region(scores, seqs_1hot, seq_genes, seq_starts, genes_gtf_file, region, min_size=64, ss_window=192, utr_window=192): + """Filter scores and sequences for a specific gene region.""" + num_seqs, seq_len, _ = seqs_1hot.shape + + # parse GTF + genes_gtf = pygene.GTF(genes_gtf_file) + + # collection regions + scores_region = [] + seqs_1hot_region = [] + + # for each gene sequence + print('Extracting region %s...' % region) + for gi in tqdm(range(num_seqs)): + gene_id = seq_genes[gi] + gene = genes_gtf.genes[gene_id] + + # collect regions + region_starts = [] + region_ends = [] + for _, tx in gene.transcripts.items(): + tx.define_utrs() + if region in ['3utr','utr3']: + for utr in tx.utrs3: + region_starts.append(utr.start) + region_ends.append(utr.end) + + elif region.find('ss') != -1: + if region in ['ss5', '5ss']: + if tx.strand == '+': + exon_side = 'end' + else: + exon_side = 'start' + else: + if tx.strand == '+': + exon_side = 'start' + else: + exon_side = 'end' + + if exon_side == 'start': + for ei in range(1, len(tx.exons)): + ss_start = tx.exons[ei].start - ss_window//2 + ss_end = ss_start + ss_window + region_starts.append(ss_start) + region_ends.append(ss_end) + else: + for ei in range(len(tx.exons)-1): + ss_start = tx.exons[ei].end - ss_window//2 + ss_end = ss_start + ss_window + region_starts.append(ss_start) + region_ends.append(ss_end) + else: + print('Unrecognized region %s' % region, file=sys.stderr) + + num_regions = len(region_starts) + if num_regions > 0: + # merge + region_ranges = pyranges.PyRanges(chromosomes=[tx.chrom]*num_regions, + starts=region_starts, ends=region_ends) + region_ranges = region_ranges.merge() + + # for each region + for _, rr in region_ranges.df.iterrows(): + # collect scores + scores_start = max(0, rr.Start - seq_starts[gi]) + scores_end = min(seq_len, rr.End - seq_starts[gi]) + + skip_region = False + + # check splice site length match + if region.find('ss') != -1: + if scores_end - scores_start != ss_window: + skip_region = True + + else: + # sample variable length window + if scores_end - scores_start < utr_window: + skip_region = True + else: + scores_std = scores[gi,scores_start:scores_end].std(axis=-1) + scores_len = len(scores_std) + scores_peak = np.argmax(scores_std) + scores_peak = max(utr_window//2, scores_peak) + scores_peak = min(scores_len-utr_window//2, scores_peak) + scores_start += scores_peak - utr_window//2 + scores_end = scores_start + utr_window + + if not skip_region: + scores_region_ri = scores[gi,scores_start:scores_end] + seqs_1hot_ri = seqs_1hot[gi,scores_start:scores_end] + if gene.strand == '-': + scores_region_ri = dna_io.hot1_rc(scores_region_ri) + seqs_1hot_ri = dna_io.hot1_rc(seqs_1hot_ri) + scores_region.append(scores_region_ri) + seqs_1hot_region.append(seqs_1hot_ri) + + scores_region = np.array(scores_region) + seqs_1hot_region = np.array(seqs_1hot_region) + + return scores_region, seqs_1hot_region + + +def ic_clip(pwm, threshold, background=[0.25]*4): + """Clip PWM sides with an information content threshold.""" + + pc = 0.001 + odds_ratio = ((pwm+pc)/(1+4*pc)) / (background[None,:]) + ic = (np.log((pwm+pc)/(1+4*pc)) / np.log(2))*pwm + ic -= (np.log(background)*background/np.log(2))[None,:] + ic_total = np.sum(ic,axis=1)[:,None] + + # no bp pass threshold + if ~np.any(ic_total.flatten()>threshold): + return None + else: + left = np.where(ic_total>threshold)[0][0] + right = np.where(ic_total>threshold)[0][-1] + return pwm[left:(right+1)] + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() diff --git a/src/scripts/borzoi_tfmodisco_diff.py b/src/scripts/borzoi_tfmodisco_diff.py new file mode 100644 index 0000000..a0119f9 --- /dev/null +++ b/src/scripts/borzoi_tfmodisco_diff.py @@ -0,0 +1,532 @@ +#!/usr/bin/env python +from optparse import OptionParser +from collections import Counter +import os +import pdb +import subprocess +import time + +import h5py +import numpy as np +import pandas as pd +import pybedtools +import pyranges +from scipy.ndimage import gaussian_filter1d +from tqdm import tqdm + +from matplotlib import pyplot as plt +import seaborn as sns + +from baskerville import dna_io +import pygene +import modisco + +import gc + +''' +borzoi_tfmodisco_diff.py + +Run TF Modisco on difference between borzoi input saliency scores for multiple tissues. +''' + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ' + parser = OptionParser(usage) + parser.add_option( + '-c', + dest='center_bp', + default=None, + type='int', + help='Extract only center bp [Default: %default]', + ) + parser.add_option( + '-d', + dest='meme_db', + default='/homde/drk/code/meme-5.4.1/motif_databases/CIS-BP_2.00/Homo_sapiens.meme', + help='Meme database [Default: %default]', + ) + parser.add_option( + '-g', + dest='genes_gtf_file', + default='/home/drk/common/data/genomes/hg38/genes/gencode38/gencode38_basic_protein.gtf', + help='Gencode GTF [Default: %default]', + ) + parser.add_option( + '--gc', + dest='gc_content', + default=0.41, + type='float', + help='Genome GC content [Default: %default]', + ) + parser.add_option( + '--fwd', + dest='force_fwd', + default=0, + type='int', + help='Do not use rev-comp in modisco [Default: %default]', + ) + parser.add_option( + '--modisco_window_size', + dest='modisco_window_size', + default=24, + type='int', + help='Modisco window size [Default: %default]', + ) + parser.add_option( + '--modisco_flank', + dest='modisco_flank', + default=8, + type='int', + help='Modisco flanks to add [Default: %default]', + ) + parser.add_option( + '--modisco_sliding_window_size', + dest='modisco_sliding_window_size', + default=18, + type='int', + help='Modisco sliding window size [Default: %default]', + ) + parser.add_option( + '--modisco_sliding_window_flank', + dest='modisco_sliding_window_flank', + default=8, + type='int', + help='Modisco sliding window flanks [Default: %default]', + ) + parser.add_option( + '--modisco_max_seqlets', + dest='modisco_max_seqlets', + default=20000, + type='int', + help='Modisco sliding window flanks [Default: %default]', + ) + parser.add_option( + '-i', + dest='ic_t', + default=0.1, + type='float', + help='Information content threshold [Default: %default]', + ) + parser.add_option( + '-n', + dest='norm_type', + default='max', + ) + parser.add_option( + '-o', + dest='out_dir', + default='tfm_out', + help='Output directory [Default: %default]', + ) + parser.add_option( + '-r', + dest='region', + default=None, + help='Limit to specific gene region [Default: %default', + ) + parser.add_option( + '-t', + dest='targets_file', + default=None, + type='str', + help='File specifying target indexes and labels in table format', + ) + parser.add_option( + '--kmer_len', + dest='kmer_len', + default=None, + type='int', + help='Extract only center bp [Default: %default]', + ) + parser.add_option( + '--num_gaps', + dest='num_gaps', + default=None, + type='int', + help='Extract only center bp [Default: %default]', + ) + parser.add_option( + '--num_mismatches', + dest='num_mismatches', + default=None, + type='int', + help='Extract only center bp [Default: %default]', + ) + parser.add_option( + '--tissue_files', + dest='tissue_files', + default=None, + type='str', + help='Comma-separated list of files containing saliency scores (h5 format).', + ) + parser.add_option( + '--tissue', + dest='tissue', + default=None, + type='str', + help='Main tissue name.', + ) + parser.add_option( + '--main_tissue_ix', + dest='main_tissue_ix', + default=None, + type='int', + help='Main tissue index in list of tissue files.', + ) + parser.add_option( + '--gene_file', + dest='gene_file', + default=None, + type='str', + help='Csv-file of gene metadata.', + ) + + (options,args) = parser.parse_args() + + if len(args) != 0: + parser.error('You must not supply any arguments.') + + options.tissue_files = [tissue for tissue in options.tissue_files.split(",")] + + # setup output dir + os.makedirs(options.out_dir, exist_ok=True) + + #Load gene dataframe and select tissue + gene_df = pd.read_csv(options.gene_file, sep='\t') + gene_df = gene_df.query("tissue == '" + str(options.tissue) + "'").copy().reset_index(drop=True) + gene_df = gene_df.drop(columns=['Unnamed: 0']) + + print("len(gene_df) = " + str(len(gene_df))) + + #Get list of genes for tissue + tissue_genes = gene_df['gene_base'].values.tolist() + + print("len(tissue_genes) = " + str(len(tissue_genes))) + + # read nucleotide scores + t0 = time.time() + print('Reading scores...', flush=True, end='') + + all_hyp_scores = [] + seqs_1hot = None + seq_chrs = None + seq_genes = None + seq_starts = None + seq_ends = None + + for scores_h5_file in options.tissue_files : + + with h5py.File(scores_h5_file, 'r') as scores_h5: + print("Reading '" + scores_h5_file + "'") + + seq_len = scores_h5['grads'].shape[1] + pos_start = seq_len//2 - options.center_bp//2 + pos_end = pos_start + options.center_bp + hyp_scores = scores_h5['grads'][:,pos_start:pos_end] + seqs_1hot = scores_h5['seqs'][:,pos_start:pos_end] + seq_chrs = [chrm.decode('UTF-8') for chrm in scores_h5['chr']] + seq_genes = [gene.decode('UTF-8') for gene in scores_h5['gene']] + seq_starts = scores_h5['start'][:] + pos_start + seq_ends = scores_h5['end'][:] - (seq_len - pos_end) + + gene_dict = {gene.split(".")[0] : gene_i for gene_i, gene in enumerate(seq_genes)} + + #Get index of rows to keep + keep_index = [] + for tissue_gene in tissue_genes : + keep_index.append(gene_dict[tissue_gene]) + + #Filter/sub-select data + hyp_scores = hyp_scores[keep_index, ...] + seqs_1hot = seqs_1hot[keep_index, ...] + seq_chrs = [seq_chrs[k_ix] for k_ix in keep_index] + seq_genes = [seq_genes[k_ix] for k_ix in keep_index] + seq_starts = seq_starts[keep_index, ...] + seq_ends = seq_ends[keep_index, ...] + + all_hyp_scores.append(hyp_scores[None, ...]) + + #Collect garbage + gc.collect() + + hyp_scores = np.concatenate(all_hyp_scores, axis=0) + + #Calculate differential scores + if options.main_tissue_ix is None : + tissue_ix = -1 + for h5_ix, scores_h5_file in enumerate(options.tissue_files) : + if options.tissue.split("_")[0] in scores_h5_file : + tissue_ix = h5_ix + break + else : + tissue_ix = options.main_tissue_ix + + print('tissue_ix = ' + str(tissue_ix)) + + score_2 = hyp_scores[tissue_ix, ...] + score_1 = np.mean(hyp_scores[np.arange(len(options.tissue_files)) != tissue_ix, ...], axis=0) + + hyp_scores = score_2 - score_1 + + num_seqs, seq_len, _ = seqs_1hot.shape + print('DONE in %ds.' % (time.time()-t0)) + + # average across targets + hyp_scores = hyp_scores.mean(axis=-1, dtype='float32') + + # normalize scores by sequence + t0 = time.time() + print('Normalizing scores...', flush=True, end='') + if options.norm_type == 'max': + scores_max = hyp_scores.std(axis=-1).max(axis=-1) + max_clip = np.percentile(scores_max, 25) + scores_max = np.clip(scores_max, max_clip, np.inf) + hyp_scores /= np.reshape(scores_max, (num_seqs,1,1)) + elif options.norm_type == 'gaussian': + scores_std = hyp_scores.std(axis=-1) + scores_std_wide = gaussian_filter1d(scores_std, sigma=1280, truncate=2) + wide_clip = np.percentile(scores_std_wide, 25) + scores_std_wide = np.clip(scores_std_wide, wide_clip, np.inf) + hyp_scores /= np.expand_dims(scores_std_wide, axis=-1) + else: + print('Unrecognized normalization %s' % options.norm_type) + print('DONE in %ds.' % (time.time()-t0)) + + ################################################ + # region filter + + if options.region is not None: + hyp_scores, seqs_1hot = filter_region(hyp_scores, seqs_1hot, + seq_genes, seq_starts, options.genes_gtf_file, options.region) + + # save to visualize individual examples + with h5py.File('%s/scores.h5'%options.out_dir, 'w') as scores_h5: + scores_h5.create_dataset('scores', data=hyp_scores, compression='gzip') + scores_h5.create_dataset('seqs', data=seqs_1hot, compression='gzip') + + ################################################ + # tfmodisco + + if isinstance(hyp_scores, list): + num_seqs = len(seqs_1hot) + contrib_scores = [np.multiply(hyp_scores[si], seqs_1hot[si]) for si in range(num_seqs)] + else: + num_seqs = seqs_1hot.shape[0] + contrib_scores = np.multiply(hyp_scores, seqs_1hot) + + # make seqlets to patterns factory + if options.kmer_len is not None and options.num_gaps is not None and options.num_mismatches is not None : + tfm_seqlets = modisco.tfmodisco_workflow.seqlets_to_patterns.TfModiscoSeqletsToPatternsFactory( + trim_to_window_size=options.modisco_window_size, + initial_flank_to_add=options.modisco_flank, + kmer_len=options.kmer_len, num_gaps=options.num_gaps, num_mismatches=options.num_mismatches, + final_min_cluster_size=20) + else : + tfm_seqlets = modisco.tfmodisco_workflow.seqlets_to_patterns.TfModiscoSeqletsToPatternsFactory( + trim_to_window_size=options.modisco_window_size, + initial_flank_to_add=options.modisco_flank, + final_min_cluster_size=20) + + # make modisco workflow + tfm_workflow = modisco.tfmodisco_workflow.workflow.TfModiscoWorkflow( + sliding_window_size=options.modisco_sliding_window_size, + flank_size=options.modisco_sliding_window_flank, + max_seqlets_per_metacluster=options.modisco_max_seqlets, + seqlets_to_patterns_factory=tfm_seqlets) + + # run modisco workflow + task_label = 'out0' + tfm_results = tfm_workflow( + task_names=[task_label], + contrib_scores={task_label: contrib_scores}, + hypothetical_contribs={task_label: hyp_scores}, + revcomp=False if options.force_fwd == 1 else True, + one_hot=seqs_1hot) + + # save results + tfm_h5_file = '%s/tfm.h5' % options.out_dir + with h5py.File(tfm_h5_file, 'w') as tfm_h5: + tfm_results.save_hdf5(tfm_h5) + + ################################################ + # extract motif PWMs + + at_pct = (1-options.gc_content)/2 + gc_pct = options.gc_content/2 + background = np.array([at_pct, gc_pct, gc_pct, at_pct]) + + tfm_pwms = {} + + with h5py.File(tfm_h5_file, 'r') as tfm_h5: + metacluster_names = [mcr.decode("utf-8") for mcr in list(tfm_h5["metaclustering_results"]["all_metacluster_names"][:])] + for metacluster_name in metacluster_names: + metacluster_grp = tfm_h5["metacluster_idx_to_submetacluster_results"][metacluster_name] + all_patterns = metacluster_grp["seqlets_to_patterns_result"]["patterns"]["all_pattern_names"][:] + all_pattern_names = [x.decode("utf-8") for x in list(all_patterns)] + for pattern_name in all_pattern_names: + pattern_id = (metacluster_name+'_'+pattern_name) + pattern = metacluster_grp["seqlets_to_patterns_result"]["patterns"][pattern_name] + fwd = np.array(pattern["sequence"]["fwd"]) + clip_pwm = ic_clip(fwd, options.ic_t, background) + if clip_pwm is None: + print('pattern_id: %s is skipped because no bp pass threshold.' % pattern_id) + else: + tfm_pwms[pattern_id] = clip_pwm + print('pattern_id: %s is converted to pwm.' % pattern_id) + + ################################################ + # tomtom + + # initialize MEME + modisco_meme_file = options.out_dir+'/modisco_' + options.out_dir.replace("/", "_") + '.meme' + modisco_meme_open = open(modisco_meme_file, 'w') + + # header + modisco_meme_open.write('MEME version 4\n\n') + modisco_meme_open.write('ALPHABET= ACGT\n\n') + modisco_meme_open.write('strands: + -\n\n') + modisco_meme_open.write('Background letter frequencies\n') + modisco_meme_open.write('A %f C %f G %f T %f\n\n' % tuple(background)) + + # PWMs + for key in tfm_pwms.keys(): + modisco_meme_open.write('MOTIF '+key+'\n') + modisco_meme_open.write('letter-probability matrix: alength= 4 w= ' + str(tfm_pwms[key].shape[0]) + '\n') + np.savetxt(modisco_meme_open, tfm_pwms[key]) + modisco_meme_open.write('\n') + + modisco_meme_open.close() + + # run tomtom + tomtom_cmd = '/home/drk/bin/tomtom -dist pearson -thresh 0.1 -oc %s %s %s' % \ + (options.out_dir, modisco_meme_file, options.meme_db) + subprocess.call(tomtom_cmd, shell=True) + + +def filter_region(scores, seqs_1hot, seq_genes, seq_starts, genes_gtf_file, region, min_size=64, ss_window=192, utr_window=192): + """Filter scores and sequences for a specific gene region.""" + num_seqs, seq_len, _ = seqs_1hot.shape + + # parse GTF + genes_gtf = pygene.GTF(genes_gtf_file) + + # collection regions + scores_region = [] + seqs_1hot_region = [] + + # for each gene sequence + print('Extracting region %s...' % region) + for gi in tqdm(range(num_seqs)): + gene_id = seq_genes[gi] + gene = genes_gtf.genes[gene_id] + + # collect regions + region_starts = [] + region_ends = [] + for _, tx in gene.transcripts.items(): + tx.define_utrs() + if region in ['3utr','utr3']: + for utr in tx.utrs3: + region_starts.append(utr.start) + region_ends.append(utr.end) + + elif region.find('ss') != -1: + if region in ['ss5', '5ss']: + if tx.strand == '+': + exon_side = 'end' + else: + exon_side = 'start' + else: + if tx.strand == '+': + exon_side = 'start' + else: + exon_side = 'end' + + if exon_side == 'start': + for ei in range(1, len(tx.exons)): + ss_start = tx.exons[ei].start - ss_window//2 + ss_end = ss_start + ss_window + region_starts.append(ss_start) + region_ends.append(ss_end) + else: + for ei in range(len(tx.exons)-1): + ss_start = tx.exons[ei].end - ss_window//2 + ss_end = ss_start + ss_window + region_starts.append(ss_start) + region_ends.append(ss_end) + else: + print('Unrecognized region %s' % region, file=sys.stderr) + + num_regions = len(region_starts) + if num_regions > 0: + # merge + region_ranges = pyranges.PyRanges(chromosomes=[tx.chrom]*num_regions, + starts=region_starts, ends=region_ends) + region_ranges = region_ranges.merge() + + # for each region + for _, rr in region_ranges.df.iterrows(): + # collect scores + scores_start = max(0, rr.Start - seq_starts[gi]) + scores_end = min(seq_len, rr.End - seq_starts[gi]) + + skip_region = False + + # check splice site length match + if region.find('ss') != -1: + if scores_end - scores_start != ss_window: + skip_region = True + + else: + # sample variable length window + if scores_end - scores_start < utr_window: + skip_region = True + else: + scores_std = scores[gi,scores_start:scores_end].std(axis=-1) + scores_len = len(scores_std) + scores_peak = np.argmax(scores_std) + scores_peak = max(utr_window//2, scores_peak) + scores_peak = min(scores_len-utr_window//2, scores_peak) + scores_start += scores_peak - utr_window//2 + scores_end = scores_start + utr_window + + if not skip_region: + scores_region_ri = scores[gi,scores_start:scores_end] + seqs_1hot_ri = seqs_1hot[gi,scores_start:scores_end] + if gene.strand == '-': + scores_region_ri = dna_io.hot1_rc(scores_region_ri) + seqs_1hot_ri = dna_io.hot1_rc(seqs_1hot_ri) + scores_region.append(scores_region_ri) + seqs_1hot_region.append(seqs_1hot_ri) + + scores_region = np.array(scores_region) + seqs_1hot_region = np.array(seqs_1hot_region) + + return scores_region, seqs_1hot_region + + +def ic_clip(pwm, threshold, background=[0.25]*4): + """Clip PWM sides with an information content threshold.""" + + pc = 0.001 + odds_ratio = ((pwm+pc)/(1+4*pc)) / (background[None,:]) + ic = (np.log((pwm+pc)/(1+4*pc)) / np.log(2))*pwm + ic -= (np.log(background)*background/np.log(2))[None,:] + ic_total = np.sum(ic,axis=1)[:,None] + + # no bp pass threshold + if ~np.any(ic_total.flatten()>threshold): + return None + else: + left = np.where(ic_total>threshold)[0][0] + right = np.where(ic_total>threshold)[0][-1] + return pwm[left:(right+1)] + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main()