-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathgenome_mince.py
305 lines (273 loc) · 11.4 KB
/
genome_mince.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
if __name__ == '__main__' and __package__ is None:
from os import sys, path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
from random import randint, choice, sample
from preprocessing.process_inputs import seq2kmers, ALPHABET
from preprocessing.genome_db import GenomeDB
from collections import defaultdict
from datasketch import MinHash
import numpy as np
import json
import logging
from tqdm import tqdm
from pprint import pprint
import os.path
from threading import Thread
import argparse
ALPHABET_all = ALPHABET.lower() + ALPHABET.upper()
def kmer_profile(seq, k=7):
profile = defaultdict(int)
for i in range(len(seq) - k + 1):
kmer = seq[i:i+k]
profile[kmer] += 1
return profile
def kmer_dist(kmers1, kmers2):
sum_ = 0
for kmer in set(kmers1).union(set(kmers2)):
sum_ += (kmers1[kmer] - kmers2[kmer])**2
return np.sqrt(sum_)
def kmer_dist_np(kmers1, kmers2):
kmers = list(set(kmers1).union(set(kmers2)))
return np.linalg.norm(np.array([kmers1[kmer] for kmer in kmers])
- np.array([kmers2[kmer] for kmer in kmers]))
def minhash(seq, k, s):
m = MinHash(num_perm=s)
for word in seq2kmers(seq, k=k, stride=1, pad=False):
m.update(word.encode())
return tuple(m.digest())
def minhash_exists(minhash, minhashes):
return minhash in minhashes
def get_fragment(seq, fragment_size=500, k=3):
"""returns a randomly chosen fragment of sequence with given size
:param SeqRecord seq: SeqRecord from which to extract fragment
:param fragment_size: size in kmers of the fragment
:param k: number of nucleotides to group to one kmer
:returns: fragment of given size or None if fragment_size is greater
than sequence size
"""
if (len(seq) // k < fragment_size):
return None
start = randint(0, len(seq) - (fragment_size * k))
end = start + fragment_size * k
return seq[start:end]
return seq2kmers(str(seq[start:end].seq))
def pick_fragment(taxids, profiles, genome_db, profile_fun, similarity_fun,
max_its=100, nonalph_cutoff=None, **get_kwargs):
"""returns one randomly chosen fragment from the SeqRecords of the given taxids
:param taxids: ids from which to choose, must have entries in genome_db
:param profiles: set of objects to which the result of profile_fun is
compared
:param genome_db: database allowing access to the SeqRecord given a taxid
:param profile_fun: function which creates a profile of a given sequence
:param similarity_fun: function which compares the profile of the fragment
candidate to all other profiles. if True, fragment
will *not* be picked
:param max_its: maximum number of attempts until giving up, can be None
:param get_kwargs: kwargs to pass to get_fragment
:returns: fragment and its kmer-profile; or None if none could be picked
"""
its = 0
fragment = None
profile = None
while fragment is None:
its += 1
if (max_its is not None and its > max_its):
return None
taxid = choice(taxids)
record = genome_db[taxid]
fragment = get_fragment(record, **get_kwargs)
if fragment is None:
continue
if (nonalph_cutoff is not None):
nonalph_perc = (
len([letter for letter in fragment
if letter.upper() not in ALPHABET])
/ len(fragment))
if (nonalph_perc > nonalph_cutoff):
fragment = None
profile = None
continue
profile = profile_fun(str(fragment.seq))
if (similarity_fun(profile, profiles)):
fragment = None
profile = None
continue
return fragment, profile, taxid
def pick_fragment_nocomp(taxids, genome_db, max_its=100, nonalph_cutoff=None, **get_kwargs):
its = 0
fragment = None
while fragment is None:
its += 1
if (max_its is not None and its > max_its):
return None
taxid = choice(taxids)
record = genome_db[taxid]
fragment = get_fragment(record, **get_kwargs)
if fragment is None:
continue
if (nonalph_cutoff is not None):
nonalph = 0
for letter in fragment:
if (letter not in ALPHABET_all):
nonalph += 1
nonalph_perc = nonalph / len(fragment)
if (nonalph_perc > nonalph_cutoff):
fragment = None
continue
return fragment, taxid
def get_sk_fragments(nr_fragments, orders_dict, genome_db,
profile_fun, similarity_fun,
max_its=None, order_max_its=None,
nonalph_cutoff=None, thread_nr=None,
**get_fragment_kwargs):
"""obtains specified number of fragments based on the given order taxids
:param nr_fragments: number of fragments to obtain
:param orders_dict: dictionary, mapping order taxids to species taxids
:param genome_db: genomeDB containing genomes of the species taxids
:param profile_fun: function which creates a profile of a given sequence
:param similarity_fun: function which compares the profile of the fragment
candidate to all other profiles. if True, fragment
will *not* be picked
:param max_its: maximum number of iterations until giving up
:param order_max_its: maximum number of iterations per order
:returns: fragments and kmer-profiles
"""
def finished():
if (max_its is not None and iterations >= max_its):
logging.warning("maximum iterations reached!")
return True
return (len(fragments) >= nr_fragments)
profiles = []
fragments = []
speciess = []
iterations = 0
pbar = tqdm(total=nr_fragments, position=thread_nr,
desc=genome_db.name)
while not finished():
for order in sample(list(orders_dict), len(orders_dict)):
if (finished()):
break
iterations += 1
avail_taxids = [taxid for taxid in orders_dict[order]
if taxid in genome_db]
if (len(avail_taxids)) == 0:
continue
picked = pick_fragment(avail_taxids, profiles, genome_db,
profile_fun, similarity_fun,
order_max_its, nonalph_cutoff,
**get_fragment_kwargs)
if (picked is None):
continue
fragment, profile, species = picked
fragments.append(fragment)
profiles.append(profile)
speciess.append(species)
pbar.update()
pbar.close()
print(f'{iterations} iterations on superkingdom level')
return fragments, profiles, speciess
def get_sk_fragments_nocomp(nr_fragments, orders_dict, genome_db, order_max_its=None,
nonalph_cutoff=None, **get_fragment_kwargs):
def finished():
return (len(fragments) >= nr_fragments)
fragments = []
speciess = []
iterations = 0
logging.info('precaching taxids with available genomes for each order... ')
avail_taxids = {}
for order in tqdm(orders_dict):
taxids = [taxid for taxid in orders_dict[order]
if taxid in genome_db]
if len(taxids) > 0:
avail_taxids[order] = taxids
orders_keys = list(avail_taxids)
logging.info('done.')
pbar = tqdm(total=nr_fragments,
desc=genome_db.name)
while not finished():
for order in sample(orders_keys, len(orders_keys)):
if (finished()):
break
iterations += 1
taxids = avail_taxids[order]
picked = pick_fragment_nocomp(taxids, genome_db,
order_max_its, nonalph_cutoff,
**get_fragment_kwargs)
if (picked is None):
continue
fragment, species = picked
fragments.append(fragment)
speciess.append(species)
pbar.update()
pbar.close()
print(f'{iterations} iterations on superkingdom level')
return fragments, speciess
def load_genomes(genome_dir, sk, thr=16e9):
def from_fastadir(sk_dir):
fastas = [os.path.join(sk_dir, _.strip())
for _ in open(os.path.join(sk_dir, 'files.txt')).readlines()]
return GenomeDB(fastas, os.path.join(sk_dir, 'mapping.tsv'), name=sk, size_thr=thr)
def from_fasta(fasta_file, mapping):
return GenomeDB(fasta_file, mapping, name=sk, size_thr=thr)
logging.info(f'loading {sk} genomes')
if (sk == 'Archaea'):
return from_fastadir(
os.path.join(genome_dir, 'Archaea'))
elif (sk == 'Bacteria'):
return from_fasta(
os.path.join(genome_dir, 'Bacteria/full_genome_bacteria.fna'),
os.path.join(genome_dir, 'mapping_Bacteria.tsv'))
elif (sk == 'Eukaryota'):
return from_fastadir(
os.path.join(genome_dir, 'Eukaryota'))
elif (sk == 'Viruses'):
return from_fasta(
os.path.join(genome_dir, 'Viruses/all_viruses_db.fa'),
os.path.join(genome_dir, 'mapping_Viruses.tsv'))
else:
raise Exception(f'genomes of {sk} not available')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('sk')
parser.add_argument('nr_fragments', type=int)
parser.add_argument('--outdir', '-o', default='.')
parser.add_argument('--thr', '-t', type=float, default=16e9)
parser.add_argument('--nonalph_cutoff', type=float, default=None)
parser.add_argument('--no_comp', action='store_true')
parser.add_argument('--genome_dir', default='genomes/')
parser.add_argument('--sk_order_dict', default='sk_order_dict.json')
logging.getLogger().setLevel(logging.INFO)
args = parser.parse_args()
nr_seqs = args.nr_fragments
sk = args.sk
sk_order_dict = json.load(
open(args.sk_order_dict))
def minhash_defined(seq):
return minhash(seq, 6, 6)
def fake_profile_fun(seq):
global counter
counter += 1
return counter
def sk_fragments_plus_stats(sk, outdir='.', thread_nr=None):
genome_db = load_genomes(
args.genome_dir, sk, thr=args.thr)
if (args.no_comp):
global counter
counter = 0
fragments, profiles, speciess = get_sk_fragments_nocomp(
nr_fragments=nr_seqs, orders_dict=sk_order_dict[sk], genome_db=genome_db,
nonalph_cutoff=args.nonalph_cutoff)
else:
fragments, profiles, speciess = get_sk_fragments(
nr_seqs, sk_order_dict[sk], genome_db, minhash_defined,
minhash_exists, 10_000_000, None, args.nonalph_cutoff, thread_nr)
print(
f'{len(fragments)} fragments generated, alongside {len(profiles)} '
'unique profiles')
json.dump([str(seq.seq) for seq in fragments],
open(os.path.join(outdir, f'{sk}_fragments.json'), 'w'))
with open(os.path.join(outdir, f'{sk}_species_picked.txt'), 'w') as f:
f.writelines(str(sp) + '\n' for sp in speciess)
logging.info(f'generating fragments for {sk}')
sk_fragments_plus_stats(sk, args.outdir)
logging.info(f'[{sk}] done')