diff --git a/src/sourmash/index.py b/src/sourmash/index.py index 6574d998e7..07e9c21b6a 100644 --- a/src/sourmash/index.py +++ b/src/sourmash/index.py @@ -126,6 +126,7 @@ def select(self, ksize=None, moltype=None): "" class LinearIndex(Index): + "An Index for a collection of signatures. Can load from a .sig file." def __init__(self, _signatures=None, filename=None): self._signatures = [] if _signatures: @@ -155,11 +156,97 @@ def load(cls, location): return lidx def select(self, ksize=None, moltype=None): - def select_sigs(siglist, ksize, moltype): - for ss in siglist: - if (ksize is None or ss.minhash.ksize == ksize) and \ - (moltype is None or ss.minhash.moltype == moltype): - yield ss + def select_sigs(ss, ksize=ksize, moltype=moltype): + if (ksize is None or ss.minhash.ksize == ksize) and \ + (moltype is None or ss.minhash.moltype == moltype): + return True + + return self.filter(select_sigs) + + def filter(self, filter_fn): + siglist = [] + for ss in self._signatures: + if filter_fn(ss): + siglist.append(ss) - siglist=select_sigs(self._signatures, ksize, moltype) return LinearIndex(siglist, self.filename) + + +class MultiIndex(Index): + """An Index class that wraps other Index classes. + + The MultiIndex constructor takes two arguments: a list of Index + objects, and a matching list of sources (filenames, etc.) If the + source is not None, then it will be used to override the 'filename' + in the triple that is returned by search and gather. + + One specific use for this is when loading signatures from a directory; + MultiIndex will properly record which files provided which signatures. + """ + def __init__(self, index_list, source_list): + self.index_list = list(index_list) + self.source_list = list(source_list) + assert len(index_list) == len(source_list) + + def signatures(self): + for idx in self.index_list: + for ss in idx.signatures(): + yield ss + + def __len__(self): + return sum([ len(idx) for idx in self.index_list ]) + + def insert(self, *args): + raise NotImplementedError + + @classmethod + def load(self, *args): + raise NotImplementedError + + def save(self, *args): + raise NotImplementedError + + def select(self, ksize=None, moltype=None): + new_idx_list = [] + new_src_list = [] + for idx, src in zip(self.index_list, self.source_list): + idx = idx.select(ksize=ksize, moltype=moltype) + new_idx_list.append(idx) + new_src_list.append(src) + + return MultiIndex(new_idx_list, new_src_list) + + def filter(self, filter_fn): + new_idx_list = [] + new_src_list = [] + for idx, src in zip(self.index_list, self.source_list): + idx = idx.filter(filter_fn) + new_idx_list.append(idx) + new_src_list.append(src) + + return MultiIndex(new_idx_list, new_src_list) + + def search(self, query, *args, **kwargs): + # do the actual search: + matches = [] + for idx, src in zip(self.index_list, self.source_list): + for (score, ss, filename) in idx.search(query, *args, **kwargs): + best_src = src or filename # override if src provided + matches.append((score, ss, best_src)) + + # sort! + matches.sort(key=lambda x: -x[0]) + return matches + + def gather(self, query, *args, **kwargs): + "Return the match with the best Jaccard containment in the Index." + # actually do search! + results = [] + for idx, src in zip(self.index_list, self.source_list): + for (score, ss, filename) in idx.gather(query, *args, **kwargs): + best_src = src or filename # override if src provided + results.append((score, ss, best_src)) + + results.sort(reverse=True, key=lambda x: (x[0], x[1].md5sum())) + + return results diff --git a/src/sourmash/sourmash_args.py b/src/sourmash/sourmash_args.py index 7fe25cb10b..5a3358783c 100644 --- a/src/sourmash/sourmash_args.py +++ b/src/sourmash/sourmash_args.py @@ -16,7 +16,7 @@ from . import signature from .logging import notify, error -from .index import LinearIndex +from .index import LinearIndex, MultiIndex from . import signature as sig from .sbt import SBT from .sbtmh import SigLeaf @@ -181,16 +181,7 @@ def traverse_find_sigs(filenames, yield_all_files=False): yield fullname -def filter_compatible_signatures(query, siglist, force=False): - for ss in siglist: - if check_signatures_are_compatible(query, ss): - yield ss - else: - if not force: - raise ValueError("incompatible signature") - - -def check_signatures_are_compatible(query, subject): +def _check_signatures_are_compatible(query, subject): # is one scaled, and the other not? cannot do search if query.minhash.scaled and not subject.minhash.scaled or \ not query.minhash.scaled and subject.minhash.scaled: @@ -275,20 +266,7 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None) sys.exit(-1) # are we collecting signatures from a directory/path? - # NOTE: error messages about loading will now be attributed to - # directory, not individual file. - if os.path.isdir(filename): - assert dbtype == DatabaseType.SIGLIST - - siglist = _select_sigs(db, moltype=query_moltype, ksize=query_ksize) - siglist = filter_compatible_signatures(query, siglist, True) - linear = LinearIndex(siglist, filename=filename) - databases.append(linear) - - n_signatures += len(linear) - - # SBT - elif dbtype == DatabaseType.SBT: + if dbtype == DatabaseType.SBT: if not check_tree_is_compatible(filename, db, query, is_similarity_query): sys.exit(-1) @@ -301,7 +279,6 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None) elif dbtype == DatabaseType.LCA: if not check_lca_db_is_compatible(filename, db, query): sys.exit(-1) - query_scaled = query.minhash.scaled notify('loaded LCA {}', filename, end='\r') n_databases += 1 @@ -310,26 +287,19 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None) # signature file elif dbtype == DatabaseType.SIGLIST: - siglist = _select_sigs(db, moltype=query_moltype, ksize=query_ksize) - try: - # CTB: it's not clear to me that filter_compatible_signatures - # should fail here, on incompatible signatures; but that's - # what we have it doing currently. Revisit. - siglist = filter_compatible_signatures(query, siglist, False) - siglist = list(siglist) - except ValueError: - siglist = [] - - if not siglist: - notify("no compatible signatures found in '{}'", filename) + db = db.select(moltype=query_moltype, ksize=query_ksize) + siglist = db.signatures() + filter_fn = lambda s: _check_signatures_are_compatible(query, s) + db = db.filter(filter_fn) + + if not db: + notify(f"no compatible signatures found in '{filename}'") sys.exit(-1) - linear = LinearIndex(siglist, filename=filename) - databases.append(linear) + databases.append(db) - notify('loaded {} signatures from {}', len(linear), - filename, end='\r') - n_signatures += len(linear) + notify(f'loaded {len(db)} signatures from {filename}', end='\r') + n_signatures += len(db) # unknown!? else: @@ -374,56 +344,58 @@ def _load_database(filename, traverse_yield_all, *, cache_size=None): # special case stdin if not loaded and filename == '-': - db = signature.load_signatures(sys.stdin, do_raise=True) - db = list(db) - loaded = True + db = LinearIndex.load(sys.stdin) dbtype = DatabaseType.SIGLIST + loaded = True - # load signatures from directory + # load signatures from directory, using MultiIndex to preserve source. if not loaded and os.path.isdir(filename): - all_sigs = [] + index_list = [] + source_list = [] for thisfile in traverse_find_sigs([filename], traverse_yield_all): try: - with open(thisfile, 'rt') as fp: - x = signature.load_signatures(fp, do_raise=True) - siglist = list(x) - all_sigs.extend(siglist) + idx = LinearIndex.load(thisfile) + index_list.append(idx) + source_list.append(thisfile) except (IOError, sourmash.exceptions.SourmashError): if traverse_yield_all: continue else: raise - loaded=True - db = all_sigs - dbtype = DatabaseType.SIGLIST - - # load signatures from single file - try: - # CTB: could make this a generator, with some trickery; but for - # now, just force into list. - with open(filename, 'rt') as fp: - db = signature.load_signatures(fp, do_raise=True) - db = list(db) + if index_list: + loaded=True + db = MultiIndex(index_list, source_list) + dbtype = DatabaseType.SIGLIST - loaded = True - dbtype = DatabaseType.SIGLIST - except Exception as exc: - pass + # load signatures from single signature file + if not loaded: + try: + with open(filename, 'rt') as fp: + db = LinearIndex.load(filename) + dbtype = DatabaseType.SIGLIST + loaded = True + except Exception as exc: + pass # try load signatures from single file (list of signature paths) + # use MultiIndex to preserve source filenames. if not loaded: try: - db = [] - with open(filename, 'rt') as fp: - for line in fp: - line = line.strip() - if line: - sigs = load_file_as_signatures(line) - db += list(sigs) + idx_list = [] + src_list = [] - loaded = True + file_list = load_file_list_of_signatures(filename) + for fname in file_list: + idx = load_file_as_index(fname) + src = fname + + idx_list.append(idx) + src_list.append(src) + + db = MultiIndex(idx_list, src_list) dbtype = DatabaseType.SIGLIST + loaded = True except Exception as exc: pass @@ -461,19 +433,11 @@ def _load_database(filename, traverse_yield_all, *, cache_size=None): raise OSError("Error while reading signatures from '{}' - got sequences instead! Is this a FASTA/FASTQ file?".format(filename)) if not loaded: - raise OSError("Error while reading signatures from '{}'.".format(filename)) + raise OSError(f"Error while reading signatures from '{filename}'.") return db, dbtype -# note: dup from index.py internal function. -def _select_sigs(siglist, ksize, moltype): - for ss in siglist: - if (ksize is None or ss.minhash.ksize == ksize) and \ - (moltype is None or ss.minhash.moltype == moltype): - yield ss - - def load_file_as_index(filename, yield_all_files=False): """Load 'filename' as a database; generic database loader. @@ -488,14 +452,7 @@ def load_file_as_index(filename, yield_all_files=False): attempt to load all files. """ db, dbtype = _load_database(filename, yield_all_files) - if dbtype in (DatabaseType.LCA, DatabaseType.SBT): - return db # already an index! - elif dbtype == DatabaseType.SIGLIST: - # turn siglist into a LinearIndex - idx = LinearIndex(db, filename) - return idx - else: - assert 0 # unknown enum!? + return db def load_file_as_signatures(filename, select_moltype=None, ksize=None, @@ -519,21 +476,15 @@ def load_file_as_signatures(filename, select_moltype=None, ksize=None, progress.notify(filename) db, dbtype = _load_database(filename, yield_all_files) - - loader = None - if dbtype in (DatabaseType.LCA, DatabaseType.SBT): - db = db.select(moltype=select_moltype, ksize=ksize) - loader = db.signatures() - elif dbtype == DatabaseType.SIGLIST: - loader = _select_sigs(db, moltype=select_moltype, ksize=ksize) - else: - assert 0 # unknown enum!? + db = db.select(moltype=select_moltype, ksize=ksize) + loader = db.signatures() if progress: return progress.start_file(filename, loader) else: return loader + def load_file_list_of_signatures(filename): "Load a list-of-files text file." try: diff --git a/tests/test_index.py b/tests/test_index.py index 1313162a2e..1b9ae93402 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -4,7 +4,7 @@ import sourmash from sourmash import load_one_signature, SourmashSignature -from sourmash.index import LinearIndex +from sourmash.index import LinearIndex, MultiIndex from sourmash.sbt import SBT, GraphFactory, Leaf import sourmash_tst_utils as utils @@ -393,3 +393,112 @@ def test_index_same_md5sum_zipstorage(c): # should have 3 files, 1 internal and two sigs. We check for 4 because the # directory also shows in namelist() assert len([f for f in zout.namelist() if f.startswith(".sbt.zzz/")]) == 4 + + +def test_multi_index_search(): + sig2 = utils.get_test_data('2.fa.sig') + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47) + ss63 = sourmash.load_one_signature(sig63) + + lidx1 = LinearIndex.load(sig2) + lidx2 = LinearIndex.load(sig47) + lidx3 = LinearIndex.load(sig63) + + # create MultiIindex with source location override + lidx = MultiIndex([lidx1, lidx2, lidx3], ['A', None, 'C']) + lidx = lidx.select(ksize=31) + + # now, search for sig2 + sr = lidx.search(ss2, threshold=1.0) + print([s[1].name for s in sr]) + assert len(sr) == 1 + assert sr[0][1] == ss2 + assert sr[0][2] == 'A' # source override + + # search for sig47 with lower threshold; search order not guaranteed. + sr = lidx.search(ss47, threshold=0.1) + print([s[1].name for s in sr]) + assert len(sr) == 2 + sr.sort(key=lambda x: -x[0]) + assert sr[0][1] == ss47 + assert sr[0][2] == sig47 # source was set to None, so no override + assert sr[1][1] == ss63 + assert sr[1][2] == 'C' # source override + + # search for sig63 with lower threshold; search order not guaranteed. + sr = lidx.search(ss63, threshold=0.1) + print([s[1].name for s in sr]) + assert len(sr) == 2 + sr.sort(key=lambda x: -x[0]) + assert sr[0][1] == ss63 + assert sr[0][2] == 'C' # source override + assert sr[1][1] == ss47 + assert sr[1][2] == sig47 # source was set to None, so no override + + # search for sig63 with high threshold => 1 match + sr = lidx.search(ss63, threshold=0.8) + print([s[1].name for s in sr]) + assert len(sr) == 1 + sr.sort(key=lambda x: -x[0]) + assert sr[0][1] == ss63 + assert sr[0][2] == 'C' # source override + + +def test_multi_index_gather(): + sig2 = utils.get_test_data('2.fa.sig') + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47) + ss63 = sourmash.load_one_signature(sig63) + + lidx1 = LinearIndex.load(sig2) + lidx2 = LinearIndex.load(sig47) + lidx3 = LinearIndex.load(sig63) + + # create MultiIindex with source location override + lidx = MultiIndex([lidx1, lidx2, lidx3], ['A', None, 'C']) + lidx = lidx.select(ksize=31) + + matches = lidx.gather(ss2) + assert len(matches) == 1 + assert matches[0][0] == 1.0 + assert matches[0][2] == 'A' + + matches = lidx.gather(ss47) + assert len(matches) == 2 + assert matches[0][0] == 1.0 + assert matches[0][1] == ss47 + assert matches[0][2] == sig47 # no source override + assert round(matches[1][0], 2) == 0.49 + assert matches[1][1] == ss63 + assert matches[1][2] == 'C' # source override + + +def test_multi_index_signatures(): + sig2 = utils.get_test_data('2.fa.sig') + sig47 = utils.get_test_data('47.fa.sig') + sig63 = utils.get_test_data('63.fa.sig') + + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss47 = sourmash.load_one_signature(sig47) + ss63 = sourmash.load_one_signature(sig63) + + lidx1 = LinearIndex.load(sig2) + lidx2 = LinearIndex.load(sig47) + lidx3 = LinearIndex.load(sig63) + + # create MultiIindex with source location override + lidx = MultiIndex([lidx1, lidx2, lidx3], ['A', None, 'C']) + lidx = lidx.select(ksize=31) + + siglist = list(lidx.signatures()) + assert len(siglist) == 3 + assert ss2 in siglist + assert ss47 in siglist + assert ss63 in siglist diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index 477d00cf12..a1f2b55f6e 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -1811,6 +1811,36 @@ def test_search_metagenome_traverse(): assert '13 matches; showing first 3:' in out +def test_search_metagenome_traverse_check_csv(): + # this test confirms that the CSV 'filename' output for signatures loaded + # via directory traversal properly contains the actual path to the + # signature file from which the signature was loaded. + with utils.TempDirectory() as location: + testdata_dir = utils.get_test_data('gather') + + query_sig = utils.get_test_data('gather/combined.sig') + out_csv = os.path.join(location, 'out.csv') + + cmd = f'search {query_sig} {testdata_dir} -k 21 -o {out_csv}' + status, out, err = utils.runscript('sourmash', cmd.split(' '), + in_directory=location) + + print(out) + print(err) + + with open(out_csv, 'rt') as fp: + prefix_len = len(testdata_dir) + r = csv.DictReader(fp) + for row in r: + filename = row['filename'] + assert filename.startswith(testdata_dir) + # should have full path to file sig was loaded from + assert len(filename) > prefix_len + + assert ' 33.2% NC_003198.1 Salmonella enterica subsp. enterica serovar T...' in out + assert '13 matches; showing first 3:' in out + + @utils.in_thisdir def test_search_incompatible(c): num_sig = utils.get_test_data('num/47.fa.sig') @@ -3518,6 +3548,49 @@ def test_gather_metagenome_traverse(): 'NC_011294.1 Salmonella enterica subsp...' in out)) +def test_gather_metagenome_traverse_check_csv(): + # this test confirms that the CSV 'filename' output for signatures loaded + # via directory traversal properly contains the actual path to the + # signature file from which the signature was loaded. + with utils.TempDirectory() as location: + # set up a directory $location/gather that contains + # everything in the 'tests/test-data/gather' directory + # *except* the query sequence, which is 'combined.sig'. + testdata_dir = utils.get_test_data('gather') + copy_testdata = os.path.join(location, 'somesigs') + shutil.copytree(testdata_dir, copy_testdata) + os.unlink(os.path.join(copy_testdata, 'combined.sig')) + + query_sig = utils.get_test_data('gather/combined.sig') + out_csv = os.path.join(location, 'out.csv') + + # now, feed in the new directory -- + cmd = f'gather {query_sig} {copy_testdata} -k 21 --threshold-bp=0' + cmd += f' -o {out_csv}' + status, out, err = utils.runscript('sourmash', cmd.split(' '), + in_directory=location) + + print(cmd) + print(out) + print(err) + + with open(out_csv, 'rt') as fp: + prefix_len = len(copy_testdata) + r = csv.DictReader(fp) + for row in r: + filename = row['filename'] + assert filename.startswith(copy_testdata) + # should have full path to file sig was loaded from + assert len(filename) > prefix_len + + assert 'found 12 matches total' in out + assert 'the recovered matches hit 100.0% of the query' in out + assert all(('4.9 Mbp 33.2% 100.0%' in out, + 'NC_003198.1 Salmonella enterica subsp...' in out)) + assert all(('4.7 Mbp 0.5% 1.5%' in out, + 'NC_011294.1 Salmonella enterica subsp...' in out)) + + @utils.in_tempdir def test_gather_traverse_incompatible(c): searchdir = c.output('searchme') @@ -3736,7 +3809,7 @@ def test_gather_error_no_sigs_traverse(c): err = c.last_result.err print(err) - assert '** ERROR: no signatures or databases loaded?' in err + assert f"Error while reading signatures from '{emptydir}'" in err assert not 'found 0 matches total;' in err