Skip to content

Commit

Permalink
[MRG] Fix sourmash prefetch to work when db scaled is larger than q…
Browse files Browse the repository at this point in the history
…uery scaled (#1870)

* add test that breaks prefetch b/c of multiple ksizes

* fix prefetch

* prefetch fails when db scaled > query scaled

* successively downsample ident/noident too

* cleanup

* update threshold calc

* add test for scaled vals in output sigs

* Apply suggestions from code review

Co-authored-by: Tessa Pierce Ward <[email protected]>

* fix from code review

Co-authored-by: Tessa Pierce Ward <[email protected]>
  • Loading branch information
ctb and bluegenes authored Mar 7, 2022
1 parent 234df70 commit cccd06c
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 13 deletions.
25 changes: 18 additions & 7 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,11 +711,8 @@ def gather(args):
# optionally calculate and save prefetch csv
if prefetch_csvout_fp:
assert scaled
# calculate expected threshold
threshold = args.threshold_bp / scaled

# calculate intersection stats and info
prefetch_result = calculate_prefetch_info(prefetch_query, found_sig, scaled, threshold)
prefetch_result = calculate_prefetch_info(prefetch_query, found_sig, scaled, args.threshold_bp)
# remove match and query signatures; write result to prefetch csv
d = dict(prefetch_result._asdict())
del d['match']
Expand Down Expand Up @@ -1168,7 +1165,9 @@ def prefetch(args):
if args.scaled:
notify(f'downsampling query from scaled={query_mh.scaled} to {int(args.scaled)}')
query_mh = query_mh.downsample(scaled=args.scaled)

notify(f"all sketches will be downsampled to scaled={query_mh.scaled}")
common_scaled = query_mh.scaled

# empty?
if not len(query_mh):
Expand Down Expand Up @@ -1223,10 +1222,21 @@ def prefetch(args):
for result in prefetch_database(query, db, args.threshold_bp):
match = result.match

# ensure we're all on the same page wrt scaled resolution:
common_scaled = max(match.minhash.scaled, query.minhash.scaled,
common_scaled)

query_mh = query.minhash.downsample(scaled=common_scaled)
match_mh = match.minhash.downsample(scaled=common_scaled)

if ident_mh.scaled != common_scaled:
ident_mh = ident_mh.downsample(scaled=common_scaled)
if noident_mh.scaled != common_scaled:
noident_mh = noident_mh.downsample(scaled=common_scaled)

# track found & "untouched" hashes.
match_mh = match.minhash.downsample(scaled=query.minhash.scaled)
ident_mh += query.minhash & match_mh.flatten()
noident_mh.remove_many(match.minhash)
ident_mh += query_mh & match_mh.flatten()
noident_mh.remove_many(match_mh)

# output match info as we go
if csvout_fp:
Expand Down Expand Up @@ -1265,6 +1275,7 @@ def prefetch(args):
assert len(query_mh) == len(ident_mh) + len(noident_mh)
notify(f"of {len(query_mh)} distinct query hashes, {len(ident_mh)} were found in matches above threshold.")
notify(f"a total of {len(noident_mh)} query hashes remain unmatched.")
notify(f"final scaled value (max across query and all matches) is {common_scaled}")

if args.save_matching_hashes:
filename = args.save_matching_hashes
Expand Down
12 changes: 6 additions & 6 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,16 +469,20 @@ def __next__(self):
'intersect_bp, jaccard, max_containment, f_query_match, f_match_query, match, match_filename, match_name, match_md5, match_bp, query, query_filename, query_name, query_md5, query_bp')


def calculate_prefetch_info(query, match, scaled, threshold):
def calculate_prefetch_info(query, match, scaled, threshold_bp):
"""
For a single query and match, calculate all search info and return a PrefetchResult.
"""
# base intersections on downsampled minhashes
query_mh = query.minhash

scaled = max(scaled, match.minhash.scaled)
query_mh = query_mh.downsample(scaled=scaled)
db_mh = match.minhash.flatten().downsample(scaled=scaled)

# calculate db match intersection with query hashes:
intersect_mh = query_mh & db_mh
threshold = threshold_bp / scaled
assert len(intersect_mh) >= threshold

f_query_match = db_mh.contained_by(query_mh)
Expand Down Expand Up @@ -515,12 +519,8 @@ def prefetch_database(query, database, threshold_bp):
scaled = query_mh.scaled
assert scaled

# for testing/double-checking purposes, calculate expected threshold -
threshold = threshold_bp / scaled

# iterate over all signatures in database, find matches

for result in database.prefetch(query, threshold_bp):
match = result.signature
result = calculate_prefetch_info(query, match, scaled, threshold)
result = calculate_prefetch_info(query, match, scaled, threshold_bp)
yield result
61 changes: 61 additions & 0 deletions tests/test_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,67 @@ def test_prefetch_select_query_ksize(runtmp, linear_gather):
assert 'of 4476 distinct query hashes, 4476 were found in matches above threshold.' in c.last_result.err


def test_prefetch_subject_scaled_is_larger(runtmp, linear_gather):
# test prefetch where subject scaled is larger
c = runtmp

# make a query sketch with scaled=1000
fa = utils.get_test_data('genome-s10.fa.gz')
c.run_sourmash('sketch', 'dna', fa, '-o', 'query.sig')
assert os.path.exists(runtmp.output('query.sig'))

# this has a scaled of 10000, from same genome:
against1 = utils.get_test_data('scaled/genome-s10.fa.gz.sig')
against2 = utils.get_test_data('scaled/all.sbt.zip')
against3 = utils.get_test_data('scaled/all.lca.json')

# run against large scaled, then small (self)
c.run_sourmash('prefetch', 'query.sig', against1, against2, against3,
'query.sig', linear_gather)
print(c.last_result.status)
print(c.last_result.out)
print(c.last_result.err)

assert c.last_result.status == 0
assert 'total of 8 matching signatures.' in c.last_result.err
assert 'of 48 distinct query hashes, 48 were found in matches above threshold.' in c.last_result.err
assert 'final scaled value (max across query and all matches) is 10000' in c.last_result.err


def test_prefetch_subject_scaled_is_larger_outsigs(runtmp, linear_gather):
# test prefetch where subject scaled is larger -- output sigs
c = runtmp

# make a query sketch with scaled=1000
fa = utils.get_test_data('genome-s10.fa.gz')
c.run_sourmash('sketch', 'dna', fa, '-o', 'query.sig')
assert os.path.exists(runtmp.output('query.sig'))

# this has a scaled of 10000, from same genome:
against1 = utils.get_test_data('scaled/genome-s10.fa.gz.sig')
against2 = utils.get_test_data('scaled/all.sbt.zip')
against3 = utils.get_test_data('scaled/all.lca.json')

# run against large scaled, then small (self)
c.run_sourmash('prefetch', 'query.sig', against1, against2, against3,
'query.sig', linear_gather, '--save-matches', 'matches.sig')
print(c.last_result.status)
print(c.last_result.out)
print(c.last_result.err)

assert c.last_result.status == 0
assert 'total of 8 matching signatures.' in c.last_result.err
assert 'of 48 distinct query hashes, 48 were found in matches above threshold.' in c.last_result.err
assert 'final scaled value (max across query and all matches) is 10000' in c.last_result.err

# make sure non-downsampled sketches were saved.
matches = sourmash.load_file_as_signatures(runtmp.output('matches.sig'))
scaled_vals = set([ match.minhash.scaled for match in matches ])
assert 1000 in scaled_vals
assert 10000 in scaled_vals
assert len(scaled_vals) == 2


def test_prefetch_query_abund(runtmp, linear_gather):
c = runtmp

Expand Down

0 comments on commit cccd06c

Please sign in to comment.