Skip to content

Commit

Permalink
add delete_sites table function
Browse files Browse the repository at this point in the history
Fixes #363
  • Loading branch information
hyanwong committed Sep 11, 2019
1 parent 245333b commit 6be1475
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 9 deletions.
70 changes: 70 additions & 0 deletions python/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,6 +2026,76 @@ def test_multichar_metadata(self):
self.assertEqual(site.metadata, site.id * b"A")


class TestDeleteSites(unittest.TestCase):
"""
Tests for the TableCollection.remove_sites method.
"""
def tc_with_4_sites(self):
ts = msprime.simulate(8, random_seed=3)
tables = ts.dump_tables()
tables.sites.set_columns(np.arange(0, 1, 0.25), *tskit.pack_strings(['G'] * 4))
tables.mutations.add_row(site=1, node=ts.first().parent(0), derived_state='C')
tables.mutations.add_row(site=1, node=0, derived_state='T', parent=0)
tables.mutations.add_row(site=2, node=1, derived_state='A')
return tables

def test_remove_by_bool(self):
tables = self.tc_with_4_sites()
tables.delete_sites(np.array([0, 0, 0, 0], dtype=bool))
ts = tables.tree_sequence()
self.assertEquals(ts.num_sites, 4)
self.assertEquals(ts.num_mutations, 3)
tables.delete_sites(np.array([0, 1, 0, 1], dtype=bool))
ts = tables.tree_sequence()
self.assertEquals(ts.num_sites, 2)
self.assertEquals(ts.num_mutations, 1)

def test_remove_by_index(self):
tables = self.tc_with_4_sites()
tables.delete_sites([])
ts = tables.tree_sequence()
self.assertEquals(ts.num_sites, 4)
self.assertEquals(ts.num_mutations, 3)
tables.delete_sites(2)
ts = tables.tree_sequence()
self.assertEquals(ts.num_sites, 3)
self.assertEquals(ts.num_mutations, 2)
tables.delete_sites([1, 2])
ts = tables.tree_sequence()
self.assertEquals(ts.num_sites, 1)
self.assertEquals(ts.num_mutations, 0)

def test_remove_by_negative_index(self):
tables = self.tc_with_4_sites()
tables.delete_sites([-2])
ts = tables.tree_sequence()
self.assertEquals(ts.num_sites, 3)
self.assertEquals(ts.num_mutations, 2)
tables.delete_sites(np.array([-2, 2]))
ts = tables.tree_sequence()
self.assertEquals(ts.num_sites, 1)
self.assertEquals(ts.num_mutations, 0)

def test_remove_all(self):
tables = self.tc_with_4_sites()
tables.delete_sites(range(4))
ts = tables.tree_sequence()
self.assertEquals(ts.num_sites, 0)
self.assertEquals(ts.num_mutations, 0)
# should be OK to run on a siteless table collection is no sites specified
tables.delete_sites([])

def test_remove_bad(self):
tables = self.tc_with_4_sites()
self.assertRaises(IndexError, tables.delete_sites, 4)
self.assertRaises(IndexError, tables.delete_sites, -5)
self.assertRaises(IndexError, tables.delete_sites, np.zeros((5, ), dtype=bool))
self.assertRaises(IndexError, tables.delete_sites, np.zeros((3, ), dtype=bool))
tables.delete_sites(np.ones((4, ), dtype=bool))
self.assertRaises(IndexError, tables.delete_sites, 0)
self.assertRaises(IndexError, tables.delete_sites, np.zeros((1, ), dtype=bool))


class TestBaseTable(unittest.TestCase):
"""
Tests of the table superclass.
Expand Down
75 changes: 66 additions & 9 deletions python/tskit/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@
["timestamp", "record"])


def keep_with_offset(keep, data, offset):
"""
Used when filtering _offset columns in tables
"""
# We need the astype here for 32 bit machines
lens = np.diff(offset).astype(np.int32)
return (data[np.repeat(keep, lens)],
np.concatenate([
np.array([0], dtype=offset.dtype),
np.cumsum(lens[keep], dtype=offset.dtype)]))


class BaseTable(object):
"""
Superclass of high-level tables. Not intended for direct instantiation.
Expand Down Expand Up @@ -1647,6 +1659,60 @@ def deduplicate_sites(self):
self.ll_tables.deduplicate_sites()
# TODO add provenance

def delete_sites(self, to_delete, record_provenance=True):
"""
Remove the specified sites entirely from the sites and mutations tables in this
collection.
:param list[int] site_ids: A list of site IDs, or a numpy boolean vector of
length num_sites specifying the sites to remove.
:param bool record_provenance: If True, record details of this call to
``delete_sites`` in the returned tree sequence's provenance information.
(Default: True).
"""
if getattr(to_delete, 'dtype', None) == bool:
if to_delete.shape != (self.sites.num_rows, ):
raise IndexError("boolean index did not match indexed array length")
keep_sites = np.logical_not(to_delete)
else:
keep_sites = np.ones(self.sites.num_rows, dtype=bool)
keep_sites[util.safe_np_int_cast(to_delete, np.int32)] = 0
new_as, new_as_offset = keep_with_offset(
keep_sites, self.sites.ancestral_state,
self.sites.ancestral_state_offset)
new_md, new_md_offset = keep_with_offset(
keep_sites, self.sites.metadata, self.sites.metadata_offset)
self.sites.set_columns(
position=self.sites.position[keep_sites],
ancestral_state=new_as,
ancestral_state_offset=new_as_offset,
metadata=new_md,
metadata_offset=new_md_offset)
# We also need to adjust the mutations table, as it references into sites
keep_mutations = keep_sites[self.mutations.site]
new_ds, new_ds_offset = keep_with_offset(
keep_mutations, self.mutations.derived_state,
self.mutations.derived_state_offset)
new_md, new_md_offset = keep_with_offset(
keep_mutations, self.mutations.metadata, self.mutations.metadata_offset)
site_map = np.cumsum(keep_sites, dtype=self.mutations.site.dtype) - 1
self.mutations.set_columns(
site=site_map[self.mutations.site[keep_mutations]],
node=self.mutations.node[keep_mutations],
derived_state=new_ds,
derived_state_offset=new_ds_offset,
parent=self.mutations.parent[keep_mutations],
metadata=new_md,
metadata_offset=new_md_offset)
if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
parameters = {
"command": "delete_sites",
"TODO": "add parameters"
}
self.provenances.add_row(record=json.dumps(
provenance.get_provenance_dict(parameters)))

def delete_intervals(self, intervals, simplify=True, record_provenance=True):
"""
Returns a copy of this set of tables for which information in the
Expand Down Expand Up @@ -1701,15 +1767,6 @@ def keep_intervals(self, intervals, simplify=True, record_provenance=True):
(Default: True).
:rtype: tskit.TableCollection
"""

def keep_with_offset(keep, data, offset):
# We need the astype here for 32 bit machines
lens = np.diff(offset).astype(np.int32)
return (data[np.repeat(keep, lens)],
np.concatenate([
np.array([0], dtype=offset.dtype),
np.cumsum(lens[keep], dtype=offset.dtype)]))

intervals = util.intervals_to_np_array(intervals, 0, self.sequence_length)
if len(self.migrations) > 0:
raise ValueError("Migrations not supported by keep_intervals")
Expand Down

0 comments on commit 6be1475

Please sign in to comment.