diff --git a/src/sourmash/sig/__main__.py b/src/sourmash/sig/__main__.py index e342f14470..eab898b39b 100644 --- a/src/sourmash/sig/__main__.py +++ b/src/sourmash/sig/__main__.py @@ -553,9 +553,9 @@ def extract(args): notify(f"loaded {len(picklist.pickset)} distinct values into picklist.") if n_empty_val: - notify(f"WARNING: {n_empty_val} empty values in column '{picklist.column_name}' in CSV file") + notify(f"WARNING: {n_empty_val} empty values in column '{picklist.column_name}' in picklist file") if dup_vals: - notify(f"WARNING: {len(dup_vals)} values in column '{picklist.column_name}' were not distinct") + notify(f"WARNING: {len(dup_vals)} values in picklist column '{picklist.column_name}' were not distinct") picklist_filter_fn = picklist.filter else: def picklist_filter_fn(it): diff --git a/src/sourmash/sig/picklist.py b/src/sourmash/sig/picklist.py index 6c2594c100..4d07f309dc 100644 --- a/src/sourmash/sig/picklist.py +++ b/src/sourmash/sig/picklist.py @@ -85,12 +85,11 @@ def init(self, values=[]): if self.pickset is not None: raise ValueError("already initialized?") self.pickset = set(values) + return self.pickset def load(self, pickfile, column_name): "load pickset, return num empty vals, and set of duplicate vals." - pickset = self.pickset - if pickset is None: - pickset = set() + pickset = self.init() n_empty_val = 0 dup_vals = set() @@ -98,7 +97,7 @@ def load(self, pickfile, column_name): r = csv.DictReader(csvfile) if column_name not in r.fieldnames: - raise ValueError("column '{column_name}' not in pickfile '{pickfile}'") + raise ValueError(f"column '{column_name}' not in pickfile '{pickfile}'") for row in r: # pick out values from column @@ -113,9 +112,8 @@ def load(self, pickfile, column_name): if col in pickset: dup_vals.add(col) else: - pickset.add(col) + self.add(col) - self.pickset = pickset return n_empty_val, dup_vals def add(self, value): diff --git a/tests/test_cmd_signature.py b/tests/test_cmd_signature.py index c15bd4d724..e243234de0 100644 --- a/tests/test_cmd_signature.py +++ b/tests/test_cmd_signature.py @@ -1462,6 +1462,104 @@ def test_sig_extract_9_picklist_md5_ksize_hp_select(runtmp): assert actual_extract_sig.minhash.moltype == 'hp' +def test_sig_extract_10_picklist_md5_dups_and_empty(runtmp): + # test empty picklist values, and duplicate picklist values + sigdir = utils.get_test_data('prot/') + + # make picklist + picklist_csv = runtmp.output('pick.csv') + with open(picklist_csv, 'w', newline='') as csvfp: + w = csv.DictWriter(csvfp, fieldnames=['md5']) + w.writeheader() + w.writerow(dict(md5='ea2a1ad233c2908529d124a330bcb672')) + w.writerow(dict(md5='ea2a1ad233c2908529d124a330bcb672')) + w.writerow(dict(md5='')) + + picklist_arg = f"{picklist_csv}:md5:md5" + + runtmp.sourmash('sig', 'extract', sigdir, '--picklist', + picklist_arg, '-k', '19', '--hp') + + # stdout should be new signature + out = runtmp.last_result.out + actual_extract_sig = sourmash.load_one_signature(out) + + assert actual_extract_sig.minhash.ksize == 19 + assert actual_extract_sig.minhash.moltype == 'hp' + + err = runtmp.last_result.err + print(err) + + assert "WARNING: 1 empty values in column 'md5' in picklist file" in err + assert "WARNING: 1 values in picklist column 'md5' were not distinct" in err + + +def test_sig_extract_11_picklist_bad_coltype(runtmp): + # test with invalid picklist coltype + sigdir = utils.get_test_data('prot/') + + # make picklist + picklist_csv = runtmp.output('pick.csv') + with open(picklist_csv, 'w', newline='') as csvfp: + w = csv.DictWriter(csvfp, fieldnames=['md5']) + w.writeheader() + w.writerow(dict(md5='ea2a1ad233c2908529d124a330bcb672')) + + picklist_arg = f"{picklist_csv}:md5:BADCOLTYPE" + + with pytest.raises(ValueError): + runtmp.sourmash('sig', 'extract', sigdir, '--picklist', + picklist_arg, '-k', '19', '--hp') + + err = runtmp.last_result.err + print(err) + assert "ValueError: invalid picklist column type 'BADCOLTYPE'" in err + + +def test_sig_extract_12_picklist_bad_argstr(runtmp): + # test with invalid argument format to --picklist + sigdir = utils.get_test_data('prot/') + + # make picklist + picklist_csv = runtmp.output('pick.csv') + with open(picklist_csv, 'w', newline='') as csvfp: + w = csv.DictWriter(csvfp, fieldnames=['md5']) + w.writeheader() + w.writerow(dict(md5='ea2a1ad233c2908529d124a330bcb672')) + + picklist_arg = f"{picklist_csv}" + + with pytest.raises(ValueError): + runtmp.sourmash('sig', 'extract', sigdir, '--picklist', + picklist_arg, '-k', '19', '--hp') + + err = runtmp.last_result.err + print(err) + assert "invalid picklist argument" in err + + +def test_sig_extract_12_picklist_bad_colname(runtmp): + # test with invalid picklist colname + sigdir = utils.get_test_data('prot/') + + # make picklist + picklist_csv = runtmp.output('pick.csv') + with open(picklist_csv, 'w', newline='') as csvfp: + w = csv.DictWriter(csvfp, fieldnames=['md5']) + w.writeheader() + w.writerow(dict(md5='ea2a1ad233c2908529d124a330bcb672')) + + picklist_arg = f"{picklist_csv}:BADCOLNAME:md5" + + with pytest.raises(ValueError): + runtmp.sourmash('sig', 'extract', sigdir, '--picklist', + picklist_arg, '-k', '19', '--hp') + + err = runtmp.last_result.err + print(err) + assert "ValueError: column 'BADCOLNAME' not in pickfile" in err + + @utils.in_tempdir def test_sig_flatten_1(c): # extract matches to several names from among several signatures & flatten