diff --git a/sourmash/commands.py b/sourmash/commands.py index 80ca775a59..61eabfffc0 100644 --- a/sourmash/commands.py +++ b/sourmash/commands.py @@ -10,7 +10,6 @@ import random import shutil import tempfile -import zipfile import screed from .sourmash_args import SourmashArgumentParser diff --git a/sourmash/sbt.py b/sourmash/sbt.py index 4c7837fd83..da44a6608c 100644 --- a/sourmash/sbt.py +++ b/sourmash/sbt.py @@ -43,7 +43,7 @@ def search_transcript(node, seq, threshold): from __future__ import print_function, unicode_literals, division -from collections import namedtuple, defaultdict +from collections import namedtuple try: from collections.abc import Mapping except ImportError: # Python 2... @@ -481,10 +481,21 @@ def load(cls, location, leaf_loader=None, storage=None, print_version_warning=Tr SBT the SBT tree built from the description. """ - dirname = os.path.dirname(os.path.abspath(location)) - sbt_name = os.path.basename(location) - if sbt_name.endswith('.sbt.json'): - sbt_name = sbt_name[:-9] + tempfile = None + if zipfile.is_zipfile(location): + tempfile = NamedTemporaryFile() + with zipfile.ZipFile(location, 'r') as zf: + tempfile.write(zf.read('tree.sbt.json')) + tempfile.flush() + + dirname = os.path.dirname(tempfile.name) + sbt_name = os.path.basename(tempfile.name) + storage = ZipStorage(location) + else: + dirname = os.path.dirname(os.path.abspath(location)) + sbt_name = os.path.basename(location) + if sbt_name.endswith('.sbt.json'): + sbt_name = sbt_name[:-9] loaders = { 1: cls._load_v1, @@ -506,11 +517,14 @@ def load(cls, location, leaf_loader=None, storage=None, print_version_warning=Tr leaf_loader = Leaf.load sbt_fn = os.path.join(dirname, sbt_name) - if not sbt_fn.endswith('.sbt.json'): + if not sbt_fn.endswith('.sbt.json') and tempfile is None: sbt_fn += '.sbt.json' with open(sbt_fn) as fp: jnodes = json.load(fp) + if tempfile is not None: + tempfile.close() + version = 1 if isinstance(jnodes, Mapping): version = jnodes['version'] diff --git a/tests/test_sbt.py b/tests/test_sbt.py index 5a8c2e00d6..54014a917d 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -3,6 +3,7 @@ import os import pytest +import shutil from sourmash import signature from sourmash.sbt import SBT, GraphFactory, Leaf, Node @@ -418,6 +419,23 @@ def test_sbt_redisstorage(): assert old_result == new_result +def test_load_zip(tmpdir): + testdata = utils.get_test_data("v5.zip") + testsbt = tmpdir.join("v5.zip") + + shutil.copyfile(testdata, str(testsbt)) + + tree = SBT.load(str(testsbt), leaf_loader=SigLeaf.load) + + to_search = signature.load_one_signature(utils.get_test_data(utils.SIG_FILES[0])) + + print("*" * 60) + print("{}:".format(to_search)) + new_result = {str(s) for s in tree.find(search_minhashes, to_search, 0.1)} + print(*new_result, sep="\n") + assert len(new_result) == 2 + + def test_tree_repair(): tree_repair = SBT.load(utils.get_test_data('leaves.sbt.json'), leaf_loader=SigLeaf.load)