Skip to content

Commit

Permalink
support loading from a zipfile
Browse files Browse the repository at this point in the history
  • Loading branch information
luizirber committed Jul 7, 2019
1 parent e5157fb commit 863c255
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
1 change: 0 additions & 1 deletion sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import random
import shutil
import tempfile
import zipfile

import screed
from .sourmash_args import SourmashArgumentParser
Expand Down
26 changes: 20 additions & 6 deletions sourmash/sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def search_transcript(node, seq, threshold):

from __future__ import print_function, unicode_literals, division

from collections import namedtuple, defaultdict
from collections import namedtuple
try:
from collections.abc import Mapping
except ImportError: # Python 2...
Expand Down Expand Up @@ -481,10 +481,21 @@ def load(cls, location, leaf_loader=None, storage=None, print_version_warning=Tr
SBT
the SBT tree built from the description.
"""
dirname = os.path.dirname(os.path.abspath(location))
sbt_name = os.path.basename(location)
if sbt_name.endswith('.sbt.json'):
sbt_name = sbt_name[:-9]
tempfile = None
if zipfile.is_zipfile(location):
tempfile = NamedTemporaryFile()
with zipfile.ZipFile(location, 'r') as zf:
tempfile.write(zf.read('tree.sbt.json'))
tempfile.flush()

dirname = os.path.dirname(tempfile.name)
sbt_name = os.path.basename(tempfile.name)
storage = ZipStorage(location)
else:
dirname = os.path.dirname(os.path.abspath(location))
sbt_name = os.path.basename(location)
if sbt_name.endswith('.sbt.json'):
sbt_name = sbt_name[:-9]

loaders = {
1: cls._load_v1,
Expand All @@ -506,11 +517,14 @@ def load(cls, location, leaf_loader=None, storage=None, print_version_warning=Tr
leaf_loader = Leaf.load

sbt_fn = os.path.join(dirname, sbt_name)
if not sbt_fn.endswith('.sbt.json'):
if not sbt_fn.endswith('.sbt.json') and tempfile is None:
sbt_fn += '.sbt.json'
with open(sbt_fn) as fp:
jnodes = json.load(fp)

if tempfile is not None:
tempfile.close()

version = 1
if isinstance(jnodes, Mapping):
version = jnodes['version']
Expand Down
18 changes: 18 additions & 0 deletions tests/test_sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import pytest
import shutil

from sourmash import signature
from sourmash.sbt import SBT, GraphFactory, Leaf, Node
Expand Down Expand Up @@ -418,6 +419,23 @@ def test_sbt_redisstorage():
assert old_result == new_result


def test_load_zip(tmpdir):
testdata = utils.get_test_data("v5.zip")
testsbt = tmpdir.join("v5.zip")

shutil.copyfile(testdata, str(testsbt))

tree = SBT.load(str(testsbt), leaf_loader=SigLeaf.load)

to_search = signature.load_one_signature(utils.get_test_data(utils.SIG_FILES[0]))

print("*" * 60)
print("{}:".format(to_search))
new_result = {str(s) for s in tree.find(search_minhashes, to_search, 0.1)}
print(*new_result, sep="\n")
assert len(new_result) == 2


def test_tree_repair():
tree_repair = SBT.load(utils.get_test_data('leaves.sbt.json'),
leaf_loader=SigLeaf.load)
Expand Down

0 comments on commit 863c255

Please sign in to comment.