Skip to content

Commit

Permalink
Fix uncompressed signal python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
0x55555555 committed Jul 8, 2024
1 parent b2b62c7 commit 92e70f1
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 20 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ All notable changes, updates, and fixes to pod5 will be documented here
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.3.12]

## Fixed

- Fixed issues reading signal from uncompressed pod5 files.


## [0.3.11]

## Added
Expand Down
8 changes: 5 additions & 3 deletions c++/pod5_format_pybind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ PYBIND11_MODULE(pod5_format_pybind, m)

m.doc() = "POD5 Format Raw Bindings";

auto thread_pool = pod5::make_thread_pool(std::thread::hardware_concurrency());
py::enum_<SignalType>(m, "SignalType", py::arithmetic(), "SignalType enum")
.value("UncompressedSignal", SignalType::UncompressedSignal, "Signal is not compressed")
.value("VbzSignal", SignalType::VbzSignal, "Signal is compressed using vbz")
.export_values();

py::class_<FileWriterOptions>(m, "FileWriterOptions")
.def(py::init([thread_pool]() {
.def(py::init([]() {
FileWriterOptions options;
options.set_thread_pool(thread_pool);
return options;
}))
.def_property(
Expand Down
4 changes: 2 additions & 2 deletions ci/setup_python_osx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ fi
# relocatable-python doesn't like this dir existing, it exits with error:
tmp_python_dir="/Users/cirunner/Library/Python_"
function cleanup {
mv $tmp_python_dir /Users/cirunner/Library/Python
mv $tmp_python_dir /Users/cirunner/Library/Python || true
}
trap cleanup EXIT

mv /Users/cirunner/Library/Python $tmp_python_dir
mv /Users/cirunner/Library/Python $tmp_python_dir || true

relocatable-python/make_relocatable_python_framework.py --python-version "${version}" --destination "${destination}" --upgrade-pip --os-version "${os_version}"

Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ sphinx-rtd-theme
sphinx==v5.3.0
myst-parser
# Paths are relative to project root for ReadTheDocs and docs/Makefile
pod5==0.3.11
pod5==0.3.12
2 changes: 2 additions & 0 deletions python/lib_pod5/src/lib_pod5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Pod5RepackerOutput,
Pod5SignalCacheBatch,
Repacker,
SignalType,
compress_signal,
create_file,
recover_file,
Expand All @@ -33,6 +34,7 @@
"Pod5RepackerOutput",
"Pod5SignalCacheBatch",
"Repacker",
"SignalType",
"compress_signal",
"create_file",
"recover_file",
Expand Down
8 changes: 5 additions & 3 deletions python/pod5/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ classifiers=[
]

dependencies = [
"lib_pod5 == 0.3.11",
"lib_pod5 == 0.3.12",
"iso8601",
'importlib-metadata; python_version<"3.8"',
"more_itertools",
"numpy >= 1.21.0",
'typing-extensions; python_version<"3.8"',
"pyarrow ~= 16.1.0",
'typing-extensions; python_version<"3.10"',
# Avoid issues with pyarrow 16.1.0 on x64 Macos: https://github.com/apache/arrow/issues/41696
'pyarrow ~= 16.1.0; platform_system!="Darwin" or platform_machine!="x86_64" or python_version<"3.12"',
'pyarrow ~= 16.0.0; platform_system=="Darwin" and platform_machine=="x86_64" and python_version>="3.12"',
"pytz",
"packaging",
"polars~=0.19",
Expand Down
3 changes: 2 additions & 1 deletion python/pod5/src/pod5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
vbz_decompress_signal_chunked,
vbz_decompress_signal_into,
)
from .writer import Writer
from .writer import SignalType, Writer

__all__ = (
"__version__",
Expand All @@ -56,6 +56,7 @@
"Reader",
"ReadRecord",
"ReadRecordBatch",
"SignalType",
"vbz_compress_signal",
"vbz_decompress_signal",
"vbz_decompress_signal_chunked",
Expand Down
13 changes: 10 additions & 3 deletions python/pod5/src/pod5/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def signal(self) -> npt.NDArray[np.int16]:
memoryview(signal[batch_row_index].as_buffer()), output_slice
)
else:
output_slice[:] = signal.to_numpy()
output_slice[:] = signal[batch_row_index].values
current_sample_index += current_row_count
return output

Expand Down Expand Up @@ -347,11 +347,16 @@ def map_signal_row(sig_row) -> SignalRowInfo:
sig_row = sig_row.as_py()

batch, batch_index, batch_row_index = self._find_signal_row_index(sig_row)
batch_length = 0
if isinstance(batch.signal, pa.lib.LargeListArray):
batch_length = len(batch.signal[batch_row_index])
else:
batch_length = len(batch.signal[batch_row_index].as_buffer())
return SignalRowInfo(
batch_index,
batch_row_index,
batch.samples[batch_row_index].as_py(),
len(batch.signal[batch_row_index].as_buffer()),
batch_length,
)

return [map_signal_row(r) for r in self._batch.columns.signal[self._row]]
Expand Down Expand Up @@ -402,7 +407,9 @@ def _get_signal_for_row(self, signal_row: int) -> npt.NDArray[np.int16]:
memoryview(signal[batch_row_index].as_buffer()), sample_count
)

return signal.to_numpy()
return signal.to_numpy()
else:
return np.array(signal[batch_row_index].values, dtype="int16")

def to_read(self) -> Read:
"""
Expand Down
27 changes: 25 additions & 2 deletions python/pod5/src/pod5/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@
TypeVar,
Union,
)
import sys

import lib_pod5 as p5b
import numpy as np
from pod5.reader import ReadRecord
import pytz

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

from pod5.api_utils import Pod5ApiException, safe_close
from pod5.pod5_types import (
BaseRead,
Expand All @@ -35,6 +41,7 @@

DEFAULT_SOFTWARE_NAME = "Python API"

SignalType: TypeAlias = p5b.SignalType
PoreType = str
T = TypeVar("T", bound=Union[EndReason, PoreType, RunInfo])

Expand Down Expand Up @@ -68,7 +75,12 @@ def timestamp_to_int(time_stamp: Union[datetime.datetime, int]) -> int:
class Writer:
"""Pod5 File Writer"""

def __init__(self, path: PathOrStr, software_name: str = DEFAULT_SOFTWARE_NAME):
def __init__(
self,
path: PathOrStr,
software_name: str = DEFAULT_SOFTWARE_NAME,
signal_compression_type: SignalType = SignalType.VbzSignal,
):
"""
Open a pod5 file for Writing.
Expand All @@ -78,17 +90,23 @@ def __init__(self, path: PathOrStr, software_name: str = DEFAULT_SOFTWARE_NAME):
The path to the pod5 file to create
software_name : str
The name of the application used to create this pod5 file
signal_compression_type : SignalType
The type of compression to use in the file. Defaults to Vbz.
"""
self._path = Path(path).absolute()
self._software_name = software_name
self._signal_compression_type = signal_compression_type

if self._path.is_file():
raise FileExistsError(
f"Input path already exists. Refusing to overwrite: {self._path}"
)

options = p5b.FileWriterOptions()
options.signal_compression_type = signal_compression_type

self._writer: Optional[p5b.FileWriter] = p5b.create_file(
str(self._path), software_name, None
str(self._path), software_name, options
)
if not self._writer:
raise Pod5ApiException(
Expand Down Expand Up @@ -134,6 +152,11 @@ def software_name(self) -> str:
"""Return the software name used to open this file"""
return self._software_name

@property
def signal_compression_type(self) -> SignalType:
"""Return the signal compression type used by this file"""
return self._signal_compression_type

def add(self, obj: Union[EndReason, PoreType, RunInfo]) -> int:
"""
Add a :py:class:`EndReason`, :py:class:`PoreType`, or
Expand Down
25 changes: 20 additions & 5 deletions python/pod5/src/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,13 @@ def get_random_str(prefix: str) -> str:


def run_writer_test(f: Writer):
writer_supports_compressed = f.signal_compression_type == p5.SignalType.VbzSignal

test_read = gen_test_read(0, compressed=False)
print("read", test_read.read_id, test_read.run_info.adc_max)
f.add_read(test_read)

test_read = gen_test_read(1, compressed=True)
test_read = gen_test_read(1, compressed=writer_supports_compressed)
print("read", test_read.read_id, test_read.run_info.adc_max)
f.add_read(test_read)

Expand All @@ -103,10 +105,10 @@ def run_writer_test(f: Writer):
f.add_reads(test_reads)

test_reads = [
gen_test_read(6, compressed=True),
gen_test_read(7, compressed=True),
gen_test_read(8, compressed=True),
gen_test_read(9, compressed=True),
gen_test_read(6, compressed=writer_supports_compressed),
gen_test_read(7, compressed=writer_supports_compressed),
gen_test_read(8, compressed=writer_supports_compressed),
gen_test_read(9, compressed=writer_supports_compressed),
]
f.add_reads(test_reads)
assert test_reads[0].sample_count > 0
Expand Down Expand Up @@ -247,6 +249,19 @@ def test_pyarrow_from_str():
run_reader_test(_fh)


@pytest.mark.filterwarnings("ignore: pod5.")
def test_pyarrow_from_pathlib_uncompressed():
with tempfile.TemporaryDirectory() as temp:
path = Path(temp) / "example.pod5"
with p5.Writer(
path, signal_compression_type=p5.SignalType.UncompressedSignal
) as _fh:
run_writer_test(_fh)

with p5.Reader(path) as _fh:
run_reader_test(_fh)


def test_read_id_packing():
"""
Assert pack_read_ids repacks and format_read_ids correctly unpacks collections
Expand Down

0 comments on commit 92e70f1

Please sign in to comment.