From 919a2fa1ff87295b33901ddcd57a46a1b806e311 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 15 Feb 2024 16:27:56 +0000 Subject: [PATCH 1/5] Initial package --- bio2zarr/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 bio2zarr/__init__.py diff --git a/bio2zarr/__init__.py b/bio2zarr/__init__.py new file mode 100644 index 0000000..e69de29 From 9b23bd0ebe77db3816cf2156caddecc309d29c96 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 15 Feb 2024 16:30:53 +0000 Subject: [PATCH 2/5] Initial copy from sgkit PR https://github.com/pystatgen/sgkit/pull/1185 --- bio2zarr/vcf.py | 1735 +++++++++++++++++++++++++++++++++++++++++ tests/test_vcf.py | 1891 +++++++++++++++++++++++++++++++++++++++++++++ vcf2zarr.py | 128 +++ 3 files changed, 3754 insertions(+) create mode 100644 bio2zarr/vcf.py create mode 100644 tests/test_vcf.py create mode 100644 vcf2zarr.py diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py new file mode 100644 index 0000000..906a4c3 --- /dev/null +++ b/bio2zarr/vcf.py @@ -0,0 +1,1735 @@ +import concurrent.futures as cf +import dataclasses +import multiprocessing +import functools +import threading +import pathlib +import time +import pickle +import sys +import shutil +import json +import math +import tempfile +from typing import Any + +import humanize +import cyvcf2 +import numcodecs +import numpy as np +import numpy.testing as nt +import tqdm +import zarr + +import bed_reader + + +# from sgkit.io.utils import FLOAT32_MISSING, str_is_int +from sgkit.io.utils import ( + # CHAR_FILL, + # CHAR_MISSING, + FLOAT32_FILL, + FLOAT32_MISSING, + FLOAT32_FILL_AS_INT32, + FLOAT32_MISSING_AS_INT32, + INT_FILL, + INT_MISSING, + # STR_FILL, + # STR_MISSING, + # str_is_int, +) + +# from sgkit.io.vcf import partition_into_regions + +# from sgkit.io.utils import INT_FILL, concatenate_and_rechunk, str_is_int +# from sgkit.utils import smallest_numpy_int_dtype + +numcodecs.blosc.use_threads = False + +default_compressor = numcodecs.Blosc( + cname="zstd", clevel=7, shuffle=numcodecs.Blosc.AUTOSHUFFLE +) + + +def assert_all_missing_float(a): + v = np.array(a, dtype=np.float32).view(np.int32) + assert np.all(v == FLOAT32_MISSING_AS_INT32) + + +def assert_prefix_integer_equal_1d(vcf_val, zarr_val): + v = np.array(vcf_val, dtype=np.int32, ndmin=1) + z = np.array(zarr_val, dtype=np.int32, ndmin=1) + v[v == VCF_INT_MISSING] = -1 + v[v == VCF_INT_FILL] = -2 + k = v.shape[0] + assert np.all(z[k:] == -2) + nt.assert_array_equal(v, z[:k]) + + +def assert_prefix_integer_equal_2d(vcf_val, zarr_val): + assert len(vcf_val.shape) == 2 + vcf_val[vcf_val == VCF_INT_MISSING] = -1 + vcf_val[vcf_val == VCF_INT_FILL] = -2 + if vcf_val.shape[1] == 1: + nt.assert_array_equal(vcf_val[:, 0], zarr_val) + else: + k = vcf_val.shape[1] + nt.assert_array_equal(vcf_val, zarr_val[:, :k]) + assert np.all(zarr_val[:, k:] == -2) + + +# FIXME these are sort-of working. It's not clear that we're +# handling the different dimensions and padded etc correctly. +# Will need to hand-craft from examples to test +def assert_prefix_float_equal_1d(vcf_val, zarr_val): + v = np.array(vcf_val, dtype=np.float32, ndmin=1) + vi = v.view(np.int32) + z = np.array(zarr_val, dtype=np.float32, ndmin=1) + zi = z.view(np.int32) + assert np.sum(zi == FLOAT32_MISSING_AS_INT32) == 0 + k = v.shape[0] + assert np.all(zi[k:] == FLOAT32_FILL_AS_INT32) + # assert np.where(zi[:k] == FLOAT32_FILL_AS_INT32) + nt.assert_array_almost_equal(v, z[:k]) + # nt.assert_array_equal(v, z[:k]) + + +def assert_prefix_float_equal_2d(vcf_val, zarr_val): + assert len(vcf_val.shape) == 2 + if vcf_val.shape[1] == 1: + vcf_val = vcf_val[:, 0] + v = np.array(vcf_val, dtype=np.float32, ndmin=2) + vi = v.view(np.int32) + z = np.array(zarr_val, dtype=np.float32, ndmin=2) + zi = z.view(np.int32) + assert np.all((zi == FLOAT32_MISSING_AS_INT32) == (vi == FLOAT32_MISSING_AS_INT32)) + assert np.all((zi == FLOAT32_FILL_AS_INT32) == (vi == FLOAT32_FILL_AS_INT32)) + # print(vcf_val, zarr_val) + # assert np.sum(zi == FLOAT32_MISSING_AS_INT32) == 0 + k = v.shape[0] + # print("k", k) + assert np.all(zi[k:] == FLOAT32_FILL_AS_INT32) + # assert np.where(zi[:k] == FLOAT32_FILL_AS_INT32) + nt.assert_array_almost_equal(v, z[:k]) + # nt.assert_array_equal(v, z[:k]) + + +# TODO rename to wait_and_check_futures +def flush_futures(futures): + # Make sure previous futures have completed + for future in cf.as_completed(futures): + exception = future.exception() + if exception is not None: + raise exception + + +@dataclasses.dataclass +class VcfFieldSummary: + num_chunks: int = 0 + compressed_size: int = 0 + uncompressed_size: int = 0 + max_number: int = 0 # Corresponds to VCF Number field, depends on context + # Only defined for numeric fields + max_value: Any = -math.inf + min_value: Any = math.inf + + def update(self, other): + self.num_chunks += other.num_chunks + self.compressed_size += other.compressed_size + self.uncompressed_size = other.uncompressed_size + self.max_number = max(self.max_number, other.max_number) + self.min_value = min(self.min_value, other.min_value) + self.max_value = max(self.max_value, other.max_value) + + def asdict(self): + return dataclasses.asdict(self) + + +@dataclasses.dataclass +class VcfField: + category: str + name: str + vcf_number: str + vcf_type: str + description: str + summary: VcfFieldSummary + + @staticmethod + def from_header(definition): + category = definition["HeaderType"] + name = definition["ID"] + vcf_number = definition["Number"] + vcf_type = definition["Type"] + return VcfField( + category=category, + name=name, + vcf_number=vcf_number, + vcf_type=vcf_type, + description=definition["Description"].strip('"'), + summary=VcfFieldSummary(), + ) + + @staticmethod + def fromdict(d): + f = VcfField(**d) + f.summary = VcfFieldSummary(**d["summary"]) + return f + + @property + def full_name(self): + if self.category == "fixed": + return self.name + return f"{self.category}/{self.name}" + + # TODO add method here to choose a good set compressor and + # filters default here for this field. + + def smallest_dtype(self): + """ + Returns the smallest dtype suitable for this field based + on type, and values. + """ + s = self.summary + if self.vcf_type == "Float": + ret = "f4" + elif self.vcf_type == "Integer": + dtype = "i4" + for a_dtype in ["i1", "i2"]: + info = np.iinfo(a_dtype) + if info.min <= s.min_value and s.max_value <= info.max: + dtype = a_dtype + break + ret = dtype + elif self.vcf_type == "Flag": + ret = "bool" + else: + assert self.vcf_type == "String" + ret = "str" + # if s.max_number == 0: + # ret = "str" + # else: + # ret = "O" + # print("smallest dtype", self.name, self.vcf_type,":", ret) + return ret + + +@dataclasses.dataclass +class VcfPartition: + vcf_path: str + num_records: int + first_position: int + + +@dataclasses.dataclass +class VcfMetadata: + samples: list + contig_names: list + filters: list + fields: list + contig_lengths: list = None + partitions: list = None + + @property + def fixed_fields(self): + return [field for field in self.fields if field.category == "fixed"] + + @property + def info_fields(self): + return [field for field in self.fields if field.category == "INFO"] + + @property + def format_fields(self): + return [field for field in self.fields if field.category == "FORMAT"] + + @staticmethod + def fromdict(d): + fields = [VcfField.fromdict(fd) for fd in d["fields"]] + partitions = [VcfPartition(**pd) for pd in d["partitions"]] + d = d.copy() + d["fields"] = fields + d["partitions"] = partitions + return VcfMetadata(**d) + + def asdict(self): + return dataclasses.asdict(self) + + +def fixed_vcf_field_definitions(): + def make_field_def(name, vcf_type, vcf_number): + return VcfField( + category="fixed", + name=name, + vcf_type=vcf_type, + vcf_number=vcf_number, + description="", + summary=VcfFieldSummary(), + ) + + fields = [ + make_field_def("CHROM", "String", "1"), + make_field_def("POS", "Integer", "1"), + make_field_def("QUAL", "Float", "1"), + make_field_def("ID", "String", "."), + make_field_def("FILTERS", "String", "."), + make_field_def("REF", "String", "1"), + make_field_def("ALT", "String", "."), + ] + return fields + + +def scan_vcfs(paths, show_progress): + partitions = [] + vcf_metadata = None + for path in tqdm.tqdm(paths, desc="Scan ", disable=not show_progress): + vcf = cyvcf2.VCF(path) + + filters = [ + h["ID"] + for h in vcf.header_iter() + if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str) + ] + # Ensure PASS is the first filter if present + if "PASS" in filters: + filters.remove("PASS") + filters.insert(0, "PASS") + + fields = fixed_vcf_field_definitions() + for h in vcf.header_iter(): + if h["HeaderType"] in ["INFO", "FORMAT"]: + field = VcfField.from_header(h) + if field.name == "GT": + field.vcf_type = "Integer" + field.vcf_number = "." + fields.append(field) + + metadata = VcfMetadata( + samples=vcf.samples, + contig_names=vcf.seqnames, + filters=filters, + fields=fields, + ) + try: + metadata.contig_lengths = vcf.seqlens + except AttributeError: + pass + + if vcf_metadata is None: + vcf_metadata = metadata + else: + if metadata != vcf_metadata: + raise ValueError("Incompatible VCF chunks") + record = next(vcf) + + partitions.append( + # Requires cyvcf2>=0.30.27 + VcfPartition( + vcf_path=str(path), + num_records=vcf.num_records, + first_position=(record.CHROM, record.POS), + ) + ) + partitions.sort(key=lambda x: x.first_position) + vcf_metadata.partitions = partitions + return vcf_metadata + + +def sanitise_value_bool(buff, j, value): + x = True + if value is None: + x = False + buff[j] = x + + +def sanitise_value_float_scalar(buff, j, value): + x = value + if value is None: + x = FLOAT32_MISSING + buff[j] = x + + +def sanitise_value_int_scalar(buff, j, value): + x = value + if value is None: + # print("MISSING", INT_MISSING, INT_FILL) + x = [INT_MISSING] + else: + x = sanitise_int_array([value], ndmin=1, dtype=np.int32) + buff[j] = x[0] + + +def sanitise_value_string_scalar(buff, j, value): + x = value + if value is None: + x = "." + # TODO check for missing values as well + buff[j] = x + + +def sanitise_value_string_1d(buff, j, value): + if value is None: + buff[j] = "." + else: + value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False) + value = drop_empty_second_dim(value) + buff[j] = "" + # TODO check for missing? + buff[j, : value.shape[0]] = value + + +def sanitise_value_string_2d(buff, j, value): + if value is None: + buff[j] = "." + else: + value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False) + value = drop_empty_second_dim(value) + buff[j] = "" + # TODO check for missing? + buff[j, : value.shape[0]] = value + + +def drop_empty_second_dim(value): + assert len(value.shape) == 1 or value.shape[1] == 1 + if len(value.shape) == 2 and value.shape[1] == 1: + value = value[..., 0] + return value + + +def sanitise_value_float_1d(buff, j, value): + if value is None: + buff[j] = FLOAT32_MISSING + else: + value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False) + value = drop_empty_second_dim(value) + buff[j] = FLOAT32_FILL + # TODO check for missing? + buff[j, : value.shape[0]] = value + + +def sanitise_value_float_2d(buff, j, value): + if value is None: + buff[j] = FLOAT32_MISSING + else: + value = np.array(value, dtype=buff.dtype, copy=False) + buff[j] = FLOAT32_FILL + # TODO check for missing? + buff[j, :, : value.shape[0]] = value + + +def sanitise_int_array(value, ndmin, dtype): + value = np.array(value, ndmin=ndmin, copy=False) + value[value == VCF_INT_MISSING] = -1 + value[value == VCF_INT_FILL] = -2 + # TODO watch out for clipping here! + return value.astype(dtype) + + +def sanitise_value_int_1d(buff, j, value): + if value is None: + buff[j] = -1 + else: + value = sanitise_int_array(value, 1, buff.dtype) + value = drop_empty_second_dim(value) + buff[j] = -2 + buff[j, : value.shape[0]] = value + + +def sanitise_value_int_2d(buff, j, value): + if value is None: + buff[j] = -1 + else: + value = sanitise_int_array(value, 2, buff.dtype) + buff[j] = -2 + buff[j, :, : value.shape[1]] = value + + +class PickleChunkedVcfField: + def __init__(self, vcf_field, base_path): + self.vcf_field = vcf_field + if vcf_field.category == "fixed": + self.path = base_path / vcf_field.name + else: + self.path = base_path / vcf_field.category / vcf_field.name + + self.compressor = numcodecs.Blosc(cname="zstd", clevel=7) + # TODO have a clearer way of defining this state between + # read and write mode. + self.num_partitions = None + self.num_records = None + self.partition_num_chunks = {} + + def num_chunks(self, partition_index): + if partition_index not in self.partition_num_chunks: + partition_path = self.path / f"p{partition_index}" + n = len(list(partition_path.iterdir())) + self.partition_num_chunks[partition_index] = n + return self.partition_num_chunks[partition_index] + + def __repr__(self): + # TODO add class name + return repr({"path": str(self.path), **self.vcf_field.summary.asdict()}) + + def write_chunk(self, partition_index, chunk_index, data): + path = self.path / f"p{partition_index}" / f"c{chunk_index}" + pkl = pickle.dumps(data) + # NOTE assuming that reusing the same compressor instance + # from multiple threads is OK! + compressed = self.compressor.encode(pkl) + with open(path, "wb") as f: + f.write(compressed) + + # Update the summary + self.vcf_field.summary.num_chunks += 1 + self.vcf_field.summary.compressed_size += len(compressed) + self.vcf_field.summary.uncompressed_size += len(pkl) + + def read_chunk(self, partition_index, chunk_index): + path = self.path / f"p{partition_index}" / f"c{chunk_index}" + with open(path, "rb") as f: + pkl = self.compressor.decode(f.read()) + return pickle.loads(pkl), len(pkl) + + def iter_values_bytes(self): + num_records = 0 + bytes_read = 0 + for partition_index in range(self.num_partitions): + for chunk_index in range(self.num_chunks(partition_index)): + chunk, chunk_bytes = self.read_chunk(partition_index, chunk_index) + bytes_read += chunk_bytes + for record in chunk: + yield record, bytes_read + num_records += 1 + if num_records != self.num_records: + raise ValueError( + f"Corruption detected: incorrect number of records in {str(self.path)}." + ) + + def values(self): + return [record for record, _ in self.iter_values_bytes()] + + def sanitiser_factory(self, shape): + """ + Return a function that sanitised values from this column + and writes into a buffer of the specified shape. + """ + assert len(shape) <= 3 + if self.vcf_field.vcf_type == "Flag": + assert len(shape) == 1 + return sanitise_value_bool + elif self.vcf_field.vcf_type == "Float": + if len(shape) == 1: + return sanitise_value_float_scalar + elif len(shape) == 2: + return sanitise_value_float_1d + else: + return sanitise_value_float_2d + elif self.vcf_field.vcf_type == "Integer": + if len(shape) == 1: + return sanitise_value_int_scalar + elif len(shape) == 2: + return sanitise_value_int_1d + else: + return sanitise_value_int_2d + else: + assert self.vcf_field.vcf_type == "String" + if len(shape) == 1: + return sanitise_value_string_scalar + elif len(shape) == 2: + return sanitise_value_string_1d + else: + return sanitise_value_string_2d + + +def update_bounds_float(summary, value, number_dim): + value = np.array(value, dtype=np.float32, copy=False) + # Map back to python types to avoid JSON issues later. Could + # be done more efficiently at the end. + summary.max_value = float(max(summary.max_value, np.max(value))) + summary.min_value = float(min(summary.min_value, np.min(value))) + number = 0 + assert len(value.shape) <= number_dim + 1 + if len(value.shape) == number_dim + 1: + number = value.shape[number_dim] + summary.max_number = max(summary.max_number, number) + + +MIN_INT_VALUE = np.iinfo(np.int32).min + 2 +VCF_INT_MISSING = np.iinfo(np.int32).min +VCF_INT_FILL = np.iinfo(np.int32).min + 1 + + +def update_bounds_integer(summary, value, number_dim): + # print("update bounds int", summary, value) + value = np.array(value, dtype=np.int32, copy=False) + # Mask out missing and fill values + a = value[value >= MIN_INT_VALUE] + summary.max_value = int(max(summary.max_value, np.max(a))) + summary.min_value = int(min(summary.min_value, np.min(a))) + number = 0 + assert len(value.shape) <= number_dim + 1 + if len(value.shape) == number_dim + 1: + number = value.shape[number_dim] + summary.max_number = max(summary.max_number, number) + + +def update_bounds_string(summary, value): + if isinstance(value, str): + number = 0 + else: + number = len(value) + summary.max_number = max(summary.max_number, number) + + +class PickleChunkedWriteBuffer: + def __init__(self, column, partition_index, executor, futures, chunk_size=1): + self.column = column + self.buffer = [] + self.buffered_bytes = 0 + # chunk_size is in megabytes + self.max_buffered_bytes = chunk_size * 2**20 + assert self.max_buffered_bytes > 0 + self.partition_index = partition_index + self.chunk_index = 0 + self.executor = executor + self.futures = futures + self._summary_bounds_update = None + vcf_type = column.vcf_field.vcf_type + number_dim = 0 + if column.vcf_field.category == "FORMAT": + number_dim = 1 + if vcf_type == "Float": + self._summary_bounds_update = functools.partial( + update_bounds_float, number_dim=number_dim + ) + elif vcf_type == "Integer": + self._summary_bounds_update = functools.partial( + update_bounds_integer, number_dim=number_dim + ) + elif vcf_type == "String": + self._summary_bounds_update = update_bounds_string + + def _update_bounds(self, value): + if value is not None: + summary = self.column.vcf_field.summary + # print("update", self.column.vcf_field.full_name, value) + if self._summary_bounds_update is not None: + self._summary_bounds_update(summary, value) + + def append(self, val): + self.buffer.append(val) + self._update_bounds(val) + val_bytes = sys.getsizeof(val) + self.buffered_bytes += val_bytes + if self.buffered_bytes >= self.max_buffered_bytes: + self.flush() + + def flush(self): + if len(self.buffer) > 0: + future = self.executor.submit( + self.column.write_chunk, + self.partition_index, + self.chunk_index, + self.buffer, + ) + self.futures.add(future) + + self.chunk_index += 1 + self.buffer = [] + self.buffered_bytes = 0 + + +class PickleChunkedVcf: + def __init__(self, path, metadata): + self.path = path + self.metadata = metadata + + self.columns = {} + for field in self.metadata.fields: + self.columns[field.full_name] = PickleChunkedVcfField(field, path) + + for col in self.columns.values(): + col.num_partitions = self.num_partitions + col.num_records = self.num_records + + def summary_table(self): + def display_number(x): + ret = "n/a" + if math.isfinite(x): + ret = f"{x: 0.2g}" + return ret + + def display_size(n): + return humanize.naturalsize(n, binary=True) + + data = [] + for name, col in self.columns.items(): + summary = col.vcf_field.summary + d = { + "name": name, + "type": col.vcf_field.vcf_type, + "chunks": summary.num_chunks, + "size": display_size(summary.uncompressed_size), + "compressed": display_size(summary.compressed_size), + "max_n": summary.max_number, + "min_val": display_number(summary.min_value), + "max_val": display_number(summary.max_value), + } + + data.append(d) + return data + + @functools.cached_property + def total_uncompressed_bytes(self): + total = 0 + for col in self.columns.values(): + summary = col.vcf_field.summary + total += summary.uncompressed_size + return total + + @functools.cached_property + def num_records(self): + return sum(partition.num_records for partition in self.metadata.partitions) + + @property + def num_partitions(self): + return len(self.metadata.partitions) + + @property + def num_samples(self): + return len(self.metadata.samples) + + def mkdirs(self): + self.path.mkdir() + for col in self.columns.values(): + col.path.mkdir(parents=True) + for j in range(self.num_partitions): + part_path = col.path / f"p{j}" + part_path.mkdir() + + @staticmethod + def load(path): + path = pathlib.Path(path) + with open(path / "metadata.json") as f: + metadata = VcfMetadata.fromdict(json.load(f)) + return PickleChunkedVcf(path, metadata) + + @staticmethod + def convert( + vcfs, out_path, *, column_chunk_size=16, worker_processes=1, show_progress=False + ): + out_path = pathlib.Path(out_path) + vcf_metadata = scan_vcfs(vcfs, show_progress=show_progress) + pcvcf = PickleChunkedVcf(out_path, vcf_metadata) + pcvcf.mkdirs() + + total_variants = sum( + partition.num_records for partition in vcf_metadata.partitions + ) + + global progress_counter + progress_counter = multiprocessing.Value("Q", 0) + + # start update progress bar process + bar_thread = None + if show_progress: + bar_thread = threading.Thread( + target=update_bar, + args=(progress_counter, total_variants, "Explode", "vars"), + name="progress", + daemon=True, + ) + bar_thread.start() + + with cf.ProcessPoolExecutor( + max_workers=worker_processes, + initializer=init_workers, + initargs=(progress_counter,), + ) as executor: + futures = [] + for j, partition in enumerate(vcf_metadata.partitions): + futures.append( + executor.submit( + PickleChunkedVcf.convert_partition, + vcf_metadata, + j, + out_path, + column_chunk_size=column_chunk_size, + ) + ) + partition_summaries = [ + future.result() for future in cf.as_completed(futures) + ] + + assert progress_counter.value == total_variants + if bar_thread is not None: + bar_thread.join() + + for field in vcf_metadata.fields: + for summary in partition_summaries: + field.summary.update(summary[field.full_name]) + + with open(out_path / "metadata.json", "w") as f: + json.dump(vcf_metadata.asdict(), f, indent=4) + return pcvcf + + @staticmethod + def convert_partition( + vcf_metadata, + partition_index, + out_path, + *, + flush_threads=4, + column_chunk_size=16, + ): + partition = vcf_metadata.partitions[partition_index] + vcf = cyvcf2.VCF(partition.vcf_path) + futures = set() + + def service_futures(max_waiting=2 * flush_threads): + while len(futures) > max_waiting: + futures_done, _ = cf.wait(futures, return_when=cf.FIRST_COMPLETED) + for future in futures_done: + exception = future.exception() + if exception is not None: + raise exception + futures.remove(future) + + with cf.ThreadPoolExecutor(max_workers=flush_threads) as executor: + columns = {} + summaries = {} + info_fields = [] + format_fields = [] + for field in vcf_metadata.fields: + column = PickleChunkedVcfField(field, out_path) + write_buffer = PickleChunkedWriteBuffer( + column, partition_index, executor, futures, column_chunk_size + ) + columns[field.full_name] = write_buffer + summaries[field.full_name] = field.summary + + if field.category == "INFO": + info_fields.append((field.name, write_buffer)) + elif field.category == "FORMAT": + if field.name != "GT": + format_fields.append((field.name, write_buffer)) + + contig = columns["CHROM"] + pos = columns["POS"] + qual = columns["QUAL"] + vid = columns["ID"] + filters = columns["FILTERS"] + ref = columns["REF"] + alt = columns["ALT"] + gt = columns.get("FORMAT/GT", None) + + for variant in vcf: + contig.append(variant.CHROM) + pos.append(variant.POS) + qual.append(variant.QUAL) + vid.append(variant.ID) + filters.append(variant.FILTERS) + ref.append(variant.REF) + alt.append(variant.ALT) + if gt is not None: + gt.append(variant.genotype.array()) + + for name, buff in info_fields: + val = None + try: + val = variant.INFO[name] + except KeyError: + pass + buff.append(val) + + for name, buff in format_fields: + val = None + try: + val = variant.format(name) + except KeyError: + pass + buff.append(val) + + service_futures() + + with progress_counter.get_lock(): + progress_counter.value += 1 + + for col in columns.values(): + col.flush() + service_futures(0) + + return summaries + + +def update_bar(progress_counter, total, title, units): + pbar = tqdm.tqdm( + total=total, desc=title, unit_scale=True, unit=units, smoothing=0.1 + ) + + while (current := progress_counter.value) < total: + inc = current - pbar.n + pbar.update(inc) + time.sleep(0.1) + pbar.close() + + +def init_workers(counter): + global progress_counter + progress_counter = counter + + +def explode( + vcfs, + out_path, + *, + column_chunk_size=16, + worker_processes=1, + show_progress=False, +): + out_path = pathlib.Path(out_path) + if out_path.exists(): + shutil.rmtree(out_path) + + return PickleChunkedVcf.convert( + vcfs, + out_path, + column_chunk_size=column_chunk_size, + worker_processes=worker_processes, + show_progress=show_progress, + ) + + +@dataclasses.dataclass +class ZarrColumnSpec: + # TODO change to "variable_name" + name: str + dtype: str + shape: tuple + chunks: tuple + dimensions: list + description: str + vcf_field: str + compressor: dict + # TODO add filters + + +@dataclasses.dataclass +class ZarrConversionSpec: + chunk_width: int + chunk_length: int + dimensions: list + sample_id: list + contig_id: list + contig_length: list + filter_id: list + variables: list + + def asdict(self): + return dataclasses.asdict(self) + + @staticmethod + def fromdict(d): + ret = ZarrConversionSpec(**d) + ret.variables = [ZarrColumnSpec(**cd) for cd in d["variables"]] + return ret + + @staticmethod + def generate(pcvcf, chunk_length=None, chunk_width=None): + m = pcvcf.num_records + n = pcvcf.num_samples + # FIXME + if chunk_width is None: + chunk_width = 1000 + if chunk_length is None: + chunk_length = 10_000 + + compressor = default_compressor.get_config() + + def fixed_field_spec( + name, dtype, vcf_field=None, shape=(m,), dimensions=("variants",) + ): + return ZarrColumnSpec( + vcf_field=vcf_field, + name=name, + dtype=dtype, + shape=shape, + description="", + dimensions=dimensions, + chunks=[chunk_length], + compressor=compressor, + ) + + alt_col = pcvcf.columns["ALT"] + max_alleles = alt_col.vcf_field.summary.max_number + 1 + num_filters = len(pcvcf.metadata.filters) + + # # FIXME get dtype from lookup table + colspecs = [ + fixed_field_spec( + name="variant_contig", + dtype="i2", # FIXME + ), + fixed_field_spec( + name="variant_filter", + dtype="bool", + shape=(m, num_filters), + dimensions=["variants", "filters"], + ), + fixed_field_spec( + name="variant_allele", + dtype="str", + shape=[m, max_alleles], + dimensions=["variants", "alleles"], + ), + fixed_field_spec( + vcf_field="POS", + name="variant_position", + dtype="i4", + ), + fixed_field_spec( + vcf_field=None, + name="variant_id", + dtype="str", + ), + fixed_field_spec( + vcf_field=None, + name="variant_id_mask", + dtype="bool", + ), + fixed_field_spec( + vcf_field="QUAL", + name="variant_quality", + dtype="f4", + ), + ] + + gt_field = None + for field in pcvcf.metadata.fields: + if field.category == "fixed": + continue + if field.name == "GT": + gt_field = field + continue + shape = [m] + prefix = "variant_" + dimensions = ["variants"] + chunks = [chunk_length] + if field.category == "FORMAT": + prefix = "call_" + shape.append(n) + chunks.append(chunk_width), + dimensions.append("samples") + if field.summary.max_number > 1: + shape.append(field.summary.max_number) + dimensions.append(field.name) + variable_name = prefix + field.name + colspec = ZarrColumnSpec( + vcf_field=field.full_name, + name=variable_name, + dtype=field.smallest_dtype(), + shape=shape, + chunks=chunks, + dimensions=dimensions, + description=field.description, + compressor=compressor, + ) + colspecs.append(colspec) + + if gt_field is not None: + ploidy = gt_field.summary.max_number - 1 + shape = [m, n] + chunks = [chunk_length, chunk_width] + dimensions = ["variants", "samples"] + + colspecs.append( + ZarrColumnSpec( + vcf_field=None, + name="call_genotype_phased", + dtype="bool", + shape=list(shape), + chunks=list(chunks), + dimensions=list(dimensions), + description="", + compressor=compressor, + ) + ) + shape += [ploidy] + dimensions += ["ploidy"] + colspecs.append( + ZarrColumnSpec( + vcf_field=None, + name="call_genotype", + dtype=gt_field.smallest_dtype(), + shape=list(shape), + chunks=list(chunks), + dimensions=list(dimensions), + description="", + compressor=compressor, + ) + ) + colspecs.append( + ZarrColumnSpec( + vcf_field=None, + name="call_genotype_mask", + dtype="bool", + shape=list(shape), + chunks=list(chunks), + dimensions=list(dimensions), + description="", + compressor=compressor, + ) + ) + + return ZarrConversionSpec( + chunk_width=chunk_width, + chunk_length=chunk_length, + variables=colspecs, + dimensions=["variants", "samples", "ploidy", "alleles", "filters"], + sample_id=pcvcf.metadata.samples, + contig_id=pcvcf.metadata.contig_names, + contig_length=pcvcf.metadata.contig_lengths, + filter_id=pcvcf.metadata.filters, + ) + + +@dataclasses.dataclass +class BufferedArray: + array: Any + buff: Any + + def __init__(self, array): + self.array = array + dims = list(array.shape) + dims[0] = min(array.chunks[0], array.shape[0]) + self.buff = np.zeros(dims, dtype=array.dtype) + + def swap_buffers(self): + self.buff = np.zeros_like(self.buff) + + +class SgvcfZarr: + def __init__(self, path): + self.path = pathlib.Path(path) + self.root = None + + def create_array(self, variable): + # print("CREATE", variable) + a = self.root.empty( + variable.name, + shape=variable.shape, + chunks=variable.chunks, + dtype=variable.dtype, + compressor=numcodecs.get_codec(variable.compressor), + ) + a.attrs["_ARRAY_DIMENSIONS"] = variable.dimensions + + def encode_column(self, pcvcf, column): + source_col = pcvcf.columns[column.vcf_field] + array = self.root[column.name] + ba = BufferedArray(array) + sanitiser = source_col.sanitiser_factory(ba.buff.shape) + chunk_length = array.chunks[0] + + with cf.ThreadPoolExecutor(max_workers=4) as executor: + futures = [] + chunk_start = 0 + j = 0 + last_bytes_read = 0 + for value, bytes_read in source_col.iter_values_bytes(): + sanitiser(ba.buff, j, value) + j += 1 + if j == chunk_length: + flush_futures(futures) + futures.extend( + async_flush_array(executor, ba.buff, ba.array, chunk_start) + ) + ba.swap_buffers() + j = 0 + chunk_start += chunk_length + if last_bytes_read != bytes_read: + with progress_counter.get_lock(): + progress_counter.value += bytes_read - last_bytes_read + last_bytes_read = bytes_read + + if j != 0: + flush_futures(futures) + futures.extend( + async_flush_array(executor, ba.buff[:j], ba.array, chunk_start) + ) + flush_futures(futures) + + def encode_genotypes(self, pcvcf): + source_col = pcvcf.columns["FORMAT/GT"] + gt = BufferedArray(self.root["call_genotype"]) + gt_mask = BufferedArray(self.root["call_genotype_mask"]) + gt_phased = BufferedArray(self.root["call_genotype_phased"]) + chunk_length = gt.array.chunks[0] + + buffered_arrays = [gt, gt_phased, gt_mask] + + with cf.ThreadPoolExecutor(max_workers=4) as executor: + futures = [] + chunk_start = 0 + j = 0 + last_bytes_read = 0 + for value, bytes_read in source_col.iter_values_bytes(): + sanitise_value_int_2d(gt.buff, j, value[:, :-1]) + sanitise_value_int_1d(gt_phased.buff, j, value[:, -1]) + # TODO check is this the correct semantics when we are padding + # with mixed ploidies? + gt_mask.buff[j] = gt.buff[j] < 0 + + j += 1 + if j == chunk_length: + flush_futures(futures) + for ba in buffered_arrays: + futures.extend( + async_flush_array(executor, ba.buff, ba.array, chunk_start) + ) + ba.swap_buffers() + j = 0 + chunk_start += chunk_length + if last_bytes_read != bytes_read: + with progress_counter.get_lock(): + progress_counter.value += bytes_read - last_bytes_read + last_bytes_read = bytes_read + + if j != 0: + flush_futures(futures) + for ba in buffered_arrays: + futures.extend( + async_flush_array(executor, ba.buff[:j], ba.array, chunk_start) + ) + flush_futures(futures) + + def encode_alleles(self, pcvcf): + ref_col = pcvcf.columns["REF"] + alt_col = pcvcf.columns["ALT"] + ref_values = ref_col.values() + alt_values = alt_col.values() + allele_array = self.root["variant_allele"] + + # We could do this chunk-by-chunk, but it doesn't seem worth the bother. + alleles = np.full(allele_array.shape, "", dtype="O") + for j, (ref, alt) in enumerate(zip(ref_values, alt_values)): + alleles[j, 0] = ref + alleles[j, 1 : 1 + len(alt)] = alt + allele_array[:] = alleles + + with progress_counter.get_lock(): + for col in [ref_col, alt_col]: + progress_counter.value += col.vcf_field.summary.uncompressed_size + + def encode_samples(self, pcvcf, sample_id, chunk_width): + if not np.array_equal(sample_id, pcvcf.metadata.samples): + raise ValueError("Subsetting or reordering samples not supported currently") + array = self.root.array( + "sample_id", + sample_id, + dtype="str", + compressor=default_compressor, + chunks=(chunk_width,), + ) + array.attrs["_ARRAY_DIMENSIONS"] = ["samples"] + + def encode_contig(self, pcvcf, contig_names, contig_lengths): + array = self.root.array( + "contig_id", + contig_names, + dtype="str", + compressor=default_compressor, + ) + array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"] + + if contig_lengths is not None: + array = self.root.array( + "contig_length", + contig_lengths, + dtype=np.int64, + ) + array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"] + + col = pcvcf.columns["CHROM"] + array = self.root["variant_contig"] + buff = np.zeros_like(array) + lookup = {v: j for j, v in enumerate(contig_names)} + for j, contig in enumerate(col.values()): + try: + buff[j] = lookup[contig] + except KeyError: + # TODO add advice about adding it to the spec + raise ValueError(f"Contig '{contig}' was not defined in the header.") + + array[:] = buff + + with progress_counter.get_lock(): + progress_counter.value += col.vcf_field.summary.uncompressed_size + + def encode_filters(self, pcvcf, filter_names): + self.root.attrs["filters"] = filter_names + array = self.root.array( + "filter_id", + filter_names, + dtype="str", + compressor=default_compressor, + ) + array.attrs["_ARRAY_DIMENSIONS"] = ["filters"] + + col = pcvcf.columns["FILTERS"] + array = self.root["variant_filter"] + buff = np.zeros_like(array) + + lookup = {v: j for j, v in enumerate(filter_names)} + for j, filters in enumerate(col.values()): + try: + for f in filters: + buff[j, lookup[f]] = True + except IndexError: + raise ValueError(f"Filter '{f}' was not defined in the header.") + + array[:] = buff + + with progress_counter.get_lock(): + progress_counter.value += col.vcf_field.summary.uncompressed_size + + def encode_id(self, pcvcf): + col = pcvcf.columns["ID"] + id_array = self.root["variant_id"] + id_mask_array = self.root["variant_id_mask"] + id_buff = np.full_like(id_array, "") + id_mask_buff = np.zeros_like(id_mask_array) + + for j, value in enumerate(col.values()): + if value is not None: + id_buff[j] = value + else: + id_buff[j] = "." # TODO is this correct?? + id_mask_buff[j] = True + + id_array[:] = id_buff + id_mask_array[:] = id_mask_buff + + with progress_counter.get_lock(): + progress_counter.value += col.vcf_field.summary.uncompressed_size + + @staticmethod + def convert( + pcvcf, path, conversion_spec, *, worker_processes=1, show_progress=False + ): + store = zarr.DirectoryStore(path) + # FIXME + sgvcf = SgvcfZarr(path) + sgvcf.root = zarr.group(store=store, overwrite=True) + for variable in conversion_spec.variables[:]: + sgvcf.create_array(variable) + + global progress_counter + progress_counter = multiprocessing.Value("Q", 0) + + # start update progress bar process + bar_thread = None + if show_progress: + bar_thread = threading.Thread( + target=update_bar, + args=(progress_counter, pcvcf.total_uncompressed_bytes, "Encode", "b"), + name="progress", + daemon=True, + ) + bar_thread.start() + + with cf.ProcessPoolExecutor( + max_workers=worker_processes, + initializer=init_workers, + initargs=(progress_counter,), + ) as executor: + futures = [ + executor.submit( + sgvcf.encode_samples, + pcvcf, + conversion_spec.sample_id, + conversion_spec.chunk_width, + ), + executor.submit(sgvcf.encode_alleles, pcvcf), + executor.submit(sgvcf.encode_id, pcvcf), + executor.submit( + sgvcf.encode_contig, + pcvcf, + conversion_spec.contig_id, + conversion_spec.contig_length, + ), + executor.submit(sgvcf.encode_filters, pcvcf, conversion_spec.filter_id), + ] + has_gt = False + for variable in conversion_spec.variables[:]: + if variable.vcf_field is not None: + # print("Encode", variable.name) + # TODO for large columns it's probably worth splitting up + # these into vertical chunks. Otherwise we tend to get a + # long wait for the largest GT columns to finish. + # Straightforward to do because we can chunk-align the work + # packages. + future = executor.submit(sgvcf.encode_column, pcvcf, variable) + futures.append(future) + else: + if variable.name == "call_genotype": + has_gt = True + if has_gt: + # TODO add mixed ploidy + futures.append(executor.submit(sgvcf.encode_genotypes, pcvcf)) + + flush_futures(futures) + + zarr.consolidate_metadata(path) + # FIXME can't join the bar_thread because we never get to the correct + # number of bytes + # if bar_thread is not None: + # bar_thread.join() + + +def sync_flush_array(np_buffer, zarr_array, offset): + zarr_array[offset : offset + np_buffer.shape[0]] = np_buffer + + +def async_flush_array(executor, np_buffer, zarr_array, offset): + """ + Flush the specified chunk aligned buffer to the specified zarr array. + """ + assert zarr_array.shape[1:] == np_buffer.shape[1:] + # print("sync", zarr_array, np_buffer) + + if len(np_buffer.shape) == 1: + futures = [executor.submit(sync_flush_array, np_buffer, zarr_array, offset)] + else: + futures = async_flush_2d_array(executor, np_buffer, zarr_array, offset) + return futures + + +def async_flush_2d_array(executor, np_buffer, zarr_array, offset): + # Flush each of the chunks in the second dimension separately + s = slice(offset, offset + np_buffer.shape[0]) + + def flush_chunk(start, stop): + zarr_array[s, start:stop] = np_buffer[:, start:stop] + + chunk_width = zarr_array.chunks[1] + zarr_array_width = zarr_array.shape[1] + start = 0 + futures = [] + while start < zarr_array_width: + stop = min(start + chunk_width, zarr_array_width) + future = executor.submit(flush_chunk, start, stop) + futures.append(future) + start = stop + + return futures + + +def convert_vcf( + vcfs, + out_path, + *, + chunk_length=None, + chunk_width=None, + worker_processes=1, + show_progress=False, +): + with tempfile.TemporaryDirectory() as intermediate_form_dir: + explode( + vcfs, + intermediate_form_dir, + worker_processes=worker_processes, + show_progress=show_progress, + ) + + pcvcf = PickleChunkedVcf.load(intermediate_form_dir) + spec = ZarrConversionSpec.generate( + pcvcf, chunk_length=chunk_length, chunk_width=chunk_width + ) + SgvcfZarr.convert( + pcvcf, + out_path, + conversion_spec=spec, + worker_processes=worker_processes, + show_progress=show_progress, + ) + + +def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_variant): + bed = bed_reader.open_bed(bed_path, num_threads=1) + + store = zarr.DirectoryStore(zarr_path) + root = zarr.group(store=store) + gt = BufferedArray(root["call_genotype"]) + gt_mask = BufferedArray(root["call_genotype_mask"]) + gt_phased = BufferedArray(root["call_genotype_phased"]) + chunk_length = gt.array.chunks[0] + assert start_variant % chunk_length == 0 + + buffered_arrays = [gt, gt_phased, gt_mask] + + with cf.ThreadPoolExecutor(max_workers=8) as executor: + futures = [] + + start = start_variant + while start < end_variant: + stop = min(start + chunk_length, end_variant) + bed_chunk = bed.read(index=slice(start, stop), dtype="int8").T + # Note could do this without iterating over rows, but it's a bit + # simpler and the bottleneck is in the encoding step anyway. It's + # also nice to have updates on the progress monitor. + for j, values in enumerate(bed_chunk): + dest = gt.buff[j] + dest[values == -127] = -1 + dest[values == 2] = 1 + dest[values == 1, 0] = 1 + gt_phased.buff[j] = False + gt_mask.buff[j] = dest == -1 + with progress_counter.get_lock(): + progress_counter.value += 1 + + assert j <= chunk_length + flush_futures(futures) + for ba in buffered_arrays: + futures.extend( + async_flush_array(executor, ba.buff[:j], ba.array, start) + ) + ba.swap_buffers() + start = stop + flush_futures(futures) + + +def validate(vcf_path, zarr_path, show_progress): + store = zarr.DirectoryStore(zarr_path) + + root = zarr.group(store=store) + pos = root["variant_position"][:] + allele = root["variant_allele"][:] + chrom = root["contig_id"][:][root["variant_contig"][:]] + vid = root["variant_id"][:] + call_genotype = iter(root["call_genotype"]) + + vcf = cyvcf2.VCF(vcf_path) + format_headers = {} + info_headers = {} + for h in vcf.header_iter(): + if h["HeaderType"] == "FORMAT": + format_headers[h["ID"]] = h + if h["HeaderType"] == "INFO": + info_headers[h["ID"]] = h + + format_fields = {} + info_fields = {} + for colname in root.keys(): + if colname.startswith("call") and not colname.startswith("call_genotype"): + vcf_name = colname.split("_", 1)[1] + vcf_type = format_headers[vcf_name]["Type"] + format_fields[vcf_name] = vcf_type, iter(root[colname]) + if colname.startswith("variant"): + name = colname.split("_", 1)[1] + if name.isupper(): + vcf_type = info_headers[name]["Type"] + # print(root[colname]) + info_fields[name] = vcf_type, iter(root[colname]) + # print(info_fields) + + first_pos = next(vcf).POS + start_index = np.searchsorted(pos, first_pos) + assert pos[start_index] == first_pos + vcf = cyvcf2.VCF(vcf_path) + iterator = tqdm.tqdm(vcf, total=vcf.num_records) + for j, row in enumerate(iterator, start_index): + assert chrom[j] == row.CHROM + assert pos[j] == row.POS + assert vid[j] == ("." if row.ID is None else row.ID) + assert allele[j, 0] == row.REF + k = len(row.ALT) + nt.assert_array_equal(allele[j, 1 : k + 1], row.ALT), + assert np.all(allele[j, k + 1 :] == "") + # TODO FILTERS + + gt = row.genotype.array() + gt_zarr = next(call_genotype) + gt_vcf = gt[:, :-1] + # NOTE weirdly cyvcf2 seems to remap genotypes automatically + # into the same missing/pad encoding that sgkit uses. + # if np.any(gt_zarr < 0): + # print("MISSING") + # print(gt_zarr) + # print(gt_vcf) + nt.assert_array_equal(gt_zarr, gt_vcf) + + # TODO this is basically right, but the details about float padding + # need to be worked out in particular. Need to find examples of + # VCFs with Number=. Float fields. + for name, (vcf_type, zarr_iter) in info_fields.items(): + vcf_val = None + try: + vcf_val = row.INFO[name] + except KeyError: + pass + zarr_val = next(zarr_iter) + if vcf_val is None: + if vcf_type == "Integer": + assert np.all(zarr_val == -1) + elif vcf_type == "String": + assert np.all(zarr_val == ".") + elif vcf_type == "Flag": + assert zarr_val == False + elif vcf_type == "Float": + assert_all_missing_float(zarr_val) + else: + assert False + else: + # print(name, vcf_type, vcf_val, zarr_val, sep="\t") + if vcf_type == "Integer": + assert_prefix_integer_equal_1d(vcf_val, zarr_val) + elif vcf_type == "Float": + assert_prefix_float_equal_1d(vcf_val, zarr_val) + elif vcf_type == "Flag": + assert zarr_val == True + elif vcf_type == "String": + assert np.all(zarr_val == vcf_val) + else: + assert False + + for name, (vcf_type, zarr_iter) in format_fields.items(): + vcf_val = None + try: + vcf_val = row.format(name) + except KeyError: + pass + zarr_val = next(zarr_iter) + if vcf_val is None: + if vcf_type == "Integer": + assert np.all(zarr_val == -1) + elif vcf_type == "Float": + assert_all_missing_float(zarr_val) + elif vcf_type == "String": + assert np.all(zarr_val == ".") + else: + print("vcf_val", vcf_type, name, vcf_val) + assert False + else: + assert vcf_val.shape[0] == zarr_val.shape[0] + if vcf_type == "Integer": + assert_prefix_integer_equal_2d(vcf_val, zarr_val) + elif vcf_type == "Float": + assert_prefix_float_equal_2d(vcf_val, zarr_val) + elif vcf_type == "String": + nt.assert_array_equal(vcf_val, zarr_val) + + # assert_prefix_string_equal_2d(vcf_val, zarr_val) + else: + print(name) + print(vcf_val) + print(zarr_val) + assert False + + +def convert_plink( + bed_path, + zarr_path, + *, + show_progress, + worker_processes=1, + chunk_length=None, + chunk_width=None, +): + bed = bed_reader.open_bed(bed_path, num_threads=1) + n = bed.iid_count + m = bed.sid_count + del bed + + # FIXME + if chunk_width is None: + chunk_width = 1000 + if chunk_length is None: + chunk_length = 10_000 + + store = zarr.DirectoryStore(zarr_path) + root = zarr.group(store=store, overwrite=True) + + ploidy = 2 + shape = [m, n] + chunks = [chunk_length, chunk_width] + dimensions = ["variants", "samples"] + + a = root.empty( + "call_genotype_phased", + dtype="bool", + shape=list(shape), + chunks=list(chunks), + compressor=default_compressor, + ) + a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions) + + shape += [ploidy] + dimensions += ["ploidy"] + a = root.empty( + "call_genotype", + dtype="i8", + shape=list(shape), + chunks=list(chunks), + compressor=default_compressor, + ) + a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions) + + a = root.empty( + "call_genotype_mask", + dtype="bool", + shape=list(shape), + chunks=list(chunks), + compressor=default_compressor, + ) + a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions) + + global progress_counter + progress_counter = multiprocessing.Value("Q", 0) + + # start update progress bar process + bar_thread = None + if show_progress: + bar_thread = threading.Thread( + target=update_bar, + args=(progress_counter, m, "Write", "vars"), + name="progress", + daemon=True, + ) + bar_thread.start() + + num_chunks = m // chunk_length + worker_processes = min(worker_processes, num_chunks) + if num_chunks == 1 or worker_processes == 1: + partitions = [(0, m)] + else: + # Generate num_workers partitions + # TODO finer grained might be better. + partitions = [] + chunk_boundaries = [ + p[0] for p in np.array_split(np.arange(num_chunks), worker_processes) + ] + for j in range(len(chunk_boundaries) - 1): + start = chunk_boundaries[j] * chunk_length + end = chunk_boundaries[j + 1] * chunk_length + end = min(end, m) + partitions.append((start, end)) + last_stop = partitions[-1][-1] + if last_stop != m: + partitions.append((last_stop, m)) + # print(partitions) + + with cf.ProcessPoolExecutor( + max_workers=worker_processes, + initializer=init_workers, + initargs=(progress_counter,), + ) as executor: + futures = [ + executor.submit( + encode_bed_partition_genotypes, bed_path, zarr_path, start, end + ) + for start, end in partitions + ] + flush_futures(futures) + # print("progress counter = ", m, progress_counter.value) + assert progress_counter.value == m + + # print(root["call_genotype"][:]) diff --git a/tests/test_vcf.py b/tests/test_vcf.py new file mode 100644 index 0000000..4e14405 --- /dev/null +++ b/tests/test_vcf.py @@ -0,0 +1,1891 @@ +import os +import tempfile +from os import listdir +from os.path import join +from typing import MutableMapping + +import numpy as np +import pytest +import xarray as xr +import zarr +from numcodecs import Blosc, Delta, FixedScaleOffset, PackBits, VLenUTF8 +from numpy.testing import assert_allclose, assert_array_equal, assert_array_almost_equal + +from sgkit import load_dataset, save_dataset +from sgkit.io.utils import FLOAT32_FILL, FLOAT32_MISSING, INT_FILL, INT_MISSING +from sgkit.io.vcf import ( + MaxAltAllelesExceededWarning, + partition_into_regions, + read_vcf, + vcf_to_zarr, +) +from sgkit.io.vcf.vcf_reader import ( + FloatFormatFieldWarning, + merge_zarr_array_sizes, + zarr_array_sizes, +) +from sgkit.io.vcf.vcf_converter import convert_vcf, validate +from sgkit.model import get_contigs, get_filters, num_contigs +from sgkit.tests.io.test_dataset import assert_identical + +from .utils import path_for_test + + +@pytest.mark.parametrize( + "read_chunk_length", + [None, 1], +) +@pytest.mark.parametrize( + "is_path", + [True, False], +) +@pytest.mark.parametrize("method", ["to_zarr", "convert", "load"]) +@pytest.mark.filterwarnings("ignore::xarray.coding.variables.SerializationWarning") +def test_vcf_to_zarr__small_vcf( + shared_datadir, + is_path, + read_chunk_length, + tmp_path, + method, +): + path = path_for_test(shared_datadir, "sample.vcf.gz", is_path) + output = tmp_path.joinpath("vcf.zarr").as_posix() + fields = [ + "INFO/NS", + "INFO/AN", + "INFO/AA", + "INFO/DB", + "INFO/AC", + "INFO/AF", + "FORMAT/GT", + "FORMAT/DP", + "FORMAT/HQ", + ] + field_defs = { + "FORMAT/HQ": {"dimension": "ploidy"}, + "INFO/AF": {"Number": "2", "dimension": "AF"}, + "INFO/AC": {"Number": "2", "dimension": "AC"}, + } + if method == "to_zarr": + vcf_to_zarr( + path, + output, + max_alt_alleles=3, + chunk_length=5, + chunk_width=2, + read_chunk_length=read_chunk_length, + fields=fields, + field_defs=field_defs, + ) + ds = xr.open_zarr(output) + + elif method == "convert": + convert_vcf( + [path], + output, + chunk_length=5, + chunk_width=2, + ) + ds = xr.open_zarr(output) + else: + ds = read_vcf( + path, chunk_length=5, chunk_width=2, fields=fields, field_defs=field_defs + ) + + assert_array_equal(ds["filter_id"], ["PASS", "s50", "q10"]) + assert_array_equal( + ds["variant_filter"], + [ + [False, False, False], + [False, False, False], + [True, False, False], + [False, False, True], + [True, False, False], + [True, False, False], + [True, False, False], + [False, False, False], + [True, False, False], + ], + ) + assert_array_equal(ds["contig_id"], ["19", "20", "X"]) + assert "contig_length" not in ds + assert_array_equal(ds["variant_contig"], [0, 0, 1, 1, 1, 1, 1, 1, 2]) + assert ds["variant_contig"].chunks[0][0] == 5 + + assert_array_equal( + ds["variant_position"], + [111, 112, 14370, 17330, 1110696, 1230237, 1234567, 1235237, 10], + ) + assert ds["variant_position"].chunks[0][0] == 5 + + assert_array_equal( + ds["variant_NS"], + [-1, -1, 3, 3, 2, 3, 3, -1, -1], + ) + assert ds["variant_NS"].chunks[0][0] == 5 + + assert_array_equal( + ds["variant_AN"], + [-1, -1, -1, -1, -1, -1, 6, -1, -1], + ) + assert ds["variant_AN"].chunks[0][0] == 5 + + assert_array_equal( + ds["variant_AA"], + [ + ".", + ".", + ".", + ".", + "T", + "T", + "G", + ".", + ".", + ], + ) + assert ds["variant_AN"].chunks[0][0] == 5 + + assert_array_equal( + ds["variant_DB"], + [ + False, + False, + True, + False, + True, + False, + False, + False, + False, + ], + ) + assert ds["variant_AN"].chunks[0][0] == 5 + + variant_AF = np.full((9, 2), FLOAT32_MISSING, dtype=np.float32) + variant_AF[2, 0] = 0.5 + variant_AF[3, 0] = 0.017 + variant_AF[4, 0] = 0.333 + variant_AF[4, 1] = 0.667 + assert_array_almost_equal(ds["variant_AF"], variant_AF, 3) + assert ds["variant_AF"].chunks[0][0] == 5 + + assert_array_equal( + ds["variant_AC"], + [ + [-1, -1], + [-1, -1], + [-1, -1], + [-1, -1], + [-1, -1], + [-1, -1], + [3, 1], + [-1, -1], + [-1, -1], + ], + ) + assert ds["variant_AC"].chunks[0][0] == 5 + + assert_array_equal( + ds["variant_allele"].values.tolist(), + [ + ["A", "C", "", ""], + ["A", "G", "", ""], + ["G", "A", "", ""], + ["T", "A", "", ""], + ["A", "G", "T", ""], + ["T", "", "", ""], + ["G", "GA", "GAC", ""], + ["T", "", "", ""], + ["AC", "A", "ATG", "C"], + ], + ) + assert ds["variant_allele"].chunks[0][0] == 5 + assert ds["variant_allele"].dtype == "O" + assert_array_equal( + ds["variant_id"].values.tolist(), + [".", ".", "rs6054257", ".", "rs6040355", ".", "microsat1", ".", "rsTest"], + ) + assert ds["variant_id"].chunks[0][0] == 5 + assert ds["variant_id"].dtype == "O" + assert_array_equal( + ds["variant_id_mask"], + [True, True, False, True, False, True, False, True, False], + ) + assert ds["variant_id_mask"].chunks[0][0] == 5 + + assert_array_equal(ds["sample_id"], ["NA00001", "NA00002", "NA00003"]) + assert ds["sample_id"].chunks[0][0] == 2 + + call_genotype = np.array( + [ + [[0, 0], [0, 0], [0, 1]], + [[0, 0], [0, 0], [0, 1]], + [[0, 0], [1, 0], [1, 1]], + [[0, 0], [0, 1], [0, 0]], + [[1, 2], [2, 1], [2, 2]], + [[0, 0], [0, 0], [0, 0]], + [[0, 1], [0, 2], [-1, -1]], + [[0, 0], [0, 0], [-1, -1]], + # NOTE: inconsistency here on pad vs missing. I think this is a + # pad value. + [[0, -2], [0, 1], [0, 2]], + ], + dtype="i1", + ) + call_genotype_phased = np.array( + [ + [True, True, False], + [True, True, False], + [True, True, False], + [True, True, False], + [True, True, False], + [True, True, False], + [False, False, False], + [False, True, False], + [True, False, True], + ], + dtype=bool, + ) + call_DP = [ + [-1, -1, -1], + [-1, -1, -1], + [1, 8, 5], + [3, 5, 3], + [6, 0, 4], + [-1, 4, 2], + [4, 2, 3], + [-1, -1, -1], + [-1, -1, -1], + ] + call_HQ = [ + [[10, 15], [10, 10], [3, 3]], + [[10, 10], [10, 10], [3, 3]], + [[51, 51], [51, 51], [-1, -1]], + [[58, 50], [65, 3], [-1, -1]], + [[23, 27], [18, 2], [-1, -1]], + [[56, 60], [51, 51], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]], + ] + + # print(np.array2string(ds["call_HQ"].values, separator=",")) + # print(np.array2string(ds["call_genotype"].values < 0, separator=",")) + + assert_array_equal(ds["call_genotype"], call_genotype) + assert_array_equal(ds["call_genotype_mask"], call_genotype < 0) + assert_array_equal(ds["call_genotype_phased"], call_genotype_phased) + assert_array_equal(ds["call_DP"], call_DP) + assert_array_equal(ds["call_HQ"], call_HQ) + + for name in ["call_genotype", "call_genotype_mask", "call_HQ"]: + assert ds[name].chunks == ((5, 4), (2, 1), (2,)) + + for name in ["call_genotype_phased", "call_DP"]: + assert ds[name].chunks == ((5, 4), (2, 1)) + + # save and load again to test https://github.com/pydata/xarray/issues/3476 + path2 = tmp_path / "ds2.zarr" + if not is_path: + path2 = str(path2) + save_dataset(ds, path2) + assert_identical(ds, load_dataset(path2)) + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +def test_vcf_to_zarr__max_alt_alleles(shared_datadir, is_path, tmp_path): + path = path_for_test(shared_datadir, "sample.vcf.gz", is_path) + output = tmp_path.joinpath("vcf.zarr").as_posix() + + with pytest.warns(MaxAltAllelesExceededWarning): + max_alt_alleles = 1 + vcf_to_zarr( + path, output, chunk_length=5, chunk_width=2, max_alt_alleles=max_alt_alleles + ) + ds = xr.open_zarr(output) + + # extra alt alleles are dropped + assert_array_equal( + ds["variant_allele"].values.tolist(), + [ + ["A", "C"], + ["A", "G"], + ["G", "A"], + ["T", "A"], + ["A", "G"], + ["T", ""], + ["G", "GA"], + ["T", ""], + ["AC", "A"], + ], + ) + + # genotype calls are truncated + assert np.all(ds["call_genotype"].values <= max_alt_alleles) + + # the maximum number of alt alleles actually seen is stored as an attribute + assert ds.attrs["max_alt_alleles_seen"] == 3 + + +@pytest.mark.parametrize( + "read_chunk_length", + [None, 1_000], +) +@pytest.mark.parametrize( + "is_path", + [True, False], +) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__large_vcf(shared_datadir, is_path, read_chunk_length, tmp_path): + path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr(path, output, chunk_length=5_000, read_chunk_length=read_chunk_length) + ds = xr.open_zarr(output) + + assert_array_equal(ds["contig_id"], ["20", "21"]) + assert_array_equal(ds["contig_length"], [63025520, 48129895]) + assert ds["sample_id"].shape == (1,) + assert ds["call_genotype"].shape == (19910, 1, 2) + assert ds["call_genotype_mask"].shape == (19910, 1, 2) + assert ds["call_genotype_phased"].shape == (19910, 1) + assert ds["variant_allele"].shape == (19910, 4) + assert ds["variant_contig"].shape == (19910,) + assert ds["variant_id"].shape == (19910,) + assert ds["variant_id_mask"].shape == (19910,) + assert ds["variant_position"].shape == (19910,) + + assert ds["variant_allele"].dtype == "O" + assert ds["variant_id"].dtype == "O" + + # check underlying zarr chunk size is 1 in samples dim + za = zarr.open(output) + assert za["sample_id"].chunks == (1,) + assert za["call_genotype"].chunks == (5000, 1, 2) + + +def test_vcf_to_zarr__plain_vcf_with_no_index(shared_datadir, tmp_path): + path = path_for_test( + shared_datadir, + "mixed.vcf", + ) + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr(path, output, truncate_calls=True) + ds = xr.open_zarr(output) + assert ds["sample_id"].shape == (3,) + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__mutable_mapping(shared_datadir, is_path): + path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) + output: MutableMapping[str, bytes] = {} + + vcf_to_zarr(path, output, chunk_length=5_000) + ds = xr.open_zarr(output) + + assert ds["sample_id"].shape == (1,) + assert ds["call_genotype"].shape == (19910, 1, 2) + assert ds["call_genotype_mask"].shape == (19910, 1, 2) + assert ds["call_genotype_phased"].shape == (19910, 1) + assert ds["variant_allele"].shape == (19910, 4) + assert ds["variant_contig"].shape == (19910,) + assert ds["variant_id"].shape == (19910,) + assert ds["variant_id_mask"].shape == (19910,) + assert ds["variant_position"].shape == (19910,) + + assert ds["variant_allele"].dtype == "O" + assert ds["variant_id"].dtype == "O" + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path): + path = path_for_test(shared_datadir, "sample.vcf.gz", is_path) + output = tmp_path.joinpath("vcf.zarr").as_posix() + + compressor = Blosc("zlib", 1, Blosc.NOSHUFFLE) + variant_id_compressor = Blosc("zlib", 2, Blosc.NOSHUFFLE) + encoding = dict( + variant_id=dict(compressor=variant_id_compressor), + variant_id_mask=dict(filters=None), + ) + vcf_to_zarr( + path, + output, + chunk_length=5, + chunk_width=2, + compressor=compressor, + encoding=encoding, + ) + + # look at actual Zarr store to check compressor and filters + z = zarr.open(output) + assert z["call_genotype"].compressor == compressor + assert z["call_genotype"].filters is None # sgkit default + assert z["call_genotype"].chunks == (5, 2, 2) + assert z["call_genotype_mask"].compressor == compressor + assert z["call_genotype_mask"].filters == [PackBits()] # sgkit default + assert z["call_genotype_mask"].chunks == (5, 2, 2) + + assert z["variant_id"].compressor == variant_id_compressor + assert z["variant_id"].filters == [VLenUTF8()] # sgkit default + assert z["variant_id"].chunks == (5,) + assert z["variant_id_mask"].compressor == compressor + assert z["variant_id_mask"].filters is None + assert z["variant_id_mask"].chunks == (5,) + + assert z["variant_position"].filters == [ + Delta(dtype="i4", astype="i4") + ] # sgkit default + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__parallel_compressor_and_filters( + shared_datadir, is_path, tmp_path +): + path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + regions = ["20", "21"] + + compressor = Blosc("zlib", 1, Blosc.NOSHUFFLE) + variant_id_compressor = Blosc("zlib", 2, Blosc.NOSHUFFLE) + encoding = dict( + variant_id=dict(compressor=variant_id_compressor), + variant_id_mask=dict(filters=None), + ) + vcf_to_zarr( + path, + output, + regions=regions, + chunk_length=5_000, + compressor=compressor, + encoding=encoding, + ) + + # look at actual Zarr store to check compressor and filters + z = zarr.open(output) + assert z["call_genotype"].compressor == compressor + assert z["call_genotype"].filters is None # sgkit default + assert z["call_genotype"].chunks == (5000, 1, 2) + assert z["call_genotype_mask"].compressor == compressor + assert z["call_genotype_mask"].filters == [PackBits()] # sgkit default + assert z["call_genotype_mask"].chunks == (5000, 1, 2) + + assert z["variant_id"].compressor == variant_id_compressor + assert z["variant_id"].filters == [VLenUTF8()] # sgkit default + assert z["variant_id"].chunks == (5000,) + assert z["variant_id_mask"].compressor == compressor + assert z["variant_id_mask"].filters is None + assert z["variant_id_mask"].chunks == (5000,) + + assert z["variant_position"].filters == [ + Delta(dtype="i4", astype="i4") + ] # sgkit default + + +def test_vcf_to_zarr__float_format_field_warning(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "simple.output.mixed_depth.likelihoods.vcf") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + with pytest.warns(FloatFormatFieldWarning): + vcf_to_zarr( + path, + output, + ploidy=4, + max_alt_alleles=3, + fields=["FORMAT/GL"], + ) + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +@pytest.mark.parametrize( + "output_is_path", + [True, False], +) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__parallel(shared_datadir, is_path, output_is_path, tmp_path): + path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) + output = tmp_path.joinpath("vcf_concat.zarr") + if not output_is_path: + output = output.as_posix() + + regions = ["20", "21"] + + vcf_to_zarr( + path, + output, + regions=regions, + chunk_length=5_000, + ) + ds = xr.open_zarr(output) + + assert ds["sample_id"].shape == (1,) + assert ds["call_genotype"].shape == (19910, 1, 2) + assert ds["call_genotype_mask"].shape == (19910, 1, 2) + assert ds["call_genotype_phased"].shape == (19910, 1) + assert ds["variant_allele"].shape == (19910, 4) + assert ds["variant_contig"].shape == (19910,) + assert ds["variant_id"].shape == (19910,) + assert ds["variant_id_mask"].shape == (19910,) + assert ds["variant_position"].shape == (19910,) + + assert ds["variant_allele"].dtype == "O" + assert ds["variant_id"].dtype == "O" + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_vcf_to_zarr__empty_region(shared_datadir, is_path, tmp_path): + path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + regions = "23" + + vcf_to_zarr(path, output, regions=regions) + ds = xr.open_zarr(output) + + assert ds["sample_id"].shape == (1,) + assert ds["call_genotype"].shape == (0, 1, 2) + assert ds["call_genotype_mask"].shape == (0, 1, 2) + assert ds["call_genotype_phased"].shape == (0, 1) + assert ds["variant_allele"].shape == (0, 4) + assert ds["variant_contig"].shape == (0,) + assert ds["variant_id"].shape == (0,) + assert ds["variant_id_mask"].shape == (0,) + assert ds["variant_position"].shape == (0,) + + +@pytest.mark.parametrize( + "is_path", + [False], +) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__parallel_temp_chunk_length(shared_datadir, is_path, tmp_path): + path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + regions = ["20", "21"] + + # Use a temp_chunk_length that is smaller than chunk_length + # Open the temporary parts to check that they have the right temp chunk length + with tempfile.TemporaryDirectory() as tempdir: + vcf_to_zarr( + path, + output, + regions=regions, + chunk_length=5_000, + temp_chunk_length=2_500, + tempdir=tempdir, + retain_temp_files=True, + ) + inner_temp_dir = join(tempdir, listdir(tempdir)[0]) + parts_dir = join(inner_temp_dir, listdir(inner_temp_dir)[0]) + part = xr.open_zarr(join(parts_dir, "part-0.zarr")) + assert part["call_genotype"].chunks[0][0] == 2_500 + assert part["variant_position"].chunks[0][0] == 2_500 + ds = xr.open_zarr(output) + + assert ds["sample_id"].shape == (1,) + assert ds["call_genotype"].shape == (19910, 1, 2) + assert ds["call_genotype"].chunks[0][0] == 5_000 + assert ds["call_genotype_mask"].shape == (19910, 1, 2) + assert ds["call_genotype_phased"].shape == (19910, 1) + assert ds["variant_allele"].shape == (19910, 4) + assert ds["variant_contig"].shape == (19910,) + assert ds["variant_id"].shape == (19910,) + assert ds["variant_id_mask"].shape == (19910,) + assert ds["variant_position"].shape == (19910,) + assert ds["variant_position"].chunks[0][0] == 5_000 + + assert ds["variant_allele"].dtype == "O" + assert ds["variant_id"].dtype == "O" + + +def test_vcf_to_zarr__parallel_temp_chunk_length_not_divisible( + shared_datadir, tmp_path +): + path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", False) + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + regions = ["20", "21"] + + with pytest.raises( + ValueError, + match=r"Temporary chunk length in variant dimension \(4000\) must evenly divide target chunk length 5000", + ): + # Use a temp_chunk_length that does not divide into chunk_length + vcf_to_zarr( + path, output, regions=regions, chunk_length=5_000, temp_chunk_length=4_000 + ) + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +def test_vcf_to_zarr__parallel_partitioned(shared_datadir, is_path, tmp_path): + path = path_for_test( + shared_datadir, + "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz", + is_path, + ) + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + + regions = partition_into_regions(path, num_parts=4) + + vcf_to_zarr(path, output, regions=regions, chunk_length=1_000, chunk_width=1_000) + ds = xr.open_zarr(output) + + assert ds["sample_id"].shape == (2535,) + assert ds["variant_id"].shape == (1406,) + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +def test_vcf_to_zarr__parallel_partitioned_by_size(shared_datadir, is_path, tmp_path): + path = path_for_test( + shared_datadir, + "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz", + is_path, + ) + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + + vcf_to_zarr( + path, output, target_part_size="4MB", chunk_length=1_000, chunk_width=1_000 + ) + ds = xr.open_zarr(output) + + assert ds["sample_id"].shape == (2535,) + assert ds["variant_id"].shape == (1406,) + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path): + paths = [ + path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), + path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), + ] + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + + vcf_to_zarr(paths, output, target_part_size=None, chunk_length=5_000) + ds = xr.open_zarr(output) + + assert ds["sample_id"].shape == (1,) + assert ds["call_genotype"].shape == (19910, 1, 2) + assert ds["call_genotype_mask"].shape == (19910, 1, 2) + assert ds["call_genotype_phased"].shape == (19910, 1) + assert ds["variant_allele"].shape == (19910, 4) + assert ds["variant_contig"].shape == (19910,) + assert ds["variant_id"].shape == (19910,) + assert ds["variant_id_mask"].shape == (19910,) + assert ds["variant_position"].shape == (19910,) + + assert ds.chunks["variants"] == (5000, 5000, 5000, 4910) + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path): + paths = [ + path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), + path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), + ] + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + + regions = [partition_into_regions(path, num_parts=2) for path in paths] + + vcf_to_zarr(paths, output, regions=regions, chunk_length=5_000) + ds = xr.open_zarr(output) + + assert ds["sample_id"].shape == (1,) + assert ds["call_genotype"].shape == (19910, 1, 2) + assert ds["call_genotype_mask"].shape == (19910, 1, 2) + assert ds["call_genotype_phased"].shape == (19910, 1) + assert ds["variant_allele"].shape == (19910, 4) + assert ds["variant_contig"].shape == (19910,) + assert ds["variant_id"].shape == (19910,) + assert ds["variant_id_mask"].shape == (19910,) + assert ds["variant_position"].shape == (19910,) + + assert ds.chunks["variants"] == (5000, 5000, 5000, 4910) + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__multiple_partitioned_by_size(shared_datadir, is_path, tmp_path): + paths = [ + path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), + path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), + ] + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + + vcf_to_zarr(paths, output, target_part_size="40KB", chunk_length=5_000) + ds = xr.open_zarr(output) + + assert ds["sample_id"].shape == (1,) + assert ds["call_genotype"].shape == (19910, 1, 2) + assert ds["call_genotype_mask"].shape == (19910, 1, 2) + assert ds["call_genotype_phased"].shape == (19910, 1) + assert ds["variant_allele"].shape == (19910, 4) + assert ds["variant_contig"].shape == (19910,) + assert ds["variant_id"].shape == (19910,) + assert ds["variant_id_mask"].shape == (19910,) + assert ds["variant_position"].shape == (19910,) + + assert ds.chunks["variants"] == (5000, 5000, 5000, 4910) + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +def test_vcf_to_zarr__mutiple_partitioned_invalid_regions( + shared_datadir, is_path, tmp_path +): + paths = [ + path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), + path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), + ] + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + + # invalid regions, should be a sequence of sequences + regions = partition_into_regions(paths[0], num_parts=2) + + with pytest.raises( + ValueError, + match=r"multiple input regions must be a sequence of sequence of strings", + ): + vcf_to_zarr(paths, output, regions=regions, chunk_length=5_000) + + +@pytest.mark.parametrize( + "is_path", + [True, False], +) +def test_vcf_to_zarr__multiple_max_alt_alleles(shared_datadir, is_path, tmp_path): + paths = [ + path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), + path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), + ] + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + + with pytest.warns(MaxAltAllelesExceededWarning): + vcf_to_zarr( + paths, + output, + target_part_size="40KB", + chunk_length=5_000, + max_alt_alleles=1, + ) + ds = xr.open_zarr(output) + + # the maximum number of alt alleles actually seen is stored as an attribute + assert ds.attrs["max_alt_alleles_seen"] == 7 + + +@pytest.mark.parametrize( + "max_alt_alleles,dtype,warning", + [ + (2, np.int8, True), + (127, np.int8, True), + (128, np.int16, True), + (145, np.int16, True), + (164, np.int16, False), + ], +) +def test_vcf_to_zarr__call_genotype_dtype( + shared_datadir, tmp_path, max_alt_alleles, dtype, warning +): + path = path_for_test(shared_datadir, "allele_overflow.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + if warning: + with pytest.warns(MaxAltAllelesExceededWarning): + vcf_to_zarr(path, output, max_alt_alleles=max_alt_alleles) + else: + vcf_to_zarr(path, output, max_alt_alleles=max_alt_alleles) + ds = load_dataset(output) + assert ds.call_genotype.dtype == dtype + assert ds.call_genotype.values.max() <= max_alt_alleles + + +@pytest.mark.parametrize( + "ploidy,mixed_ploidy,truncate_calls,regions", + [ + (2, False, True, None), + (4, False, False, None), + (4, False, False, ["CHR1:0-5", "CHR1:5-10"]), + (4, True, False, None), + (4, True, False, ["CHR1:0-5", "CHR1:5-10"]), + (5, True, False, None), + ], +) +def test_vcf_to_zarr__mixed_ploidy_vcf( + shared_datadir, tmp_path, ploidy, mixed_ploidy, truncate_calls, regions +): + path = path_for_test(shared_datadir, "mixed.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr( + path, + output, + regions=regions, + chunk_length=5, + chunk_width=2, + ploidy=ploidy, + mixed_ploidy=mixed_ploidy, + truncate_calls=truncate_calls, + ) + ds = load_dataset(output) + + variant_dtype = "O" + assert_array_equal(ds["contig_id"], ["CHR1", "CHR2", "CHR3"]) + assert_array_equal(ds["variant_contig"], [0, 0]) + assert_array_equal(ds["variant_position"], [2, 7]) + assert_array_equal( + ds["variant_allele"].values.tolist(), + np.array( + [ + ["A", "T", "", ""], + ["A", "C", "", ""], + ], + dtype=variant_dtype, + ), + ) + assert ds["variant_allele"].dtype == variant_dtype # type: ignore[comparison-overlap] + assert_array_equal( + ds["variant_id"], + np.array([".", "."], dtype=variant_dtype), + ) + assert ds["variant_id"].dtype == variant_dtype # type: ignore[comparison-overlap] + assert_array_equal( + ds["variant_id_mask"], + [True, True], + ) + assert_array_equal(ds["sample_id"], ["SAMPLE1", "SAMPLE2", "SAMPLE3"]) + + assert ds["call_genotype"].attrs["mixed_ploidy"] == mixed_ploidy + pad = -2 if mixed_ploidy else -1 # -2 indicates a fill (non-allele) value + call_genotype = np.array( + [ + [[0, 0, 1, 1, pad], [0, 0, pad, pad, pad], [0, 0, 0, 1, pad]], + [[0, 0, 1, 1, pad], [0, 1, pad, pad, pad], [0, 1, -1, -1, pad]], + ], + dtype="i1", + ) + # truncate row vectors if lower ploidy + call_genotype = call_genotype[:, :, 0:ploidy] + + assert_array_equal(ds["call_genotype"], call_genotype) + assert_array_equal(ds["call_genotype_mask"], call_genotype < 0) + if mixed_ploidy: + assert_array_equal(ds["call_genotype_fill"], call_genotype < -1) + + +@pytest.mark.parametrize( + "ploidy,mixed_ploidy,truncate_calls", + [ + (2, False, False), + (3, True, False), + ], +) +def test_vcf_to_zarr__mixed_ploidy_vcf_exception( + shared_datadir, tmp_path, ploidy, mixed_ploidy, truncate_calls +): + path = path_for_test(shared_datadir, "mixed.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + with pytest.raises(ValueError) as excinfo: + vcf_to_zarr( + path, + output, + ploidy=ploidy, + mixed_ploidy=mixed_ploidy, + truncate_calls=truncate_calls, + ) + assert "Genotype call longer than ploidy." == str(excinfo.value) + + +def test_vcf_to_zarr__no_genotypes(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "no_genotypes.vcf") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr(path, output) + + ds = xr.open_zarr(output) + + assert "call_genotype" not in ds + assert "call_genotype_mask" not in ds + assert "call_genotype_phased" not in ds + + assert ds["sample_id"].shape == (0,) + assert ds["variant_allele"].shape == (26, 4) + assert ds["variant_contig"].shape == (26,) + assert ds["variant_id"].shape == (26,) + assert ds["variant_id_mask"].shape == (26,) + assert ds["variant_position"].shape == (26,) + + +def test_vcf_to_zarr__no_genotypes_with_gt_header(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "no_genotypes_with_gt_header.vcf") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr(path, output) + + ds = xr.open_zarr(output) + + assert_array_equal(ds["call_genotype"], -1) + assert_array_equal(ds["call_genotype_mask"], 1) + assert_array_equal(ds["call_genotype_phased"], 0) + + assert ds["sample_id"].shape == (0,) + assert ds["variant_allele"].shape == (26, 4) + assert ds["variant_contig"].shape == (26,) + assert ds["variant_id"].shape == (26,) + assert ds["variant_id_mask"].shape == (26,) + assert ds["variant_position"].shape == (26,) + + +def test_vcf_to_zarr__contig_not_defined_in_header(shared_datadir, tmp_path): + # sample.vcf does not define the contigs in the header, and isn't indexed + path = path_for_test(shared_datadir, "sample.vcf") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + with pytest.raises( + ValueError, + match=r"Contig '19' is not defined in the header.", + ): + vcf_to_zarr(path, output) + + +def test_vcf_to_zarr__filter_not_defined_in_header(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "no_filter_defined.vcf") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + with pytest.raises( + ValueError, + match=r"Filter 'FAIL' is not defined in the header.", + ): + vcf_to_zarr(path, output) + + +def test_vcf_to_zarr__info_name_clash(shared_datadir, tmp_path): + # info_name_clash.vcf has an info field called 'id' which would be mapped to + # 'variant_id', clashing with the fixed field of the same name + path = path_for_test(shared_datadir, "info_name_clash.vcf") + output = tmp_path.joinpath("info_name_clash.zarr").as_posix() + + vcf_to_zarr(path, output) # OK if problematic field is ignored + + with pytest.raises( + ValueError, + match=r"Generated name for INFO field 'id' clashes with 'variant_id' from fixed VCF fields.", + ): + vcf_to_zarr(path, output, fields=["INFO/id"]) + + +def test_vcf_to_zarr__large_number_of_contigs(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "Homo_sapiens_assembly38.headerOnly.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr(path, output) + + ds = xr.open_zarr(output) + + assert len(ds["contig_id"]) == 3366 + assert ds["variant_contig"].dtype == np.int16 # needs larger dtype than np.int8 + + +def test_vcf_to_zarr__fields(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "sample.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr( + path, + output, + chunk_length=5, + chunk_width=2, + fields=["INFO/DP", "INFO/AA", "INFO/DB", "FORMAT/DP"], + ) + ds = xr.open_zarr(output) + + missing, fill = INT_MISSING, INT_FILL + assert_array_equal(ds["variant_DP"], [fill, fill, 14, 11, 10, 13, 9, fill, fill]) + assert ds["variant_DP"].attrs["comment"] == "Total Depth" + + assert_array_equal( + ds["variant_AA"], + np.array(["", "", "", "", "T", "T", "G", "", ""], dtype="O"), + ) + assert ds["variant_AA"].attrs["comment"] == "Ancestral Allele" + + assert_array_equal( + ds["variant_DB"], [False, False, True, False, True, False, False, False, False] + ) + assert ds["variant_DB"].attrs["comment"] == "dbSNP membership, build 129" + + dp = np.array( + [ + [fill, fill, fill], + [fill, fill, fill], + [1, 8, 5], + [3, 5, 3], + [6, 0, 4], + [missing, 4, 2], + [4, 2, 3], + [fill, fill, fill], + [fill, fill, fill], + ], + dtype="i4", + ) + assert_array_equal(ds["call_DP"], dp) + assert ds["call_DP"].attrs["comment"] == "Read Depth" + + +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__parallel_with_fields(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + regions = ["20", "21"] + + vcf_to_zarr( + path, + output, + regions=regions, + chunk_length=5_000, + temp_chunk_length=2_500, + fields=["INFO/MQ", "FORMAT/PGT"], + ) + ds = xr.open_zarr(output) + + # select a small region to check + ds = ds.set_index(variants=("variant_contig", "variant_position")).sel( + variants=slice((0, 10001661), (0, 10001670)) + ) + + # check strings have not been truncated after concat_zarrs + assert_array_equal( + ds["variant_allele"], + np.array( + [ + ["T", "C", "", ""], + ["T", "", "", ""], + ["T", "G", "", ""], + ], + dtype="O", + ), + ) + + # convert floats to ints to check nan type + fill = FLOAT32_FILL + assert_allclose( + ds["variant_MQ"].values.view("i4"), + np.array([58.33, fill, 57.45], dtype="f4").view("i4"), + ) + assert ds["variant_MQ"].attrs["comment"] == "RMS Mapping Quality" + + assert_array_equal(ds["call_PGT"], np.array([["0|1"], [""], ["0|1"]], dtype="O")) + assert ( + ds["call_PGT"].attrs["comment"] + == "Physical phasing haplotype information, describing how the alternate alleles are phased in relation to one another" + ) + + +def test_vcf_to_zarr__field_defs(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "sample.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr( + path, + output, + fields=["INFO/DP"], + field_defs={"INFO/DP": {"Description": "Combined depth across samples"}}, + ) + ds = xr.open_zarr(output) + + fill = INT_FILL + assert_array_equal(ds["variant_DP"], [fill, fill, 14, 11, 10, 13, 9, fill, fill]) + assert ds["variant_DP"].attrs["comment"] == "Combined depth across samples" + + vcf_to_zarr( + path, + output, + fields=["INFO/DP"], + field_defs={"INFO/DP": {"Description": ""}}, # blank description + ) + ds = xr.open_zarr(output) + + assert_array_equal(ds["variant_DP"], [fill, fill, 14, 11, 10, 13, 9, fill, fill]) + assert "comment" not in ds["variant_DP"].attrs + + +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "sample.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr( + path, + output, + max_alt_alleles=2, + fields=["INFO/AC"], + field_defs={"INFO/AC": {"Number": "A"}}, + ) + ds = xr.open_zarr(output) + + fill = INT_FILL + assert_array_equal( + ds["variant_AC"], + [ + [fill, fill], + [fill, fill], + [fill, fill], + [fill, fill], + [fill, fill], + [fill, fill], + [3, 1], + [fill, fill], + [fill, fill], + ], + ) + assert ( + ds["variant_AC"].attrs["comment"] + == "Allele count in genotypes, for each ALT allele, in the same order as listed" + ) + + +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__field_number_R(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr( + path, + output, + fields=["FORMAT/AD"], + field_defs={"FORMAT/AD": {"Number": "R"}}, + ) + ds = xr.open_zarr(output) + + # select a small region to check + ds = ds.set_index(variants="variant_position").sel( + variants=slice(10002764, 10002793) + ) + + fill = INT_FILL + ad = np.array( + [ + [[40, 14, 0, fill]], + [[fill, fill, fill, fill]], + [[65, 8, 5, 0]], + [[fill, fill, fill, fill]], + ], + ) + assert_array_equal(ds["call_AD"], ad) + assert ( + ds["call_AD"].attrs["comment"] + == "Allelic depths for the ref and alt alleles in the order listed" + ) + + +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") +def test_vcf_to_zarr__field_number_G(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr(path, output, fields=["FORMAT/PL"]) + ds = xr.open_zarr(output) + + # select a small region to check + ds = ds.set_index(variants="variant_position").sel( + variants=slice(10002764, 10002793) + ) + + fill = INT_FILL + pl = np.array( + [ + [[319, 0, 1316, 440, 1358, 1798, fill, fill, fill, fill]], + [[0, 120, 1800, fill, fill, fill, fill, fill, fill, fill]], + [[8, 0, 1655, 103, 1743, 2955, 184, 1653, 1928, 1829]], + [[0, 0, 2225, fill, fill, fill, fill, fill, fill, fill]], + ], + ) + assert_array_equal(ds["call_PL"], pl) + assert ( + ds["call_PL"].attrs["comment"] + == "Normalized, Phred-scaled likelihoods for genotypes as defined in the VCF specification" + ) + + +def test_vcf_to_zarr__field_number_G_non_diploid(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "simple.output.mixed_depth.likelihoods.vcf") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + # store GL field as 2dp + encoding = { + "call_GL": { + "filters": [FixedScaleOffset(offset=0, scale=100, dtype="f4", astype="u1")] + } + } + vcf_to_zarr( + path, + output, + ploidy=4, + max_alt_alleles=3, + fields=["FORMAT/GL"], + encoding=encoding, + ) + ds = xr.open_zarr(output) + + # comb(n_alleles + ploidy - 1, ploidy) = comb(4 + 4 - 1, 4) = comb(7, 4) = 35 + assert_array_equal(ds["call_GL"].shape, (4, 3, 35)) + assert ds["call_GL"].attrs["comment"] == "Genotype likelihoods" + + +@pytest.mark.filterwarnings( + "ignore::sgkit.io.vcfzarr_reader.DimensionNameForFixedFormatFieldWarning" +) +def test_vcf_to_zarr__field_number_fixed(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "sample.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + # HQ Number is 2, and a dimension is automatically assigned (FORMAT_HQ_dim) + vcf_to_zarr( + path, + output, + fields=["FORMAT/HQ"], + ) + ds = xr.open_zarr(output) + + missing, fill = INT_MISSING, INT_FILL + assert_array_equal( + ds["call_HQ"], + [ + [[10, 15], [10, 10], [3, 3]], + [[10, 10], [10, 10], [3, 3]], + [[51, 51], [51, 51], [missing, missing]], + [[58, 50], [65, 3], [missing, missing]], + [[23, 27], [18, 2], [missing, missing]], + [[56, 60], [51, 51], [missing, missing]], + [[fill, fill], [fill, fill], [fill, fill]], + [[fill, fill], [fill, fill], [fill, fill]], + [[fill, fill], [fill, fill], [fill, fill]], + ], + ) + assert ds["call_HQ"].dims == ("variants", "samples", "FORMAT_HQ_dim") + assert ds["call_HQ"].attrs["comment"] == "Haplotype Quality" + + +def test_vcf_to_zarr__fields_errors(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "sample.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + with pytest.raises( + ValueError, + match=r"VCF field must be prefixed with 'INFO/' or 'FORMAT/'", + ): + vcf_to_zarr(path, output, fields=["DP"]) + + with pytest.raises( + ValueError, + match=r"INFO field 'XX' is not defined in the header.", + ): + vcf_to_zarr(path, output, fields=["INFO/XX"]) + + with pytest.raises( + ValueError, + match=r"FORMAT field 'XX' is not defined in the header.", + ): + vcf_to_zarr(path, output, fields=["FORMAT/XX"]) + + with pytest.raises( + ValueError, + match=r"FORMAT field 'XX' is not defined in the header.", + ): + vcf_to_zarr(path, output, exclude_fields=["FORMAT/XX"]) + + with pytest.raises( + ValueError, + match=r"INFO field 'AC' is defined as Number '.', which is not supported. Consider specifying `field_defs` to provide a concrete size for this field.", + ): + vcf_to_zarr(path, output, fields=["INFO/AC"]) + + with pytest.raises( + ValueError, + match=r"INFO field 'AN' is defined as Type 'Blah', which is not supported.", + ): + vcf_to_zarr( + path, + output, + fields=["INFO/AN"], + field_defs={"INFO/AN": {"Type": "Blah"}}, + ) + + +@pytest.mark.parametrize( + "vcf_file, expected_sizes", + [ + ( + "sample.vcf.gz", + { + "max_alt_alleles": 3, + "field_defs": {"INFO/AC": {"Number": 2}, "INFO/AF": {"Number": 2}}, + "ploidy": 2, + }, + ), + ("mixed.vcf.gz", {"max_alt_alleles": 1, "ploidy": 4}), + ("no_genotypes.vcf", {"max_alt_alleles": 1}), + ( + "CEUTrio.20.21.gatk3.4.g.vcf.bgz", + { + "max_alt_alleles": 7, + "field_defs": {"FORMAT/AD": {"Number": 8}}, + "ploidy": 2, + }, + ), + ], +) +def test_zarr_array_sizes(shared_datadir, vcf_file, expected_sizes): + path = path_for_test(shared_datadir, vcf_file) + sizes = zarr_array_sizes(path) + assert sizes == expected_sizes + + +def test_zarr_array_sizes__parallel(shared_datadir): + path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz") + regions = ["20", "21"] + sizes = zarr_array_sizes(path, regions=regions) + assert sizes == { + "max_alt_alleles": 7, + "field_defs": {"FORMAT/AD": {"Number": 8}}, + "ploidy": 2, + } + + +def test_zarr_array_sizes__multiple(shared_datadir): + paths = [ + path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz"), + path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz"), + ] + sizes = zarr_array_sizes(paths, target_part_size=None) + assert sizes == { + "max_alt_alleles": 7, + "field_defs": {"FORMAT/AD": {"Number": 8}}, + "ploidy": 2, + } + + +def test_zarr_array_sizes__parallel_partitioned_by_size(shared_datadir): + path = path_for_test( + shared_datadir, + "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz", + ) + sizes = zarr_array_sizes(path, target_part_size="4MB") + assert sizes == { + "max_alt_alleles": 3, + "field_defs": {"FORMAT/AD": {"Number": 4}}, + "ploidy": 2, + } + + +@pytest.mark.parametrize( + "all_kwargs, expected_sizes", + [ + ([{"max_alt_alleles": 1}, {"max_alt_alleles": 2}], {"max_alt_alleles": 2}), + ( + [{"max_alt_alleles": 1, "ploidy": 3}, {"max_alt_alleles": 2}], + {"max_alt_alleles": 2, "ploidy": 3}, + ), + ( + [ + {"max_alt_alleles": 1, "field_defs": {"FORMAT/AD": {"Number": 8}}}, + {"max_alt_alleles": 2, "field_defs": {"FORMAT/AD": {"Number": 6}}}, + ], + {"max_alt_alleles": 2, "field_defs": {"FORMAT/AD": {"Number": 8}}}, + ), + ], +) +def test_merge_zarr_array_sizes(all_kwargs, expected_sizes): + assert merge_zarr_array_sizes(all_kwargs) == expected_sizes + + +def check_field(group, name, ndim, shape, dimension_names, dtype): + assert group[name].ndim == ndim + assert group[name].shape == shape + assert group[name].attrs["_ARRAY_DIMENSIONS"] == dimension_names + if dtype == str: + assert group[name].dtype == np.object_ + assert VLenUTF8() in group[name].filters + else: + assert group[name].dtype == dtype + + +@pytest.mark.filterwarnings( + "ignore::sgkit.io.vcfzarr_reader.DimensionNameForFixedFormatFieldWarning" +) +def test_spec(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "sample_multiple_filters.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + kwargs = zarr_array_sizes(path) + vcf_to_zarr( + path, + output, + chunk_length=5, + fields=["INFO/*", "FORMAT/*"], + mixed_ploidy=True, + **kwargs, + ) + + variants = 9 + alt_alleles = 3 + samples = 3 + ploidy = 2 + + group = zarr.open_group(output) + + # VCF Zarr group attributes + assert group.attrs["vcf_zarr_version"] == "0.2" + assert group.attrs["vcf_header"].startswith("##fileformat=VCFv4.0") + assert group.attrs["contigs"] == ["19", "20", "X"] + + # VCF Zarr arrays + assert set(list(group.array_keys())) == set( + [ + "variant_contig", + "variant_position", + "variant_id", + "variant_id_mask", + "variant_allele", + "variant_quality", + "variant_filter", + "variant_AA", + "variant_AC", + "variant_AF", + "variant_AN", + "variant_DB", + "variant_DP", + "variant_H2", + "variant_NS", + "call_DP", + "call_GQ", + "call_genotype", + "call_genotype_mask", + "call_genotype_fill", + "call_genotype_phased", + "call_HQ", + "contig_id", + "filter_id", + "sample_id", + ] + ) + + # Fixed fields + check_field( + group, + "variant_contig", + ndim=1, + shape=(variants,), + dimension_names=["variants"], + dtype=np.int8, + ) + check_field( + group, + "variant_position", + ndim=1, + shape=(variants,), + dimension_names=["variants"], + dtype=np.int32, + ) + check_field( + group, + "variant_id", + ndim=1, + shape=(variants,), + dimension_names=["variants"], + dtype=str, + ) + check_field( + group, + "variant_allele", + ndim=2, + shape=(variants, alt_alleles + 1), + dimension_names=["variants", "alleles"], + dtype=str, + ) + check_field( + group, + "variant_quality", + ndim=1, + shape=(variants,), + dimension_names=["variants"], + dtype=np.float32, + ) + check_field( + group, + "variant_filter", + ndim=2, + shape=(variants, 3), + dimension_names=["variants", "filters"], + dtype=bool, + ) + + # INFO fields + check_field( + group, + "variant_AA", + ndim=1, + shape=(variants,), + dimension_names=["variants"], + dtype=str, + ) + check_field( + group, + "variant_AC", + ndim=2, + shape=(variants, 2), + dimension_names=["variants", "INFO_AC_dim"], + dtype=np.int32, + ) + check_field( + group, + "variant_AF", + ndim=2, + shape=(variants, 2), + dimension_names=["variants", "INFO_AF_dim"], + dtype=np.float32, + ) + check_field( + group, + "variant_AN", + ndim=1, + shape=(variants,), + dimension_names=["variants"], + dtype=np.int32, + ) + check_field( + group, + "variant_DB", + ndim=1, + shape=(variants,), + dimension_names=["variants"], + dtype=bool, + ) + check_field( + group, + "variant_DP", + ndim=1, + shape=(variants,), + dimension_names=["variants"], + dtype=np.int32, + ) + check_field( + group, + "variant_H2", + ndim=1, + shape=(variants,), + dimension_names=["variants"], + dtype=bool, + ) + check_field( + group, + "variant_NS", + ndim=1, + shape=(variants,), + dimension_names=["variants"], + dtype=np.int32, + ) + + # FORMAT fields + check_field( + group, + "call_DP", + ndim=2, + shape=(variants, samples), + dimension_names=["variants", "samples"], + dtype=np.int32, + ) + check_field( + group, + "call_GQ", + ndim=2, + shape=(variants, samples), + dimension_names=["variants", "samples"], + dtype=np.int32, + ) + check_field( + group, + "call_HQ", + ndim=3, + shape=(variants, samples, 2), + dimension_names=["variants", "samples", "FORMAT_HQ_dim"], + dtype=np.int32, + ) + check_field( + group, + "call_genotype", + ndim=3, + shape=(variants, samples, ploidy), + dimension_names=["variants", "samples", "ploidy"], + dtype=np.int8, + ) + check_field( + group, + "call_genotype_phased", + ndim=2, + shape=(variants, samples), + dimension_names=["variants", "samples"], + dtype=bool, + ) + + # Sample information + check_field( + group, + "sample_id", + ndim=1, + shape=(samples,), + dimension_names=["samples"], + dtype=str, + ) + + # Array values + assert_array_equal(group["variant_contig"], [0, 0, 1, 1, 1, 1, 1, 1, 2]) + assert_array_equal( + group["variant_position"], + [111, 112, 14370, 17330, 1110696, 1230237, 1234567, 1235237, 10], + ) + assert_array_equal( + group["variant_id"], + [".", ".", "rs6054257", ".", "rs6040355", ".", "microsat1", ".", "rsTest"], + ) + assert_array_equal( + group["variant_allele"], + [ + ["A", "C", "", ""], + ["A", "G", "", ""], + ["G", "A", "", ""], + ["T", "A", "", ""], + ["A", "G", "T", ""], + ["T", "", "", ""], + ["G", "GA", "GAC", ""], + ["T", "", "", ""], + ["AC", "A", "ATG", "C"], + ], + ) + assert_allclose( + group["variant_quality"], [9.6, 10.0, 29.0, 3.0, 67.0, 47.0, 50.0, np.nan, 10.0] + ) + assert ( + group["variant_quality"][:].view(np.int32)[7] + == np.array([0x7F800001], dtype=np.int32).item() + ) # missing nan + assert_array_equal( + group["variant_filter"], + [ + [False, False, False], + [False, False, False], + [True, False, False], + [False, True, True], + [True, False, False], + [True, False, False], + [True, False, False], + [False, False, False], + [True, False, False], + ], + ) + + assert_array_equal( + group["variant_NS"], + [INT_FILL, INT_FILL, 3, 3, 2, 3, 3, INT_FILL, INT_FILL], + ) + + assert_array_equal( + group["call_DP"], + [ + [INT_FILL, INT_FILL, INT_FILL], + [INT_FILL, INT_FILL, INT_FILL], + [1, 8, 5], + [3, 5, 3], + [6, 0, 4], + [INT_MISSING, 4, 2], + [4, 2, 3], + [INT_FILL, INT_FILL, INT_FILL], + [INT_FILL, INT_FILL, INT_FILL], + ], + ) + assert_array_equal( + group["call_genotype"], + [ + [[0, 0], [0, 0], [0, 1]], + [[0, 0], [0, 0], [0, 1]], + [[0, 0], [1, 0], [1, 1]], + [[0, 0], [0, 1], [0, 0]], + [[1, 2], [2, 1], [2, 2]], + [[0, 0], [0, 0], [0, 0]], + [[0, 1], [0, 2], [-1, -1]], + [[0, 0], [0, 0], [-1, -1]], + [[0, -2], [0, 1], [0, 2]], + ], + ) + assert_array_equal( + group["call_genotype_phased"], + [ + [True, True, False], + [True, True, False], + [True, True, False], + [True, True, False], + [True, True, False], + [True, True, False], + [False, False, False], + [False, True, False], + [True, False, True], + ], + ) + + assert_array_equal(group["sample_id"], ["NA00001", "NA00002", "NA00003"]) + + +@pytest.mark.parametrize( + "retain_temp_files", + [True, False], +) +def test_vcf_to_zarr__retain_files(shared_datadir, tmp_path, retain_temp_files): + path = path_for_test(shared_datadir, "sample.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + temp_path = tmp_path.joinpath("temp").as_posix() + + vcf_to_zarr( + path, + output, + chunk_length=5, + chunk_width=2, + tempdir=temp_path, + retain_temp_files=retain_temp_files, + target_part_size="500B", + ) + ds = xr.open_zarr(output) + assert_array_equal(ds["contig_id"], ["19", "20", "X"]) + assert (len(os.listdir(temp_path)) == 0) != retain_temp_files + + +def test_vcf_to_zarr__legacy_contig_and_filter_attrs(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "sample.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr(path, output, chunk_length=5, chunk_width=2) + ds = xr.open_zarr(output) + + # drop new contig_id and filter_id variables + ds = ds.drop_vars(["contig_id", "filter_id"]) + + # check that contigs and filters can still be retrieved (with a warning) + assert num_contigs(ds) == 3 + with pytest.warns(DeprecationWarning): + assert_array_equal(get_contigs(ds), np.array(["19", "20", "X"], dtype="S")) + with pytest.warns(DeprecationWarning): + assert_array_equal(get_filters(ds), np.array(["PASS", "s50", "q10"], dtype="S")) + + +def test_vcf_to_zarr__no_samples(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "no_samples.vcf.gz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + vcf_to_zarr(path, output) + # Run with many parts to test concat_zarrs path also accepts no samples + vcf_to_zarr(path, output, target_part_size="1k") + ds = xr.open_zarr(output) + assert_array_equal(ds["sample_id"], []) + assert_array_equal(ds["contig_id"], ["1"]) + assert ds.sizes["variants"] == 973 + + +# TODO take out some of these, they take far too long +@pytest.mark.parametrize( + "vcf_name", + [ + "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz", + "CEUTrio.20.21.gatk3.4.csi.g.vcf.bgz", + "CEUTrio.20.21.gatk3.4.g.bcf", + "CEUTrio.20.21.gatk3.4.g.vcf.bgz", + "CEUTrio.20.gatk3.4.g.vcf.bgz", + "CEUTrio.21.gatk3.4.g.vcf.bgz", + "sample_multiple_filters.vcf.gz", + "sample.vcf.gz", + "allele_overflow.vcf.gz", + ], +) +def test_compare_vcf_to_zarr_convert(shared_datadir, tmp_path, vcf_name): + vcf_path = path_for_test(shared_datadir, vcf_name) + zarr1_path = tmp_path.joinpath("vcf1.zarr").as_posix() + zarr2_path = tmp_path.joinpath("vcf2.zarr").as_posix() + + # Convert gets the actual number of alleles by default, so use this as the + # input for + convert_vcf([vcf_path], zarr2_path) + ds2 = load_dataset(zarr2_path) + vcf_to_zarr( + vcf_path, + zarr1_path, + mixed_ploidy=True, + max_alt_alleles=ds2.variant_allele.shape[1] - 1, + ) + ds1 = load_dataset(zarr1_path) + + # convert reads all variables by default. + base_vars = list(ds1) + ds2 = load_dataset(zarr2_path) + # print(ds1.call_genotype.values) + # print(ds2.call_genotype.values) + xr.testing.assert_equal(ds1, ds2[base_vars]) + + +@pytest.mark.parametrize( + "vcf_name", + [ + "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz", + "CEUTrio.20.21.gatk3.4.csi.g.vcf.bgz", + "CEUTrio.20.21.gatk3.4.g.bcf", + "CEUTrio.20.21.gatk3.4.g.vcf.bgz", + "CEUTrio.20.gatk3.4.g.vcf.bgz", + "CEUTrio.21.gatk3.4.g.vcf.bgz", + "sample_multiple_filters.vcf.gz", + "sample.vcf.gz", + "allele_overflow.vcf.gz", + ], +) +def test_validate_vcf(shared_datadir, tmp_path, vcf_name): + vcf_path = path_for_test(shared_datadir, vcf_name) + zarr_path = os.path.join("tmp/converted/", vcf_name, ".vcf.zarr") + # zarr_path = tmp_path.joinpath("vcf.zarr").as_posix() + print("converting", zarr_path) + convert_vcf([vcf_path], zarr_path) + # validate([vcf_path], zarr_path) + diff --git a/vcf2zarr.py b/vcf2zarr.py new file mode 100644 index 0000000..8ecea04 --- /dev/null +++ b/vcf2zarr.py @@ -0,0 +1,128 @@ +import json + +import click +import yaml +import tabulate + +import sgkit.io.vcf.vcf_converter as cnv + +# from sgkit import load_dataset + + +@click.command +@click.argument("vcfs", nargs=-1, required=True) +def scan(vcfs): + progress = False + spec = cnv.scan_vcfs(vcfs, show_progress=progress) + spec = spec.vcf_metadata + converted = yaml.dump(spec.asdict()) + # converted = json.dumps(spec.asdict(), indent=4) + + print(converted) + # spec2 = cnv.VcfMetadata.fromdict(yaml.load(converted)) + # print(spec2) + + +@click.command +@click.argument("vcfs", nargs=-1, required=True) +@click.argument("out_path", type=click.Path()) +@click.option("-p", "--worker-processes", type=int, default=1) +@click.option("-c", "--column-chunk-size", type=int, default=64) +def explode(vcfs, out_path, worker_processes, column_chunk_size): + cnv.explode( + vcfs, + out_path, + worker_processes=worker_processes, + column_chunk_size=column_chunk_size, + show_progress=True, + ) + + +@click.command +@click.argument("columnarised", type=click.Path()) +def summarise(columnarised): + pcvcf = cnv.PickleChunkedVcf.load(columnarised) + data = pcvcf.summary_table() + print(tabulate.tabulate(data, headers="keys")) + + +@click.command +@click.argument("columnarised", type=click.Path()) +# @click.argument("specfile", type=click.Path()) +def genspec(columnarised): + pcvcf = cnv.PickleChunkedVcf.load(columnarised) + spec = cnv.ZarrConversionSpec.generate(pcvcf) + # with open(specfile, "w") as f: + stream = click.get_text_stream("stdout") + json.dump(spec.asdict(), stream, indent=4) + + +@click.command +@click.argument("columnarised", type=click.Path()) +@click.argument("zarr_path", type=click.Path()) +@click.option("-s", "--conversion-spec", default=None) +@click.option("-p", "--worker-processes", type=int, default=1) +def to_zarr(columnarised, zarr_path, conversion_spec, worker_processes): + pcvcf = cnv.PickleChunkedVcf.load(columnarised) + if conversion_spec is None: + spec = cnv.ZarrConversionSpec.generate(pcvcf) + else: + with open(conversion_spec, "r") as f: + d = json.load(f) + spec = cnv.ZarrConversionSpec.fromdict(d) + + cnv.SgvcfZarr.convert( + pcvcf, + zarr_path, + conversion_spec=spec, + worker_processes=worker_processes, + show_progress=True, + ) + + +@click.command +@click.argument("vcfs", nargs=-1, required=True) +@click.argument("out_path", type=click.Path()) +@click.option("-p", "--worker-processes", type=int, default=1) +def convert(vcfs, out_path, worker_processes): + cnv.convert_vcf(vcfs, out_path, show_progress=True, worker_processes=worker_processes) + +@click.command +@click.argument("vcfs", nargs=-1, required=True) +@click.argument("out_path", type=click.Path()) +def validate(vcfs, out_path): + cnv.validate(vcfs[0], out_path, show_progress=True) + + +@click.command +@click.argument("plink", type=click.Path()) +@click.argument("out_path", type=click.Path()) +@click.option("-p", "--worker-processes", type=int, default=1) +@click.option("--chunk-width", type=int, default=None) +@click.option("--chunk-length", type=int, default=None) +def convert_plink(plink, out_path, worker_processes, chunk_width, chunk_length): + cnv.convert_plink( + plink, + out_path, + show_progress=True, + worker_processes=worker_processes, + chunk_width=chunk_width, + chunk_length=chunk_length, + ) + + +@click.group() +def cli(): + pass + + +cli.add_command(explode) +cli.add_command(summarise) +cli.add_command(genspec) +cli.add_command(to_zarr) +cli.add_command(convert) +cli.add_command(validate) +cli.add_command(convert_plink) + +if __name__ == "__main__": + cli() From e044702dc505593339131b7bb5bfda40df34039d Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 15 Feb 2024 16:40:24 +0000 Subject: [PATCH 3/5] Minimal changes to make conversion run --- bio2zarr/vcf.py | 27 +++++++++------------------ vcf2zarr.py | 23 +++++------------------ 2 files changed, 14 insertions(+), 36 deletions(-) diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index 906a4c3..8a7f3a7 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -23,26 +23,17 @@ import bed_reader +INT_MISSING = -1 +INT_FILL = -2 +STR_MISSING = "." +STR_FILL = "" -# from sgkit.io.utils import FLOAT32_MISSING, str_is_int -from sgkit.io.utils import ( - # CHAR_FILL, - # CHAR_MISSING, - FLOAT32_FILL, - FLOAT32_MISSING, - FLOAT32_FILL_AS_INT32, - FLOAT32_MISSING_AS_INT32, - INT_FILL, - INT_MISSING, - # STR_FILL, - # STR_MISSING, - # str_is_int, +FLOAT32_MISSING, FLOAT32_FILL = np.array([0x7F800001, 0x7F800002], dtype=np.int32).view( + np.float32 +) +FLOAT32_MISSING_AS_INT32, FLOAT32_FILL_AS_INT32 = np.array( + [0x7F800001, 0x7F800002], dtype=np.int32 ) - -# from sgkit.io.vcf import partition_into_regions - -# from sgkit.io.utils import INT_FILL, concatenate_and_rechunk, str_is_int -# from sgkit.utils import smallest_numpy_int_dtype numcodecs.blosc.use_threads = False diff --git a/vcf2zarr.py b/vcf2zarr.py index 8ecea04..145badf 100644 --- a/vcf2zarr.py +++ b/vcf2zarr.py @@ -4,23 +4,7 @@ import yaml import tabulate -import sgkit.io.vcf.vcf_converter as cnv - -# from sgkit import load_dataset - - -@click.command -@click.argument("vcfs", nargs=-1, required=True) -def scan(vcfs): - progress = False - spec = cnv.scan_vcfs(vcfs, show_progress=progress) - spec = spec.vcf_metadata - converted = yaml.dump(spec.asdict()) - # converted = json.dumps(spec.asdict(), indent=4) - - print(converted) - # spec2 = cnv.VcfMetadata.fromdict(yaml.load(converted)) - # print(spec2) +import bio2zarr.vcf as cnv # fixme @click.command @@ -85,7 +69,10 @@ def to_zarr(columnarised, zarr_path, conversion_spec, worker_processes): @click.argument("out_path", type=click.Path()) @click.option("-p", "--worker-processes", type=int, default=1) def convert(vcfs, out_path, worker_processes): - cnv.convert_vcf(vcfs, out_path, show_progress=True, worker_processes=worker_processes) + cnv.convert_vcf( + vcfs, out_path, show_progress=True, worker_processes=worker_processes + ) + @click.command @click.argument("vcfs", nargs=-1, required=True) From 8eefcf097955d08beb9695a34a4f3472be45c3f7 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 15 Feb 2024 17:16:54 +0000 Subject: [PATCH 4/5] Basic tests working --- bio2zarr/vcf.py | 2 +- tests/test_vcf.py | 2029 ++++----------------------------------------- 2 files changed, 180 insertions(+), 1851 deletions(-) diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index 8a7f3a7..8194877 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -74,7 +74,7 @@ def assert_prefix_integer_equal_2d(vcf_val, zarr_val): # Will need to hand-craft from examples to test def assert_prefix_float_equal_1d(vcf_val, zarr_val): v = np.array(vcf_val, dtype=np.float32, ndmin=1) - vi = v.view(np.int32) + # vi = v.view(np.int32) z = np.array(zarr_val, dtype=np.float32, ndmin=1) zi = z.view(np.int32) assert np.sum(zi == FLOAT32_MISSING_AS_INT32) == 0 diff --git a/tests/test_vcf.py b/tests/test_vcf.py index 4e14405..d2964c8 100644 --- a/tests/test_vcf.py +++ b/tests/test_vcf.py @@ -1,1891 +1,220 @@ -import os -import tempfile -from os import listdir -from os.path import join -from typing import MutableMapping - import numpy as np +import numpy.testing as nt import pytest -import xarray as xr -import zarr -from numcodecs import Blosc, Delta, FixedScaleOffset, PackBits, VLenUTF8 -from numpy.testing import assert_allclose, assert_array_equal, assert_array_almost_equal - -from sgkit import load_dataset, save_dataset -from sgkit.io.utils import FLOAT32_FILL, FLOAT32_MISSING, INT_FILL, INT_MISSING -from sgkit.io.vcf import ( - MaxAltAllelesExceededWarning, - partition_into_regions, - read_vcf, - vcf_to_zarr, -) -from sgkit.io.vcf.vcf_reader import ( - FloatFormatFieldWarning, - merge_zarr_array_sizes, - zarr_array_sizes, -) -from sgkit.io.vcf.vcf_converter import convert_vcf, validate -from sgkit.model import get_contigs, get_filters, num_contigs -from sgkit.tests.io.test_dataset import assert_identical +import sgkit as sg -from .utils import path_for_test +from bio2zarr import vcf -@pytest.mark.parametrize( - "read_chunk_length", - [None, 1], -) -@pytest.mark.parametrize( - "is_path", - [True, False], -) -@pytest.mark.parametrize("method", ["to_zarr", "convert", "load"]) -@pytest.mark.filterwarnings("ignore::xarray.coding.variables.SerializationWarning") -def test_vcf_to_zarr__small_vcf( - shared_datadir, - is_path, - read_chunk_length, - tmp_path, - method, -): - path = path_for_test(shared_datadir, "sample.vcf.gz", is_path) - output = tmp_path.joinpath("vcf.zarr").as_posix() - fields = [ - "INFO/NS", - "INFO/AN", - "INFO/AA", - "INFO/DB", - "INFO/AC", - "INFO/AF", - "FORMAT/GT", - "FORMAT/DP", - "FORMAT/HQ", - ] - field_defs = { - "FORMAT/HQ": {"dimension": "ploidy"}, - "INFO/AF": {"Number": "2", "dimension": "AF"}, - "INFO/AC": {"Number": "2", "dimension": "AC"}, - } - if method == "to_zarr": - vcf_to_zarr( - path, - output, - max_alt_alleles=3, - chunk_length=5, - chunk_width=2, - read_chunk_length=read_chunk_length, - fields=fields, - field_defs=field_defs, - ) - ds = xr.open_zarr(output) +class TestSmallExampleValues: + @pytest.fixture(scope="class") + def ds(self, tmp_path_factory): + path = "tests/data/vcf/sample.vcf.gz" + out = tmp_path_factory.mktemp("data") / "example.vcf.zarr" + vcf.convert_vcf([path], out) + return sg.load_dataset(out) - elif method == "convert": - convert_vcf( - [path], - output, - chunk_length=5, - chunk_width=2, - ) - ds = xr.open_zarr(output) - else: - ds = read_vcf( - path, chunk_length=5, chunk_width=2, fields=fields, field_defs=field_defs + def test_filters(self, ds): + nt.assert_array_equal(ds["filter_id"], ["PASS", "s50", "q10"]) + nt.assert_array_equal( + ds["variant_filter"], + [ + [False, False, False], + [False, False, False], + [True, False, False], + [False, False, True], + [True, False, False], + [True, False, False], + [True, False, False], + [False, False, False], + [True, False, False], + ], ) - assert_array_equal(ds["filter_id"], ["PASS", "s50", "q10"]) - assert_array_equal( - ds["variant_filter"], - [ - [False, False, False], - [False, False, False], - [True, False, False], - [False, False, True], - [True, False, False], - [True, False, False], - [True, False, False], - [False, False, False], - [True, False, False], - ], - ) - assert_array_equal(ds["contig_id"], ["19", "20", "X"]) - assert "contig_length" not in ds - assert_array_equal(ds["variant_contig"], [0, 0, 1, 1, 1, 1, 1, 1, 2]) - assert ds["variant_contig"].chunks[0][0] == 5 - - assert_array_equal( - ds["variant_position"], - [111, 112, 14370, 17330, 1110696, 1230237, 1234567, 1235237, 10], - ) - assert ds["variant_position"].chunks[0][0] == 5 - - assert_array_equal( - ds["variant_NS"], - [-1, -1, 3, 3, 2, 3, 3, -1, -1], - ) - assert ds["variant_NS"].chunks[0][0] == 5 - - assert_array_equal( - ds["variant_AN"], - [-1, -1, -1, -1, -1, -1, 6, -1, -1], - ) - assert ds["variant_AN"].chunks[0][0] == 5 - - assert_array_equal( - ds["variant_AA"], - [ - ".", - ".", - ".", - ".", - "T", - "T", - "G", - ".", - ".", - ], - ) - assert ds["variant_AN"].chunks[0][0] == 5 - - assert_array_equal( - ds["variant_DB"], - [ - False, - False, - True, - False, - True, - False, - False, - False, - False, - ], - ) - assert ds["variant_AN"].chunks[0][0] == 5 - - variant_AF = np.full((9, 2), FLOAT32_MISSING, dtype=np.float32) - variant_AF[2, 0] = 0.5 - variant_AF[3, 0] = 0.017 - variant_AF[4, 0] = 0.333 - variant_AF[4, 1] = 0.667 - assert_array_almost_equal(ds["variant_AF"], variant_AF, 3) - assert ds["variant_AF"].chunks[0][0] == 5 - - assert_array_equal( - ds["variant_AC"], - [ - [-1, -1], - [-1, -1], - [-1, -1], - [-1, -1], - [-1, -1], - [-1, -1], - [3, 1], - [-1, -1], - [-1, -1], - ], - ) - assert ds["variant_AC"].chunks[0][0] == 5 - - assert_array_equal( - ds["variant_allele"].values.tolist(), - [ - ["A", "C", "", ""], - ["A", "G", "", ""], - ["G", "A", "", ""], - ["T", "A", "", ""], - ["A", "G", "T", ""], - ["T", "", "", ""], - ["G", "GA", "GAC", ""], - ["T", "", "", ""], - ["AC", "A", "ATG", "C"], - ], - ) - assert ds["variant_allele"].chunks[0][0] == 5 - assert ds["variant_allele"].dtype == "O" - assert_array_equal( - ds["variant_id"].values.tolist(), - [".", ".", "rs6054257", ".", "rs6040355", ".", "microsat1", ".", "rsTest"], - ) - assert ds["variant_id"].chunks[0][0] == 5 - assert ds["variant_id"].dtype == "O" - assert_array_equal( - ds["variant_id_mask"], - [True, True, False, True, False, True, False, True, False], - ) - assert ds["variant_id_mask"].chunks[0][0] == 5 - - assert_array_equal(ds["sample_id"], ["NA00001", "NA00002", "NA00003"]) - assert ds["sample_id"].chunks[0][0] == 2 - - call_genotype = np.array( - [ - [[0, 0], [0, 0], [0, 1]], - [[0, 0], [0, 0], [0, 1]], - [[0, 0], [1, 0], [1, 1]], - [[0, 0], [0, 1], [0, 0]], - [[1, 2], [2, 1], [2, 2]], - [[0, 0], [0, 0], [0, 0]], - [[0, 1], [0, 2], [-1, -1]], - [[0, 0], [0, 0], [-1, -1]], - # NOTE: inconsistency here on pad vs missing. I think this is a - # pad value. - [[0, -2], [0, 1], [0, 2]], - ], - dtype="i1", - ) - call_genotype_phased = np.array( - [ - [True, True, False], - [True, True, False], - [True, True, False], - [True, True, False], - [True, True, False], - [True, True, False], - [False, False, False], - [False, True, False], - [True, False, True], - ], - dtype=bool, - ) - call_DP = [ - [-1, -1, -1], - [-1, -1, -1], - [1, 8, 5], - [3, 5, 3], - [6, 0, 4], - [-1, 4, 2], - [4, 2, 3], - [-1, -1, -1], - [-1, -1, -1], - ] - call_HQ = [ - [[10, 15], [10, 10], [3, 3]], - [[10, 10], [10, 10], [3, 3]], - [[51, 51], [51, 51], [-1, -1]], - [[58, 50], [65, 3], [-1, -1]], - [[23, 27], [18, 2], [-1, -1]], - [[56, 60], [51, 51], [-1, -1]], - [[-1, -1], [-1, -1], [-1, -1]], - [[-1, -1], [-1, -1], [-1, -1]], - [[-1, -1], [-1, -1], [-1, -1]], - ] + def test_contigs(self, ds): + nt.assert_array_equal(ds["contig_id"], ["19", "20", "X"]) + assert "contig_length" not in ds + nt.assert_array_equal(ds["variant_contig"], [0, 0, 1, 1, 1, 1, 1, 1, 2]) - # print(np.array2string(ds["call_HQ"].values, separator=",")) - # print(np.array2string(ds["call_genotype"].values < 0, separator=",")) - - assert_array_equal(ds["call_genotype"], call_genotype) - assert_array_equal(ds["call_genotype_mask"], call_genotype < 0) - assert_array_equal(ds["call_genotype_phased"], call_genotype_phased) - assert_array_equal(ds["call_DP"], call_DP) - assert_array_equal(ds["call_HQ"], call_HQ) - - for name in ["call_genotype", "call_genotype_mask", "call_HQ"]: - assert ds[name].chunks == ((5, 4), (2, 1), (2,)) - - for name in ["call_genotype_phased", "call_DP"]: - assert ds[name].chunks == ((5, 4), (2, 1)) - - # save and load again to test https://github.com/pydata/xarray/issues/3476 - path2 = tmp_path / "ds2.zarr" - if not is_path: - path2 = str(path2) - save_dataset(ds, path2) - assert_identical(ds, load_dataset(path2)) - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -def test_vcf_to_zarr__max_alt_alleles(shared_datadir, is_path, tmp_path): - path = path_for_test(shared_datadir, "sample.vcf.gz", is_path) - output = tmp_path.joinpath("vcf.zarr").as_posix() + def test_position(self, ds): + nt.assert_array_equal( + ds["variant_position"], + [111, 112, 14370, 17330, 1110696, 1230237, 1234567, 1235237, 10], + ) - with pytest.warns(MaxAltAllelesExceededWarning): - max_alt_alleles = 1 - vcf_to_zarr( - path, output, chunk_length=5, chunk_width=2, max_alt_alleles=max_alt_alleles + def test_int_info_fields(self, ds): + nt.assert_array_equal( + ds["variant_NS"], + [-1, -1, 3, 3, 2, 3, 3, -1, -1], + ) + nt.assert_array_equal( + ds["variant_AN"], + [-1, -1, -1, -1, -1, -1, 6, -1, -1], ) - ds = xr.open_zarr(output) - # extra alt alleles are dropped - assert_array_equal( - ds["variant_allele"].values.tolist(), + nt.assert_array_equal( + ds["variant_AC"], [ - ["A", "C"], - ["A", "G"], - ["G", "A"], - ["T", "A"], - ["A", "G"], - ["T", ""], - ["G", "GA"], - ["T", ""], - ["AC", "A"], + [-1, -1], + [-1, -1], + [-1, -1], + [-1, -1], + [-1, -1], + [-1, -1], + [3, 1], + [-1, -1], + [-1, -1], ], ) - # genotype calls are truncated - assert np.all(ds["call_genotype"].values <= max_alt_alleles) - - # the maximum number of alt alleles actually seen is stored as an attribute - assert ds.attrs["max_alt_alleles_seen"] == 3 - - -@pytest.mark.parametrize( - "read_chunk_length", - [None, 1_000], -) -@pytest.mark.parametrize( - "is_path", - [True, False], -) -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__large_vcf(shared_datadir, is_path, read_chunk_length, tmp_path): - path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr(path, output, chunk_length=5_000, read_chunk_length=read_chunk_length) - ds = xr.open_zarr(output) - - assert_array_equal(ds["contig_id"], ["20", "21"]) - assert_array_equal(ds["contig_length"], [63025520, 48129895]) - assert ds["sample_id"].shape == (1,) - assert ds["call_genotype"].shape == (19910, 1, 2) - assert ds["call_genotype_mask"].shape == (19910, 1, 2) - assert ds["call_genotype_phased"].shape == (19910, 1) - assert ds["variant_allele"].shape == (19910, 4) - assert ds["variant_contig"].shape == (19910,) - assert ds["variant_id"].shape == (19910,) - assert ds["variant_id_mask"].shape == (19910,) - assert ds["variant_position"].shape == (19910,) - - assert ds["variant_allele"].dtype == "O" - assert ds["variant_id"].dtype == "O" - - # check underlying zarr chunk size is 1 in samples dim - za = zarr.open(output) - assert za["sample_id"].chunks == (1,) - assert za["call_genotype"].chunks == (5000, 1, 2) - - -def test_vcf_to_zarr__plain_vcf_with_no_index(shared_datadir, tmp_path): - path = path_for_test( - shared_datadir, - "mixed.vcf", - ) - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr(path, output, truncate_calls=True) - ds = xr.open_zarr(output) - assert ds["sample_id"].shape == (3,) - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__mutable_mapping(shared_datadir, is_path): - path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) - output: MutableMapping[str, bytes] = {} - - vcf_to_zarr(path, output, chunk_length=5_000) - ds = xr.open_zarr(output) - - assert ds["sample_id"].shape == (1,) - assert ds["call_genotype"].shape == (19910, 1, 2) - assert ds["call_genotype_mask"].shape == (19910, 1, 2) - assert ds["call_genotype_phased"].shape == (19910, 1) - assert ds["variant_allele"].shape == (19910, 4) - assert ds["variant_contig"].shape == (19910,) - assert ds["variant_id"].shape == (19910,) - assert ds["variant_id_mask"].shape == (19910,) - assert ds["variant_position"].shape == (19910,) - - assert ds["variant_allele"].dtype == "O" - assert ds["variant_id"].dtype == "O" - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path): - path = path_for_test(shared_datadir, "sample.vcf.gz", is_path) - output = tmp_path.joinpath("vcf.zarr").as_posix() - - compressor = Blosc("zlib", 1, Blosc.NOSHUFFLE) - variant_id_compressor = Blosc("zlib", 2, Blosc.NOSHUFFLE) - encoding = dict( - variant_id=dict(compressor=variant_id_compressor), - variant_id_mask=dict(filters=None), - ) - vcf_to_zarr( - path, - output, - chunk_length=5, - chunk_width=2, - compressor=compressor, - encoding=encoding, - ) - - # look at actual Zarr store to check compressor and filters - z = zarr.open(output) - assert z["call_genotype"].compressor == compressor - assert z["call_genotype"].filters is None # sgkit default - assert z["call_genotype"].chunks == (5, 2, 2) - assert z["call_genotype_mask"].compressor == compressor - assert z["call_genotype_mask"].filters == [PackBits()] # sgkit default - assert z["call_genotype_mask"].chunks == (5, 2, 2) - - assert z["variant_id"].compressor == variant_id_compressor - assert z["variant_id"].filters == [VLenUTF8()] # sgkit default - assert z["variant_id"].chunks == (5,) - assert z["variant_id_mask"].compressor == compressor - assert z["variant_id_mask"].filters is None - assert z["variant_id_mask"].chunks == (5,) - - assert z["variant_position"].filters == [ - Delta(dtype="i4", astype="i4") - ] # sgkit default - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__parallel_compressor_and_filters( - shared_datadir, is_path, tmp_path -): - path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - regions = ["20", "21"] - - compressor = Blosc("zlib", 1, Blosc.NOSHUFFLE) - variant_id_compressor = Blosc("zlib", 2, Blosc.NOSHUFFLE) - encoding = dict( - variant_id=dict(compressor=variant_id_compressor), - variant_id_mask=dict(filters=None), - ) - vcf_to_zarr( - path, - output, - regions=regions, - chunk_length=5_000, - compressor=compressor, - encoding=encoding, - ) - - # look at actual Zarr store to check compressor and filters - z = zarr.open(output) - assert z["call_genotype"].compressor == compressor - assert z["call_genotype"].filters is None # sgkit default - assert z["call_genotype"].chunks == (5000, 1, 2) - assert z["call_genotype_mask"].compressor == compressor - assert z["call_genotype_mask"].filters == [PackBits()] # sgkit default - assert z["call_genotype_mask"].chunks == (5000, 1, 2) - - assert z["variant_id"].compressor == variant_id_compressor - assert z["variant_id"].filters == [VLenUTF8()] # sgkit default - assert z["variant_id"].chunks == (5000,) - assert z["variant_id_mask"].compressor == compressor - assert z["variant_id_mask"].filters is None - assert z["variant_id_mask"].chunks == (5000,) - - assert z["variant_position"].filters == [ - Delta(dtype="i4", astype="i4") - ] # sgkit default - - -def test_vcf_to_zarr__float_format_field_warning(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "simple.output.mixed_depth.likelihoods.vcf") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - with pytest.warns(FloatFormatFieldWarning): - vcf_to_zarr( - path, - output, - ploidy=4, - max_alt_alleles=3, - fields=["FORMAT/GL"], + def test_float_info_fields(self, ds): + missing = vcf.FLOAT32_MISSING + fill = vcf.FLOAT32_FILL + variant_AF = np.array( + [ + [missing, missing], + [missing, missing], + [0.5, fill], + [0.017, fill], + [0.333, 0.667], + [missing, missing], + [missing, missing], + [missing, missing], + [missing, missing], + ], + dtype=np.float32, ) - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -@pytest.mark.parametrize( - "output_is_path", - [True, False], -) -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__parallel(shared_datadir, is_path, output_is_path, tmp_path): - path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) - output = tmp_path.joinpath("vcf_concat.zarr") - if not output_is_path: - output = output.as_posix() - - regions = ["20", "21"] - - vcf_to_zarr( - path, - output, - regions=regions, - chunk_length=5_000, - ) - ds = xr.open_zarr(output) - - assert ds["sample_id"].shape == (1,) - assert ds["call_genotype"].shape == (19910, 1, 2) - assert ds["call_genotype_mask"].shape == (19910, 1, 2) - assert ds["call_genotype_phased"].shape == (19910, 1) - assert ds["variant_allele"].shape == (19910, 4) - assert ds["variant_contig"].shape == (19910,) - assert ds["variant_id"].shape == (19910,) - assert ds["variant_id_mask"].shape == (19910,) - assert ds["variant_position"].shape == (19910,) - - assert ds["variant_allele"].dtype == "O" - assert ds["variant_id"].dtype == "O" - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -@pytest.mark.filterwarnings("ignore::UserWarning") -def test_vcf_to_zarr__empty_region(shared_datadir, is_path, tmp_path): - path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - regions = "23" - - vcf_to_zarr(path, output, regions=regions) - ds = xr.open_zarr(output) - - assert ds["sample_id"].shape == (1,) - assert ds["call_genotype"].shape == (0, 1, 2) - assert ds["call_genotype_mask"].shape == (0, 1, 2) - assert ds["call_genotype_phased"].shape == (0, 1) - assert ds["variant_allele"].shape == (0, 4) - assert ds["variant_contig"].shape == (0,) - assert ds["variant_id"].shape == (0,) - assert ds["variant_id_mask"].shape == (0,) - assert ds["variant_position"].shape == (0,) - - -@pytest.mark.parametrize( - "is_path", - [False], -) -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__parallel_temp_chunk_length(shared_datadir, is_path, tmp_path): - path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - regions = ["20", "21"] - - # Use a temp_chunk_length that is smaller than chunk_length - # Open the temporary parts to check that they have the right temp chunk length - with tempfile.TemporaryDirectory() as tempdir: - vcf_to_zarr( - path, - output, - regions=regions, - chunk_length=5_000, - temp_chunk_length=2_500, - tempdir=tempdir, - retain_temp_files=True, + values = ds["variant_AF"].values + nt.assert_array_almost_equal(values, variant_AF, 3) + nans = np.isnan(variant_AF) + nt.assert_array_equal( + variant_AF.view(np.int32)[nans], values.view(np.int32)[nans] ) - inner_temp_dir = join(tempdir, listdir(tempdir)[0]) - parts_dir = join(inner_temp_dir, listdir(inner_temp_dir)[0]) - part = xr.open_zarr(join(parts_dir, "part-0.zarr")) - assert part["call_genotype"].chunks[0][0] == 2_500 - assert part["variant_position"].chunks[0][0] == 2_500 - ds = xr.open_zarr(output) - - assert ds["sample_id"].shape == (1,) - assert ds["call_genotype"].shape == (19910, 1, 2) - assert ds["call_genotype"].chunks[0][0] == 5_000 - assert ds["call_genotype_mask"].shape == (19910, 1, 2) - assert ds["call_genotype_phased"].shape == (19910, 1) - assert ds["variant_allele"].shape == (19910, 4) - assert ds["variant_contig"].shape == (19910,) - assert ds["variant_id"].shape == (19910,) - assert ds["variant_id_mask"].shape == (19910,) - assert ds["variant_position"].shape == (19910,) - assert ds["variant_position"].chunks[0][0] == 5_000 - assert ds["variant_allele"].dtype == "O" - assert ds["variant_id"].dtype == "O" - - -def test_vcf_to_zarr__parallel_temp_chunk_length_not_divisible( - shared_datadir, tmp_path -): - path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", False) - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - regions = ["20", "21"] - - with pytest.raises( - ValueError, - match=r"Temporary chunk length in variant dimension \(4000\) must evenly divide target chunk length 5000", - ): - # Use a temp_chunk_length that does not divide into chunk_length - vcf_to_zarr( - path, output, regions=regions, chunk_length=5_000, temp_chunk_length=4_000 + def test_string_info_fields(self, ds): + nt.assert_array_equal( + ds["variant_AA"], + [ + ".", + ".", + ".", + ".", + "T", + "T", + "G", + ".", + ".", + ], ) - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -def test_vcf_to_zarr__parallel_partitioned(shared_datadir, is_path, tmp_path): - path = path_for_test( - shared_datadir, - "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz", - is_path, - ) - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - - regions = partition_into_regions(path, num_parts=4) - - vcf_to_zarr(path, output, regions=regions, chunk_length=1_000, chunk_width=1_000) - ds = xr.open_zarr(output) - - assert ds["sample_id"].shape == (2535,) - assert ds["variant_id"].shape == (1406,) - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -def test_vcf_to_zarr__parallel_partitioned_by_size(shared_datadir, is_path, tmp_path): - path = path_for_test( - shared_datadir, - "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz", - is_path, - ) - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - - vcf_to_zarr( - path, output, target_part_size="4MB", chunk_length=1_000, chunk_width=1_000 - ) - ds = xr.open_zarr(output) - - assert ds["sample_id"].shape == (2535,) - assert ds["variant_id"].shape == (1406,) - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path): - paths = [ - path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), - path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), - ] - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - - vcf_to_zarr(paths, output, target_part_size=None, chunk_length=5_000) - ds = xr.open_zarr(output) - - assert ds["sample_id"].shape == (1,) - assert ds["call_genotype"].shape == (19910, 1, 2) - assert ds["call_genotype_mask"].shape == (19910, 1, 2) - assert ds["call_genotype_phased"].shape == (19910, 1) - assert ds["variant_allele"].shape == (19910, 4) - assert ds["variant_contig"].shape == (19910,) - assert ds["variant_id"].shape == (19910,) - assert ds["variant_id_mask"].shape == (19910,) - assert ds["variant_position"].shape == (19910,) - - assert ds.chunks["variants"] == (5000, 5000, 5000, 4910) - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path): - paths = [ - path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), - path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), - ] - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - - regions = [partition_into_regions(path, num_parts=2) for path in paths] - - vcf_to_zarr(paths, output, regions=regions, chunk_length=5_000) - ds = xr.open_zarr(output) - - assert ds["sample_id"].shape == (1,) - assert ds["call_genotype"].shape == (19910, 1, 2) - assert ds["call_genotype_mask"].shape == (19910, 1, 2) - assert ds["call_genotype_phased"].shape == (19910, 1) - assert ds["variant_allele"].shape == (19910, 4) - assert ds["variant_contig"].shape == (19910,) - assert ds["variant_id"].shape == (19910,) - assert ds["variant_id_mask"].shape == (19910,) - assert ds["variant_position"].shape == (19910,) - - assert ds.chunks["variants"] == (5000, 5000, 5000, 4910) - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__multiple_partitioned_by_size(shared_datadir, is_path, tmp_path): - paths = [ - path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), - path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), - ] - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - - vcf_to_zarr(paths, output, target_part_size="40KB", chunk_length=5_000) - ds = xr.open_zarr(output) - - assert ds["sample_id"].shape == (1,) - assert ds["call_genotype"].shape == (19910, 1, 2) - assert ds["call_genotype_mask"].shape == (19910, 1, 2) - assert ds["call_genotype_phased"].shape == (19910, 1) - assert ds["variant_allele"].shape == (19910, 4) - assert ds["variant_contig"].shape == (19910,) - assert ds["variant_id"].shape == (19910,) - assert ds["variant_id_mask"].shape == (19910,) - assert ds["variant_position"].shape == (19910,) - - assert ds.chunks["variants"] == (5000, 5000, 5000, 4910) - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -def test_vcf_to_zarr__mutiple_partitioned_invalid_regions( - shared_datadir, is_path, tmp_path -): - paths = [ - path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), - path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), - ] - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - - # invalid regions, should be a sequence of sequences - regions = partition_into_regions(paths[0], num_parts=2) - - with pytest.raises( - ValueError, - match=r"multiple input regions must be a sequence of sequence of strings", - ): - vcf_to_zarr(paths, output, regions=regions, chunk_length=5_000) - - -@pytest.mark.parametrize( - "is_path", - [True, False], -) -def test_vcf_to_zarr__multiple_max_alt_alleles(shared_datadir, is_path, tmp_path): - paths = [ - path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), - path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), - ] - output = tmp_path.joinpath("vcf_concat.zarr").as_posix() - - with pytest.warns(MaxAltAllelesExceededWarning): - vcf_to_zarr( - paths, - output, - target_part_size="40KB", - chunk_length=5_000, - max_alt_alleles=1, + def test_flag_info_fields(self, ds): + nt.assert_array_equal( + ds["variant_DB"], + [ + False, + False, + True, + False, + True, + False, + False, + False, + False, + ], ) - ds = xr.open_zarr(output) - - # the maximum number of alt alleles actually seen is stored as an attribute - assert ds.attrs["max_alt_alleles_seen"] == 7 - -@pytest.mark.parametrize( - "max_alt_alleles,dtype,warning", - [ - (2, np.int8, True), - (127, np.int8, True), - (128, np.int16, True), - (145, np.int16, True), - (164, np.int16, False), - ], -) -def test_vcf_to_zarr__call_genotype_dtype( - shared_datadir, tmp_path, max_alt_alleles, dtype, warning -): - path = path_for_test(shared_datadir, "allele_overflow.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - if warning: - with pytest.warns(MaxAltAllelesExceededWarning): - vcf_to_zarr(path, output, max_alt_alleles=max_alt_alleles) - else: - vcf_to_zarr(path, output, max_alt_alleles=max_alt_alleles) - ds = load_dataset(output) - assert ds.call_genotype.dtype == dtype - assert ds.call_genotype.values.max() <= max_alt_alleles - - -@pytest.mark.parametrize( - "ploidy,mixed_ploidy,truncate_calls,regions", - [ - (2, False, True, None), - (4, False, False, None), - (4, False, False, ["CHR1:0-5", "CHR1:5-10"]), - (4, True, False, None), - (4, True, False, ["CHR1:0-5", "CHR1:5-10"]), - (5, True, False, None), - ], -) -def test_vcf_to_zarr__mixed_ploidy_vcf( - shared_datadir, tmp_path, ploidy, mixed_ploidy, truncate_calls, regions -): - path = path_for_test(shared_datadir, "mixed.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr( - path, - output, - regions=regions, - chunk_length=5, - chunk_width=2, - ploidy=ploidy, - mixed_ploidy=mixed_ploidy, - truncate_calls=truncate_calls, - ) - ds = load_dataset(output) - - variant_dtype = "O" - assert_array_equal(ds["contig_id"], ["CHR1", "CHR2", "CHR3"]) - assert_array_equal(ds["variant_contig"], [0, 0]) - assert_array_equal(ds["variant_position"], [2, 7]) - assert_array_equal( - ds["variant_allele"].values.tolist(), - np.array( + def test_allele(self, ds): + fill = vcf.STR_FILL + nt.assert_array_equal( + ds["variant_allele"].values.tolist(), [ - ["A", "T", "", ""], - ["A", "C", "", ""], + ["A", "C", fill, fill], + ["A", "G", fill, fill], + ["G", "A", fill, fill], + ["T", "A", fill, fill], + ["A", "G", "T", fill], + ["T", fill, fill, fill], + ["G", "GA", "GAC", fill], + ["T", fill, fill, fill], + ["AC", "A", "ATG", "C"], ], - dtype=variant_dtype, - ), - ) - assert ds["variant_allele"].dtype == variant_dtype # type: ignore[comparison-overlap] - assert_array_equal( - ds["variant_id"], - np.array([".", "."], dtype=variant_dtype), - ) - assert ds["variant_id"].dtype == variant_dtype # type: ignore[comparison-overlap] - assert_array_equal( - ds["variant_id_mask"], - [True, True], - ) - assert_array_equal(ds["sample_id"], ["SAMPLE1", "SAMPLE2", "SAMPLE3"]) - - assert ds["call_genotype"].attrs["mixed_ploidy"] == mixed_ploidy - pad = -2 if mixed_ploidy else -1 # -2 indicates a fill (non-allele) value - call_genotype = np.array( - [ - [[0, 0, 1, 1, pad], [0, 0, pad, pad, pad], [0, 0, 0, 1, pad]], - [[0, 0, 1, 1, pad], [0, 1, pad, pad, pad], [0, 1, -1, -1, pad]], - ], - dtype="i1", - ) - # truncate row vectors if lower ploidy - call_genotype = call_genotype[:, :, 0:ploidy] - - assert_array_equal(ds["call_genotype"], call_genotype) - assert_array_equal(ds["call_genotype_mask"], call_genotype < 0) - if mixed_ploidy: - assert_array_equal(ds["call_genotype_fill"], call_genotype < -1) - - -@pytest.mark.parametrize( - "ploidy,mixed_ploidy,truncate_calls", - [ - (2, False, False), - (3, True, False), - ], -) -def test_vcf_to_zarr__mixed_ploidy_vcf_exception( - shared_datadir, tmp_path, ploidy, mixed_ploidy, truncate_calls -): - path = path_for_test(shared_datadir, "mixed.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - with pytest.raises(ValueError) as excinfo: - vcf_to_zarr( - path, - output, - ploidy=ploidy, - mixed_ploidy=mixed_ploidy, - truncate_calls=truncate_calls, ) - assert "Genotype call longer than ploidy." == str(excinfo.value) - - -def test_vcf_to_zarr__no_genotypes(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "no_genotypes.vcf") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr(path, output) - - ds = xr.open_zarr(output) + assert ds["variant_allele"].dtype == "O" - assert "call_genotype" not in ds - assert "call_genotype_mask" not in ds - assert "call_genotype_phased" not in ds - - assert ds["sample_id"].shape == (0,) - assert ds["variant_allele"].shape == (26, 4) - assert ds["variant_contig"].shape == (26,) - assert ds["variant_id"].shape == (26,) - assert ds["variant_id_mask"].shape == (26,) - assert ds["variant_position"].shape == (26,) - - -def test_vcf_to_zarr__no_genotypes_with_gt_header(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "no_genotypes_with_gt_header.vcf") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr(path, output) - - ds = xr.open_zarr(output) - - assert_array_equal(ds["call_genotype"], -1) - assert_array_equal(ds["call_genotype_mask"], 1) - assert_array_equal(ds["call_genotype_phased"], 0) - - assert ds["sample_id"].shape == (0,) - assert ds["variant_allele"].shape == (26, 4) - assert ds["variant_contig"].shape == (26,) - assert ds["variant_id"].shape == (26,) - assert ds["variant_id_mask"].shape == (26,) - assert ds["variant_position"].shape == (26,) - - -def test_vcf_to_zarr__contig_not_defined_in_header(shared_datadir, tmp_path): - # sample.vcf does not define the contigs in the header, and isn't indexed - path = path_for_test(shared_datadir, "sample.vcf") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - with pytest.raises( - ValueError, - match=r"Contig '19' is not defined in the header.", - ): - vcf_to_zarr(path, output) - - -def test_vcf_to_zarr__filter_not_defined_in_header(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "no_filter_defined.vcf") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - with pytest.raises( - ValueError, - match=r"Filter 'FAIL' is not defined in the header.", - ): - vcf_to_zarr(path, output) - - -def test_vcf_to_zarr__info_name_clash(shared_datadir, tmp_path): - # info_name_clash.vcf has an info field called 'id' which would be mapped to - # 'variant_id', clashing with the fixed field of the same name - path = path_for_test(shared_datadir, "info_name_clash.vcf") - output = tmp_path.joinpath("info_name_clash.zarr").as_posix() - - vcf_to_zarr(path, output) # OK if problematic field is ignored - - with pytest.raises( - ValueError, - match=r"Generated name for INFO field 'id' clashes with 'variant_id' from fixed VCF fields.", - ): - vcf_to_zarr(path, output, fields=["INFO/id"]) - - -def test_vcf_to_zarr__large_number_of_contigs(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "Homo_sapiens_assembly38.headerOnly.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr(path, output) - - ds = xr.open_zarr(output) - - assert len(ds["contig_id"]) == 3366 - assert ds["variant_contig"].dtype == np.int16 # needs larger dtype than np.int8 - - -def test_vcf_to_zarr__fields(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "sample.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr( - path, - output, - chunk_length=5, - chunk_width=2, - fields=["INFO/DP", "INFO/AA", "INFO/DB", "FORMAT/DP"], - ) - ds = xr.open_zarr(output) - - missing, fill = INT_MISSING, INT_FILL - assert_array_equal(ds["variant_DP"], [fill, fill, 14, 11, 10, 13, 9, fill, fill]) - assert ds["variant_DP"].attrs["comment"] == "Total Depth" - - assert_array_equal( - ds["variant_AA"], - np.array(["", "", "", "", "T", "T", "G", "", ""], dtype="O"), - ) - assert ds["variant_AA"].attrs["comment"] == "Ancestral Allele" - - assert_array_equal( - ds["variant_DB"], [False, False, True, False, True, False, False, False, False] - ) - assert ds["variant_DB"].attrs["comment"] == "dbSNP membership, build 129" - - dp = np.array( - [ - [fill, fill, fill], - [fill, fill, fill], - [1, 8, 5], - [3, 5, 3], - [6, 0, 4], - [missing, 4, 2], - [4, 2, 3], - [fill, fill, fill], - [fill, fill, fill], - ], - dtype="i4", - ) - assert_array_equal(ds["call_DP"], dp) - assert ds["call_DP"].attrs["comment"] == "Read Depth" - - -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__parallel_with_fields(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - regions = ["20", "21"] - - vcf_to_zarr( - path, - output, - regions=regions, - chunk_length=5_000, - temp_chunk_length=2_500, - fields=["INFO/MQ", "FORMAT/PGT"], - ) - ds = xr.open_zarr(output) + def test_id(self, ds): + nt.assert_array_equal( + ds["variant_id"].values.tolist(), + [".", ".", "rs6054257", ".", "rs6040355", ".", "microsat1", ".", "rsTest"], + ) + assert ds["variant_id"].dtype == "O" + nt.assert_array_equal( + ds["variant_id_mask"], + [True, True, False, True, False, True, False, True, False], + ) - # select a small region to check - ds = ds.set_index(variants=("variant_contig", "variant_position")).sel( - variants=slice((0, 10001661), (0, 10001670)) - ) + def test_samples(self, ds): + nt.assert_array_equal(ds["sample_id"], ["NA00001", "NA00002", "NA00003"]) - # check strings have not been truncated after concat_zarrs - assert_array_equal( - ds["variant_allele"], - np.array( + def test_call_genotype(self, ds): + call_genotype = np.array( [ - ["T", "C", "", ""], - ["T", "", "", ""], - ["T", "G", "", ""], + [[0, 0], [0, 0], [0, 1]], + [[0, 0], [0, 0], [0, 1]], + [[0, 0], [1, 0], [1, 1]], + [[0, 0], [0, 1], [0, 0]], + [[1, 2], [2, 1], [2, 2]], + [[0, 0], [0, 0], [0, 0]], + [[0, 1], [0, 2], [-1, -1]], + [[0, 0], [0, 0], [-1, -1]], + # FIXME this depends on "mixed ploidy" interpretation. + [[0, -2], [0, 1], [0, 2]], ], - dtype="O", - ), - ) - - # convert floats to ints to check nan type - fill = FLOAT32_FILL - assert_allclose( - ds["variant_MQ"].values.view("i4"), - np.array([58.33, fill, 57.45], dtype="f4").view("i4"), - ) - assert ds["variant_MQ"].attrs["comment"] == "RMS Mapping Quality" - - assert_array_equal(ds["call_PGT"], np.array([["0|1"], [""], ["0|1"]], dtype="O")) - assert ( - ds["call_PGT"].attrs["comment"] - == "Physical phasing haplotype information, describing how the alternate alleles are phased in relation to one another" - ) - - -def test_vcf_to_zarr__field_defs(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "sample.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr( - path, - output, - fields=["INFO/DP"], - field_defs={"INFO/DP": {"Description": "Combined depth across samples"}}, - ) - ds = xr.open_zarr(output) - - fill = INT_FILL - assert_array_equal(ds["variant_DP"], [fill, fill, 14, 11, 10, 13, 9, fill, fill]) - assert ds["variant_DP"].attrs["comment"] == "Combined depth across samples" - - vcf_to_zarr( - path, - output, - fields=["INFO/DP"], - field_defs={"INFO/DP": {"Description": ""}}, # blank description - ) - ds = xr.open_zarr(output) - - assert_array_equal(ds["variant_DP"], [fill, fill, 14, 11, 10, 13, 9, fill, fill]) - assert "comment" not in ds["variant_DP"].attrs - - -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "sample.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr( - path, - output, - max_alt_alleles=2, - fields=["INFO/AC"], - field_defs={"INFO/AC": {"Number": "A"}}, - ) - ds = xr.open_zarr(output) - - fill = INT_FILL - assert_array_equal( - ds["variant_AC"], - [ - [fill, fill], - [fill, fill], - [fill, fill], - [fill, fill], - [fill, fill], - [fill, fill], - [3, 1], - [fill, fill], - [fill, fill], - ], - ) - assert ( - ds["variant_AC"].attrs["comment"] - == "Allele count in genotypes, for each ALT allele, in the same order as listed" - ) - - -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__field_number_R(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr( - path, - output, - fields=["FORMAT/AD"], - field_defs={"FORMAT/AD": {"Number": "R"}}, - ) - ds = xr.open_zarr(output) - - # select a small region to check - ds = ds.set_index(variants="variant_position").sel( - variants=slice(10002764, 10002793) - ) - - fill = INT_FILL - ad = np.array( - [ - [[40, 14, 0, fill]], - [[fill, fill, fill, fill]], - [[65, 8, 5, 0]], - [[fill, fill, fill, fill]], - ], - ) - assert_array_equal(ds["call_AD"], ad) - assert ( - ds["call_AD"].attrs["comment"] - == "Allelic depths for the ref and alt alleles in the order listed" - ) - - -@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") -def test_vcf_to_zarr__field_number_G(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr(path, output, fields=["FORMAT/PL"]) - ds = xr.open_zarr(output) - - # select a small region to check - ds = ds.set_index(variants="variant_position").sel( - variants=slice(10002764, 10002793) - ) - - fill = INT_FILL - pl = np.array( - [ - [[319, 0, 1316, 440, 1358, 1798, fill, fill, fill, fill]], - [[0, 120, 1800, fill, fill, fill, fill, fill, fill, fill]], - [[8, 0, 1655, 103, 1743, 2955, 184, 1653, 1928, 1829]], - [[0, 0, 2225, fill, fill, fill, fill, fill, fill, fill]], - ], - ) - assert_array_equal(ds["call_PL"], pl) - assert ( - ds["call_PL"].attrs["comment"] - == "Normalized, Phred-scaled likelihoods for genotypes as defined in the VCF specification" - ) - - -def test_vcf_to_zarr__field_number_G_non_diploid(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "simple.output.mixed_depth.likelihoods.vcf") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - # store GL field as 2dp - encoding = { - "call_GL": { - "filters": [FixedScaleOffset(offset=0, scale=100, dtype="f4", astype="u1")] - } - } - vcf_to_zarr( - path, - output, - ploidy=4, - max_alt_alleles=3, - fields=["FORMAT/GL"], - encoding=encoding, - ) - ds = xr.open_zarr(output) - - # comb(n_alleles + ploidy - 1, ploidy) = comb(4 + 4 - 1, 4) = comb(7, 4) = 35 - assert_array_equal(ds["call_GL"].shape, (4, 3, 35)) - assert ds["call_GL"].attrs["comment"] == "Genotype likelihoods" - - -@pytest.mark.filterwarnings( - "ignore::sgkit.io.vcfzarr_reader.DimensionNameForFixedFormatFieldWarning" -) -def test_vcf_to_zarr__field_number_fixed(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "sample.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - # HQ Number is 2, and a dimension is automatically assigned (FORMAT_HQ_dim) - vcf_to_zarr( - path, - output, - fields=["FORMAT/HQ"], - ) - ds = xr.open_zarr(output) - - missing, fill = INT_MISSING, INT_FILL - assert_array_equal( - ds["call_HQ"], - [ - [[10, 15], [10, 10], [3, 3]], - [[10, 10], [10, 10], [3, 3]], - [[51, 51], [51, 51], [missing, missing]], - [[58, 50], [65, 3], [missing, missing]], - [[23, 27], [18, 2], [missing, missing]], - [[56, 60], [51, 51], [missing, missing]], - [[fill, fill], [fill, fill], [fill, fill]], - [[fill, fill], [fill, fill], [fill, fill]], - [[fill, fill], [fill, fill], [fill, fill]], - ], - ) - assert ds["call_HQ"].dims == ("variants", "samples", "FORMAT_HQ_dim") - assert ds["call_HQ"].attrs["comment"] == "Haplotype Quality" - - -def test_vcf_to_zarr__fields_errors(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "sample.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - with pytest.raises( - ValueError, - match=r"VCF field must be prefixed with 'INFO/' or 'FORMAT/'", - ): - vcf_to_zarr(path, output, fields=["DP"]) - - with pytest.raises( - ValueError, - match=r"INFO field 'XX' is not defined in the header.", - ): - vcf_to_zarr(path, output, fields=["INFO/XX"]) - - with pytest.raises( - ValueError, - match=r"FORMAT field 'XX' is not defined in the header.", - ): - vcf_to_zarr(path, output, fields=["FORMAT/XX"]) - - with pytest.raises( - ValueError, - match=r"FORMAT field 'XX' is not defined in the header.", - ): - vcf_to_zarr(path, output, exclude_fields=["FORMAT/XX"]) - - with pytest.raises( - ValueError, - match=r"INFO field 'AC' is defined as Number '.', which is not supported. Consider specifying `field_defs` to provide a concrete size for this field.", - ): - vcf_to_zarr(path, output, fields=["INFO/AC"]) - - with pytest.raises( - ValueError, - match=r"INFO field 'AN' is defined as Type 'Blah', which is not supported.", - ): - vcf_to_zarr( - path, - output, - fields=["INFO/AN"], - field_defs={"INFO/AN": {"Type": "Blah"}}, + dtype="i1", ) + nt.assert_array_equal(ds["call_genotype"], call_genotype) + nt.assert_array_equal(ds["call_genotype_mask"], call_genotype < 0) - -@pytest.mark.parametrize( - "vcf_file, expected_sizes", - [ - ( - "sample.vcf.gz", - { - "max_alt_alleles": 3, - "field_defs": {"INFO/AC": {"Number": 2}, "INFO/AF": {"Number": 2}}, - "ploidy": 2, - }, - ), - ("mixed.vcf.gz", {"max_alt_alleles": 1, "ploidy": 4}), - ("no_genotypes.vcf", {"max_alt_alleles": 1}), - ( - "CEUTrio.20.21.gatk3.4.g.vcf.bgz", - { - "max_alt_alleles": 7, - "field_defs": {"FORMAT/AD": {"Number": 8}}, - "ploidy": 2, - }, - ), - ], -) -def test_zarr_array_sizes(shared_datadir, vcf_file, expected_sizes): - path = path_for_test(shared_datadir, vcf_file) - sizes = zarr_array_sizes(path) - assert sizes == expected_sizes - - -def test_zarr_array_sizes__parallel(shared_datadir): - path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz") - regions = ["20", "21"] - sizes = zarr_array_sizes(path, regions=regions) - assert sizes == { - "max_alt_alleles": 7, - "field_defs": {"FORMAT/AD": {"Number": 8}}, - "ploidy": 2, - } - - -def test_zarr_array_sizes__multiple(shared_datadir): - paths = [ - path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz"), - path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz"), - ] - sizes = zarr_array_sizes(paths, target_part_size=None) - assert sizes == { - "max_alt_alleles": 7, - "field_defs": {"FORMAT/AD": {"Number": 8}}, - "ploidy": 2, - } - - -def test_zarr_array_sizes__parallel_partitioned_by_size(shared_datadir): - path = path_for_test( - shared_datadir, - "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz", - ) - sizes = zarr_array_sizes(path, target_part_size="4MB") - assert sizes == { - "max_alt_alleles": 3, - "field_defs": {"FORMAT/AD": {"Number": 4}}, - "ploidy": 2, - } - - -@pytest.mark.parametrize( - "all_kwargs, expected_sizes", - [ - ([{"max_alt_alleles": 1}, {"max_alt_alleles": 2}], {"max_alt_alleles": 2}), - ( - [{"max_alt_alleles": 1, "ploidy": 3}, {"max_alt_alleles": 2}], - {"max_alt_alleles": 2, "ploidy": 3}, - ), - ( + def test_call_genotype_phased(self, ds): + call_genotype_phased = np.array( [ - {"max_alt_alleles": 1, "field_defs": {"FORMAT/AD": {"Number": 8}}}, - {"max_alt_alleles": 2, "field_defs": {"FORMAT/AD": {"Number": 6}}}, + [True, True, False], + [True, True, False], + [True, True, False], + [True, True, False], + [True, True, False], + [True, True, False], + [False, False, False], + [False, True, False], + [True, False, True], ], - {"max_alt_alleles": 2, "field_defs": {"FORMAT/AD": {"Number": 8}}}, - ), - ], -) -def test_merge_zarr_array_sizes(all_kwargs, expected_sizes): - assert merge_zarr_array_sizes(all_kwargs) == expected_sizes - - -def check_field(group, name, ndim, shape, dimension_names, dtype): - assert group[name].ndim == ndim - assert group[name].shape == shape - assert group[name].attrs["_ARRAY_DIMENSIONS"] == dimension_names - if dtype == str: - assert group[name].dtype == np.object_ - assert VLenUTF8() in group[name].filters - else: - assert group[name].dtype == dtype - - -@pytest.mark.filterwarnings( - "ignore::sgkit.io.vcfzarr_reader.DimensionNameForFixedFormatFieldWarning" -) -def test_spec(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "sample_multiple_filters.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - kwargs = zarr_array_sizes(path) - vcf_to_zarr( - path, - output, - chunk_length=5, - fields=["INFO/*", "FORMAT/*"], - mixed_ploidy=True, - **kwargs, - ) - - variants = 9 - alt_alleles = 3 - samples = 3 - ploidy = 2 - - group = zarr.open_group(output) - - # VCF Zarr group attributes - assert group.attrs["vcf_zarr_version"] == "0.2" - assert group.attrs["vcf_header"].startswith("##fileformat=VCFv4.0") - assert group.attrs["contigs"] == ["19", "20", "X"] - - # VCF Zarr arrays - assert set(list(group.array_keys())) == set( - [ - "variant_contig", - "variant_position", - "variant_id", - "variant_id_mask", - "variant_allele", - "variant_quality", - "variant_filter", - "variant_AA", - "variant_AC", - "variant_AF", - "variant_AN", - "variant_DB", - "variant_DP", - "variant_H2", - "variant_NS", - "call_DP", - "call_GQ", - "call_genotype", - "call_genotype_mask", - "call_genotype_fill", - "call_genotype_phased", - "call_HQ", - "contig_id", - "filter_id", - "sample_id", - ] - ) - - # Fixed fields - check_field( - group, - "variant_contig", - ndim=1, - shape=(variants,), - dimension_names=["variants"], - dtype=np.int8, - ) - check_field( - group, - "variant_position", - ndim=1, - shape=(variants,), - dimension_names=["variants"], - dtype=np.int32, - ) - check_field( - group, - "variant_id", - ndim=1, - shape=(variants,), - dimension_names=["variants"], - dtype=str, - ) - check_field( - group, - "variant_allele", - ndim=2, - shape=(variants, alt_alleles + 1), - dimension_names=["variants", "alleles"], - dtype=str, - ) - check_field( - group, - "variant_quality", - ndim=1, - shape=(variants,), - dimension_names=["variants"], - dtype=np.float32, - ) - check_field( - group, - "variant_filter", - ndim=2, - shape=(variants, 3), - dimension_names=["variants", "filters"], - dtype=bool, - ) - - # INFO fields - check_field( - group, - "variant_AA", - ndim=1, - shape=(variants,), - dimension_names=["variants"], - dtype=str, - ) - check_field( - group, - "variant_AC", - ndim=2, - shape=(variants, 2), - dimension_names=["variants", "INFO_AC_dim"], - dtype=np.int32, - ) - check_field( - group, - "variant_AF", - ndim=2, - shape=(variants, 2), - dimension_names=["variants", "INFO_AF_dim"], - dtype=np.float32, - ) - check_field( - group, - "variant_AN", - ndim=1, - shape=(variants,), - dimension_names=["variants"], - dtype=np.int32, - ) - check_field( - group, - "variant_DB", - ndim=1, - shape=(variants,), - dimension_names=["variants"], - dtype=bool, - ) - check_field( - group, - "variant_DP", - ndim=1, - shape=(variants,), - dimension_names=["variants"], - dtype=np.int32, - ) - check_field( - group, - "variant_H2", - ndim=1, - shape=(variants,), - dimension_names=["variants"], - dtype=bool, - ) - check_field( - group, - "variant_NS", - ndim=1, - shape=(variants,), - dimension_names=["variants"], - dtype=np.int32, - ) - - # FORMAT fields - check_field( - group, - "call_DP", - ndim=2, - shape=(variants, samples), - dimension_names=["variants", "samples"], - dtype=np.int32, - ) - check_field( - group, - "call_GQ", - ndim=2, - shape=(variants, samples), - dimension_names=["variants", "samples"], - dtype=np.int32, - ) - check_field( - group, - "call_HQ", - ndim=3, - shape=(variants, samples, 2), - dimension_names=["variants", "samples", "FORMAT_HQ_dim"], - dtype=np.int32, - ) - check_field( - group, - "call_genotype", - ndim=3, - shape=(variants, samples, ploidy), - dimension_names=["variants", "samples", "ploidy"], - dtype=np.int8, - ) - check_field( - group, - "call_genotype_phased", - ndim=2, - shape=(variants, samples), - dimension_names=["variants", "samples"], - dtype=bool, - ) - - # Sample information - check_field( - group, - "sample_id", - ndim=1, - shape=(samples,), - dimension_names=["samples"], - dtype=str, - ) - - # Array values - assert_array_equal(group["variant_contig"], [0, 0, 1, 1, 1, 1, 1, 1, 2]) - assert_array_equal( - group["variant_position"], - [111, 112, 14370, 17330, 1110696, 1230237, 1234567, 1235237, 10], - ) - assert_array_equal( - group["variant_id"], - [".", ".", "rs6054257", ".", "rs6040355", ".", "microsat1", ".", "rsTest"], - ) - assert_array_equal( - group["variant_allele"], - [ - ["A", "C", "", ""], - ["A", "G", "", ""], - ["G", "A", "", ""], - ["T", "A", "", ""], - ["A", "G", "T", ""], - ["T", "", "", ""], - ["G", "GA", "GAC", ""], - ["T", "", "", ""], - ["AC", "A", "ATG", "C"], - ], - ) - assert_allclose( - group["variant_quality"], [9.6, 10.0, 29.0, 3.0, 67.0, 47.0, 50.0, np.nan, 10.0] - ) - assert ( - group["variant_quality"][:].view(np.int32)[7] - == np.array([0x7F800001], dtype=np.int32).item() - ) # missing nan - assert_array_equal( - group["variant_filter"], - [ - [False, False, False], - [False, False, False], - [True, False, False], - [False, True, True], - [True, False, False], - [True, False, False], - [True, False, False], - [False, False, False], - [True, False, False], - ], - ) - - assert_array_equal( - group["variant_NS"], - [INT_FILL, INT_FILL, 3, 3, 2, 3, 3, INT_FILL, INT_FILL], - ) + dtype=bool, + ) + nt.assert_array_equal(ds["call_genotype_phased"], call_genotype_phased) - assert_array_equal( - group["call_DP"], - [ - [INT_FILL, INT_FILL, INT_FILL], - [INT_FILL, INT_FILL, INT_FILL], + def test_call_DP(self, ds): + call_DP = [ + [-1, -1, -1], + [-1, -1, -1], [1, 8, 5], [3, 5, 3], [6, 0, 4], - [INT_MISSING, 4, 2], + [-1, 4, 2], [4, 2, 3], - [INT_FILL, INT_FILL, INT_FILL], - [INT_FILL, INT_FILL, INT_FILL], - ], - ) - assert_array_equal( - group["call_genotype"], - [ - [[0, 0], [0, 0], [0, 1]], - [[0, 0], [0, 0], [0, 1]], - [[0, 0], [1, 0], [1, 1]], - [[0, 0], [0, 1], [0, 0]], - [[1, 2], [2, 1], [2, 2]], - [[0, 0], [0, 0], [0, 0]], - [[0, 1], [0, 2], [-1, -1]], - [[0, 0], [0, 0], [-1, -1]], - [[0, -2], [0, 1], [0, 2]], - ], - ) - assert_array_equal( - group["call_genotype_phased"], - [ - [True, True, False], - [True, True, False], - [True, True, False], - [True, True, False], - [True, True, False], - [True, True, False], - [False, False, False], - [False, True, False], - [True, False, True], - ], - ) - - assert_array_equal(group["sample_id"], ["NA00001", "NA00002", "NA00003"]) - - -@pytest.mark.parametrize( - "retain_temp_files", - [True, False], -) -def test_vcf_to_zarr__retain_files(shared_datadir, tmp_path, retain_temp_files): - path = path_for_test(shared_datadir, "sample.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - temp_path = tmp_path.joinpath("temp").as_posix() - - vcf_to_zarr( - path, - output, - chunk_length=5, - chunk_width=2, - tempdir=temp_path, - retain_temp_files=retain_temp_files, - target_part_size="500B", - ) - ds = xr.open_zarr(output) - assert_array_equal(ds["contig_id"], ["19", "20", "X"]) - assert (len(os.listdir(temp_path)) == 0) != retain_temp_files - - -def test_vcf_to_zarr__legacy_contig_and_filter_attrs(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "sample.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr(path, output, chunk_length=5, chunk_width=2) - ds = xr.open_zarr(output) - - # drop new contig_id and filter_id variables - ds = ds.drop_vars(["contig_id", "filter_id"]) - - # check that contigs and filters can still be retrieved (with a warning) - assert num_contigs(ds) == 3 - with pytest.warns(DeprecationWarning): - assert_array_equal(get_contigs(ds), np.array(["19", "20", "X"], dtype="S")) - with pytest.warns(DeprecationWarning): - assert_array_equal(get_filters(ds), np.array(["PASS", "s50", "q10"], dtype="S")) - - -def test_vcf_to_zarr__no_samples(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "no_samples.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - vcf_to_zarr(path, output) - # Run with many parts to test concat_zarrs path also accepts no samples - vcf_to_zarr(path, output, target_part_size="1k") - ds = xr.open_zarr(output) - assert_array_equal(ds["sample_id"], []) - assert_array_equal(ds["contig_id"], ["1"]) - assert ds.sizes["variants"] == 973 - - -# TODO take out some of these, they take far too long -@pytest.mark.parametrize( - "vcf_name", - [ - "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz", - "CEUTrio.20.21.gatk3.4.csi.g.vcf.bgz", - "CEUTrio.20.21.gatk3.4.g.bcf", - "CEUTrio.20.21.gatk3.4.g.vcf.bgz", - "CEUTrio.20.gatk3.4.g.vcf.bgz", - "CEUTrio.21.gatk3.4.g.vcf.bgz", - "sample_multiple_filters.vcf.gz", - "sample.vcf.gz", - "allele_overflow.vcf.gz", - ], -) -def test_compare_vcf_to_zarr_convert(shared_datadir, tmp_path, vcf_name): - vcf_path = path_for_test(shared_datadir, vcf_name) - zarr1_path = tmp_path.joinpath("vcf1.zarr").as_posix() - zarr2_path = tmp_path.joinpath("vcf2.zarr").as_posix() - - # Convert gets the actual number of alleles by default, so use this as the - # input for - convert_vcf([vcf_path], zarr2_path) - ds2 = load_dataset(zarr2_path) - vcf_to_zarr( - vcf_path, - zarr1_path, - mixed_ploidy=True, - max_alt_alleles=ds2.variant_allele.shape[1] - 1, - ) - ds1 = load_dataset(zarr1_path) - - # convert reads all variables by default. - base_vars = list(ds1) - ds2 = load_dataset(zarr2_path) - # print(ds1.call_genotype.values) - # print(ds2.call_genotype.values) - xr.testing.assert_equal(ds1, ds2[base_vars]) - - -@pytest.mark.parametrize( - "vcf_name", - [ - "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz", - "CEUTrio.20.21.gatk3.4.csi.g.vcf.bgz", - "CEUTrio.20.21.gatk3.4.g.bcf", - "CEUTrio.20.21.gatk3.4.g.vcf.bgz", - "CEUTrio.20.gatk3.4.g.vcf.bgz", - "CEUTrio.21.gatk3.4.g.vcf.bgz", - "sample_multiple_filters.vcf.gz", - "sample.vcf.gz", - "allele_overflow.vcf.gz", - ], -) -def test_validate_vcf(shared_datadir, tmp_path, vcf_name): - vcf_path = path_for_test(shared_datadir, vcf_name) - zarr_path = os.path.join("tmp/converted/", vcf_name, ".vcf.zarr") - # zarr_path = tmp_path.joinpath("vcf.zarr").as_posix() - print("converting", zarr_path) - convert_vcf([vcf_path], zarr_path) - # validate([vcf_path], zarr_path) + [-1, -1, -1], + [-1, -1, -1], + ] + nt.assert_array_equal(ds["call_DP"], call_DP) + def test_call_HQ(self, ds): + call_HQ = [ + [[10, 15], [10, 10], [3, 3]], + [[10, 10], [10, 10], [3, 3]], + [[51, 51], [51, 51], [-1, -1]], + [[58, 50], [65, 3], [-1, -1]], + [[23, 27], [18, 2], [-1, -1]], + [[56, 60], [51, 51], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]], + ] + nt.assert_array_equal(ds["call_HQ"], call_HQ) From e943f58232800605855cc1ae437e1f6105c4375d Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 15 Feb 2024 17:34:39 +0000 Subject: [PATCH 5/5] Basic packaging details --- bio2zarr/vcf.py | 4 ++-- setup.cfg | 59 +++++++++++++++++++++++++++++++++++++++++++++++++ setup.py | 9 ++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 setup.cfg create mode 100644 setup.py diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index 8194877..2bff6ab 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -1564,7 +1564,7 @@ def validate(vcf_path, zarr_path, show_progress): elif vcf_type == "String": assert np.all(zarr_val == ".") elif vcf_type == "Flag": - assert zarr_val == False + assert zarr_val == False # noqa 712 elif vcf_type == "Float": assert_all_missing_float(zarr_val) else: @@ -1576,7 +1576,7 @@ def validate(vcf_path, zarr_path, show_progress): elif vcf_type == "Float": assert_prefix_float_equal_1d(vcf_val, zarr_val) elif vcf_type == "Flag": - assert zarr_val == True + assert zarr_val == True # noqa 712 elif vcf_type == "String": assert np.all(zarr_val == vcf_val) else: diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..0a240ba --- /dev/null +++ b/setup.cfg @@ -0,0 +1,59 @@ +[metadata] +name = bio2zarr +author = sgkit Developers +author_email = project@pystatgen.org +license = Apache +description = FIXME +long_description_content_type=text/x-rst +long_description = + FIXME +url = https://github.com/pystatgen/bio2zarr +classifiers = + Development Status :: 3 - Alpha + License :: OSI Approved :: Apache Software License + Operating System :: OS Independent + Intended Audience :: Science/Research + Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Topic :: Scientific/Engineering + +[options] +packages = bio2zarr +zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.html +include_package_data = True +python_requires = >=3.8 +install_requires = + numpy + zarr >= 2.10.0, != 2.11.0, != 2.11.1, != 2.11.2 + cyvcf2 + bed_reader +setup_requires = + setuptools >= 41.2 + setuptools_scm + +[flake8] +ignore = + # whitespace before ':' - doesn't work well with black + E203 + E402 + # line too long - let black worry about that + E501 + # do not assign a lambda expression, use a def + E731 + # line break before binary operator + W503 + +[isort] +profile = black +default_section = THIRDPARTY +known_first_party = sgkit +known_third_party = hypothesis,msprime,numpy,pandas,pytest,setuptools,sgkit,zarr +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +line_length = 88 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..d5fecb5 --- /dev/null +++ b/setup.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python +from setuptools import setup + +setup( + # The package name along with all the other metadata is specified in setup.cfg + # However, GitHub's dependency graph can't see the package unless we put this here. + name="bio2zarr", + use_scm_version=True, +)