Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Python API for sample deletion #759

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions apis/python/src/tiledbvcf/binding/libtiledbvcf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ PYBIND11_MODULE(libtiledbvcf, m) {
"ingest_samples",
&Writer::ingest_samples,
py::call_guard<py::gil_scoped_release>())
.def(
"delete_samples",
&Writer::delete_samples,
py::call_guard<py::gil_scoped_release>())
.def("get_schema_version", &Writer::get_schema_version)
.def("set_tiledb_config", &Writer::set_tiledb_config)
.def("set_sample_batch_size", &Writer::set_sample_batch_size)
Expand Down
12 changes: 12 additions & 0 deletions apis/python/src/tiledbvcf/binding/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,18 @@ void Writer::ingest_samples() {
check_error(writer, tiledb_vcf_writer_store(writer));
}

void Writer::delete_samples(std::vector<std::string> samples_to_delete) {
std::vector<const char*> samples;
for (std::string& sample : samples_to_delete) {
samples.emplace_back(sample.c_str());
}

auto writer = ptr.get();
check_error(
writer,
tiledb_vcf_writer_delete_samples(writer, samples.data(), samples.size()));
}

void Writer::deleter(tiledb_vcf_writer_t* w) {
tiledb_vcf_writer_free(&w);
}
Expand Down
2 changes: 2 additions & 0 deletions apis/python/src/tiledbvcf/binding/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ class Writer {

void ingest_samples();

void delete_samples(std::vector<std::string> samples);

/** Returns schema version number of the TileDB VCF dataset */
int32_t get_schema_version();

Expand Down
8 changes: 8 additions & 0 deletions apis/python/src/tiledbvcf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,14 @@ def ingest_samples(
self.writer.register_samples()
self.writer.ingest_samples()

def delete_samples(
self,
sample_uris: List[str] = None,
):
if self.mode != "w":
raise Exception("Dataset not open in write mode")
self.writer.delete_samples(sample_uris)

def tiledb_stats(self) -> str:
"""
Get TileDB stats as a string.
Expand Down
78 changes: 57 additions & 21 deletions apis/python/tests/test_tiledbvcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,15 +1197,8 @@ def test_ingest_mode_merged(tmp_path):
assert ds.count(regions=["chrX:9032893-9032893"]) == 0


# Ok to skip is missing bcftools in Windows CI job
@pytest.mark.skipif(
os.environ.get("CI") == "true"
and platform.system() == "Windows"
and shutil.which("bcftools") is None,
reason="no bcftools",
)
def test_ingest_with_stats_v3(tmp_path):
# tiledbvcf.config_logging("debug")
@pytest.fixture
def test_stats_bgzipped_inputs(tmp_path):
tmp_path_contents = os.listdir(tmp_path)
if "stats" in tmp_path_contents:
shutil.rmtree(os.path.join(tmp_path, "stats"))
Expand All @@ -1221,23 +1214,46 @@ def test_ingest_with_stats_v3(tmp_path):
check=True,
)
bgzipped_inputs = glob.glob(os.path.join(tmp_path, "stats", "*.gz"))
# print(f"bgzipped inputs: {bgzipped_inputs}")
for vcf_file in bgzipped_inputs:
assert subprocess.run("bcftools index " + vcf_file, shell=True).returncode == 0
if "outputs" in tmp_path_contents:
shutil.rmtree(os.path.join(tmp_path, "outputs"))
if "stats_test" in tmp_path_contents:
shutil.rmtree(os.path.join(tmp_path, "stats_test"))
# tiledbvcf.config_logging("trace")
return bgzipped_inputs


@pytest.fixture
def test_stats_sample_names(test_stats_bgzipped_inputs):
assert len(test_stats_bgzipped_inputs) == 8
return [os.path.basename(file).split(".")[0] for file in test_stats_bgzipped_inputs]


@pytest.fixture
def test_stats_v3_ingestion(tmp_path, test_stats_bgzipped_inputs):
assert len(test_stats_bgzipped_inputs) == 8
# print(f"bgzipped inputs: {test_stats_bgzipped_inputs}")
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w")
ds.create_dataset(
enable_variant_stats=True, enable_allele_count=True, variant_stats_version=3
)
ds.ingest_samples(bgzipped_inputs)
ds.ingest_samples(test_stats_bgzipped_inputs)
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r")
sample_names = [os.path.basename(file).split(".")[0] for file in bgzipped_inputs]
data_frame = ds.read(
samples=sample_names,
return ds


# Ok to skip is missing bcftools in Windows CI job
@pytest.mark.skipif(
os.environ.get("CI") == "true"
and platform.system() == "Windows"
and shutil.which("bcftools") is None,
reason="no bcftools",
)
def test_ingest_with_stats_v3(
tmp_path, test_stats_v3_ingestion, test_stats_sample_names
):
data_frame = test_stats_v3_ingestion.read(
samples=test_stats_sample_names,
attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"],
set_af_filter="<0.2",
)
Expand All @@ -1249,8 +1265,8 @@ def test_ingest_with_stats_v3(tmp_path):
data_frame[data_frame["sample_name"] == "second"]["info_TILEDB_IAF"].iloc[0][0]
== 0.9375
)
data_frame = ds.read(
samples=sample_names,
data_frame = test_stats_v3_ingestion.read(
samples=test_stats_sample_names,
attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"],
scan_all_samples=True,
)
Expand All @@ -1260,25 +1276,45 @@ def test_ingest_with_stats_v3(tmp_path):
]["info_TILEDB_IAF"].iloc[0][0]
== 0.9375
)
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r")
df = ds.read_variant_stats("chr1:1-10000")
df = test_stats_v3_ingestion.read_variant_stats("chr1:1-10000")
assert df.shape == (13, 5)
df = tiledbvcf.allele_frequency.read_allele_frequency(
os.path.join(tmp_path, "stats_test"), "chr1:1-10000"
)
assert df.pos.is_monotonic_increasing
df["an_check"] = (df.ac / df.af).round(0).astype("int32")
assert df.an_check.equals(df.an)
df = ds.read_variant_stats("chr1:1-10000")
df = test_stats_v3_ingestion.read_variant_stats("chr1:1-10000")
assert df.shape == (13, 5)
df = df.to_pandas()
df = ds.read_allele_count("chr1:1-10000")
df = test_stats_v3_ingestion.read_allele_count("chr1:1-10000")
assert df.shape == (7, 6)
df = df.to_pandas()
assert sum(df["pos"] == (0, 1, 1, 2, 2, 2, 3)) == 7
assert sum(df["count"] == (8, 5, 3, 4, 2, 2, 1)) == 7


@pytest.mark.skipif(
os.environ.get("CI") == "true"
and platform.system() == "Windows"
and shutil.which("bcftools") is None,
reason="no bcftools",
)
def test_delete_samples(tmp_path, test_stats_v3_ingestion, test_stats_sample_names):
# assert test_stats_v3_ingestion.samples() == test_stats_sample_names
assert "second" in test_stats_sample_names
assert "fifth" in test_stats_sample_names
assert "third" in test_stats_sample_names
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w")
# tiledbvcf.config_logging("trace")
ds.delete_samples(["second", "fifth"])
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r")
sample_names = ds.samples()
assert "second" not in sample_names
assert "fifth" not in sample_names
assert "third" in sample_names


# Ok to skip is missing bcftools in Windows CI job
@pytest.mark.skipif(
os.environ.get("CI") == "true"
Expand Down
15 changes: 15 additions & 0 deletions libtiledbvcf/src/c_api/tiledbvcf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,21 @@ int32_t tiledb_vcf_writer_set_variant_stats_version(
return TILEDB_VCF_OK;
}

int32_t tiledb_vcf_writer_delete_samples(
tiledb_vcf_writer_t* writer, const char** samples, size_t nsamples) {
std::vector<std::string> encoded_samples;
for (size_t i = 0; i < nsamples; i++)
encoded_samples.emplace_back(samples[i]);
if (sanity_check(writer) == TILEDB_VCF_ERR)
return TILEDB_VCF_ERR;

if (SAVE_ERROR_CATCH(
writer, writer->writer_->delete_samples(encoded_samples)))
return TILEDB_VCF_ERR;

return TILEDB_VCF_OK;
}

/* ********************************* */
/* ERROR */
/* ********************************* */
Expand Down
10 changes: 10 additions & 0 deletions libtiledbvcf/src/c_api/tiledbvcf.h
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,16 @@ tiledb_vcf_writer_set_compression_level(tiledb_vcf_writer_t* writer, int level);
TILEDBVCF_EXPORT int32_t tiledb_vcf_writer_set_variant_stats_version(
tiledb_vcf_writer_t* writer, uint8_t version);

/**
* Deletes samples from dataset
* @param writer VCF writer object
* @param samples samples to delete
* @param nsamples number of samples to delete
*/
TILEDBVCF_EXPORT int32_t tiledb_vcf_writer_delete_samples(

tiledb_vcf_writer_t* writer, const char** samples, size_t nsamples);

/* ********************************* */
/* ERROR */
/* ********************************* */
Expand Down
4 changes: 3 additions & 1 deletion libtiledbvcf/src/dataset/tiledbvcfdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,9 @@ void TileDBVCFDataset::delete_samples(
const std::vector<std::string>& sample_names,
const std::vector<std::string>& tiledb_config) {
// Open dataset in read mode, required before calling `sample_exists`.
open(uri);
if (!open_) {
open(uri, tiledb_config);
}

// Define a function that deletes a sample from an array
auto delete_sample = [&](Array& array, const std::string& sample) {
Expand Down
5 changes: 5 additions & 0 deletions libtiledbvcf/src/write/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1484,5 +1484,10 @@ void Writer::set_variant_stats_array_version(uint8_t version) {
creation_params_.variant_stats_array_version = version;
}

void Writer::delete_samples(std::vector<std::string> samples) {
dataset_->delete_samples(
ingestion_params_.uri, samples, ingestion_params_.tiledb_config);
}

} // namespace vcf
} // namespace tiledb
5 changes: 5 additions & 0 deletions libtiledbvcf/src/write/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ class Writer {
/** Set variant stats array version */
void set_variant_stats_array_version(uint8_t version);

/**
* @brief Delete samples from the writer's dataset.
*/
void delete_samples(std::vector<std::string> samples);

private:
/* ********************************* */
/* PRIVATE ATTRIBUTES */
Expand Down
Loading