diff --git a/Cargo.toml b/Cargo.toml index 7a67713..9855988 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustfrc" -version = "1.0.2" +version = "1.0.3" edition = "2018" [lib] @@ -8,10 +8,10 @@ name = "rustfrc" crate-type = ["cdylib"] [dependencies] -numpy = "^0.14" -ndarray = { version = "^0.15", features = ["rayon"] } -ndarray-rand = "^0.14" +numpy = "~0.14" +ndarray = { version = "~0.15", features = ["rayon"] } +ndarray-rand = "~0.14" [dependencies.pyo3] -version = "^0.14" +version = "~0.14" features = ["extension-module"] diff --git a/README.md b/README.md index cfa38a6..2588f74 100644 --- a/README.md +++ b/README.md @@ -5,16 +5,16 @@ rustfrc is a Python package with some fast Rust functions for use with FRC resolution determination for microscopy. It is in development for use during a Bachelor end project for the TU Delft in 2021-2022. -Since rustfrc contains compiled extensions and is not pure Python, it is not available for all platforms, but only for those with available compiled wheels. As of version 1.0.1, they are available for Windows (x86_64), macOS (x86_64 and universal2, which includes Apple Silicon) and Linux (x86_64). However, since Rust and Python are supported on many platforms, it is not difficult to compile for other platforms (see below). +Since rustfrc contains compiled extensions and is not pure Python, it is not available for all platforms, but only for those with available compiled wheels. They are available for Windows (x86_64), macOS (x86_64 and universal2, which includes Apple Silicon) and Linux (x86_64). However, since Rust and Python are supported on many platforms, it is not difficult to compile for other platforms (see below). ## Features -Currently, rustfrc does not have many features. The primary one is `binom_split(x: np.ndarray) -> np.ndarray` which samples _Binom ~ (n, 0.5)_ with n as the array element value. +Currently, rustfrc does not have many features. The primary one is `binom_split(x: np.ndarray) -> np.ndarray` which samples _Binom ~ (n, 0.5)_ with n as the array element value. The operation is fully parallelized and somewhere between 3-10x faster than sampling using NumPy. ## Requirements * Python 3.7 or greater -* numpy 1.18 or greater +* NumPy 1.18 or greater ## Installation @@ -37,7 +37,7 @@ Build a wheel file like this (if using poetry, append `poetry run` before the co maturin build --release ``` -If you want to choose which versions of Python to build for, you can append e.g. `maturin build --release -i python3.9 python3.8 python3.7`. Here `python3.7` should be an available Python command installed on your computer. +If you want to choose which versions of Python to build for, you can write e.g. `maturin build --release -i python3.9 python3.8 python3.7`. Here, for example '`python3.7`' should be an available Python command installed on your computer. This generates `.whl` files in `\target\wheels`. Then, create a Python environment of your choosing (with `numpy ^1.18` and `python ^3.7`), drop the `.whl` file in it and run `pip install <.whl filename>`, for example: `pip install rustfrc-0.1.0-cp39-none-win_amd64.whl`. Then, use `import rustfrc` in your Python script to be able to use the Rust functions. This should be generally valid for all platforms. The only real requirement is the availability of a Rust toolchain and Python for your platform. diff --git a/poetry.lock b/poetry.lock index 593369e..9cdbd67 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,6 +1,61 @@ +[[package]] +name = "atomicwrites" +version = "1.4.0" +description = "Atomic file writes." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "attrs" +version = "21.2.0" +description = "Classes Without Boilerplate" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[package.extras] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"] +docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"] +tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"] + +[[package]] +name = "colorama" +version = "0.4.4" +description = "Cross-platform colored terminal text." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[[package]] +name = "importlib-metadata" +version = "4.8.1" +description = "Read metadata from Python packages" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} +zipp = ">=0.5" + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] +perf = ["ipython"] +testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"] + +[[package]] +name = "iniconfig" +version = "1.1.1" +description = "iniconfig: brain-dead simple config-ini parsing" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "maturin" -version = "0.11.4" +version = "0.11.5" description = "Build and publish crates with pyo3, rust-cpython and cffi bindings as well as rust binaries as python packages" category = "dev" optional = false @@ -17,32 +72,136 @@ category = "main" optional = false python-versions = ">=3.7,<3.11" +[[package]] +name = "packaging" +version = "21.0" +description = "Core utilities for Python packages" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +pyparsing = ">=2.0.2" + +[[package]] +name = "pluggy" +version = "1.0.0" +description = "plugin and hook calling mechanisms for python" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "py" +version = "1.10.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "pyparsing" +version = "2.4.7" +description = "Python parsing module" +category = "main" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" + +[[package]] +name = "pytest" +version = "6.2.5" +description = "pytest: simple powerful testing with Python" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} +attrs = ">=19.2.0" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +py = ">=1.8.2" +toml = "*" + +[package.extras] +testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] + [[package]] name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" -category = "dev" +category = "main" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +[[package]] +name = "typing-extensions" +version = "3.10.0.2" +description = "Backported and Experimental Type Hints for Python 3.5+" +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "zipp" +version = "3.6.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] +testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"] + [metadata] lock-version = "1.1" -python-versions = ">=3.7, <3.11" -content-hash = "24c14f3c9fb1f39c8d6b1bd5a5afb69d34d70215e54d53584d36b9568836f934" +python-versions = ">= 3.7" +content-hash = "55ca09031bcaaf5fad587a9c023bb3883ecd0fcf5ae5dab917d69d6db4a07287" [metadata.files] +atomicwrites = [ + {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, + {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, +] +attrs = [ + {file = "attrs-21.2.0-py2.py3-none-any.whl", hash = "sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1"}, + {file = "attrs-21.2.0.tar.gz", hash = "sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb"}, +] +colorama = [ + {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, + {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, +] +importlib-metadata = [ + {file = "importlib_metadata-4.8.1-py3-none-any.whl", hash = "sha256:b618b6d2d5ffa2f16add5697cf57a46c76a56229b0ed1c438322e4e95645bd15"}, + {file = "importlib_metadata-4.8.1.tar.gz", hash = "sha256:f284b3e11256ad1e5d03ab86bb2ccd6f5339688ff17a4d797a0fe7df326f23b1"}, +] +iniconfig = [ + {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, + {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, +] maturin = [ - {file = "maturin-0.11.4-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:9c4eb3b7a6399cdb3b098095037eb26da0242c1bd48b5067ef48e1204e7ba7ba"}, - {file = "maturin-0.11.4-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:e3e5bfff69074ee307df661678ef94cfe3048afe4cf34bea19446701f67af4eb"}, - {file = "maturin-0.11.4-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f9fbc687bf363717d09d9dfbc4c57f627ef0b74273089ac10ed2418a98d0681e"}, - {file = "maturin-0.11.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:580f1e3a05f66b5eb4b73615a40563237f0b611884a75fc6fa98b423ccd46a4c"}, - {file = "maturin-0.11.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c8b1899391bcec7b47bd4423231d4df214538d3f1f1ad709af958751b6480afd"}, - {file = "maturin-0.11.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80d728c9741f012c2afd1b516309eaca986bd6734b8e711d3f560b8c1d02d8c3"}, - {file = "maturin-0.11.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e1c8db8eb1e9edf6adfb21e8fd0a0080ae770ef3ac3e71cd961caf970f043c7"}, - {file = "maturin-0.11.4-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b606da2839a2095c425086a2fc32bcdf108fb8491ef7cce998a8599f3d6ce057"}, - {file = "maturin-0.11.4-py3-none-win32.whl", hash = "sha256:04fcb8de4a7667500aa40fd97ac853776e8f5872d0776ab50cfa499b53e32bd6"}, - {file = "maturin-0.11.4-py3-none-win_amd64.whl", hash = "sha256:e7897fcbafdc90a356814a54833a3ca81fd27e97a02c763addf122a5c6dcbc75"}, - {file = "maturin-0.11.4.tar.gz", hash = "sha256:b1cd9c35c911a11532c1182d35041cda0602f470cab8d76124ca049c8bba896a"}, + {file = "maturin-0.11.5-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:d78f24561a5e02f7d119b348b26e5772ad5698a43ca49e8facb9ce77cf273714"}, + {file = "maturin-0.11.5-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:c2ded8b4ef9210d627bb966bc67661b7db259535f6062afe1ce5605406b50f3f"}, + {file = "maturin-0.11.5-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1ce666c386ff9c3c2b5d7d3ca4b1f9f675c38d7540ffbda0d5d5bc7d6ddde49a"}, + {file = "maturin-0.11.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0ac45879a7d624b47d72b093ae3370270894c19779f42aad7568a92951c5d47"}, + {file = "maturin-0.11.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4191b0b7362b3025096faf126ff15cb682fbff324ac4a6ca18d55bb16e2b759b"}, + {file = "maturin-0.11.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7bf96e7586bfdb5b0fadc6d662534b8a41123b33dff084fa383a81ded0ce5334"}, + {file = "maturin-0.11.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ab2b3ccf66f5e0f9c3904d215835337b1bd305e79e3bf53b65bbc80a5755e01b"}, + {file = "maturin-0.11.5-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:3354d030b88c938a33bf407a6c0f79ccdd2cce3e1e3e4a2d0c92dc2e063adc6e"}, + {file = "maturin-0.11.5-py3-none-win32.whl", hash = "sha256:20f9c30701c9932ed8026ceaf896fc77ecc76cebd6a182668dbc10ed597f8789"}, + {file = "maturin-0.11.5-py3-none-win_amd64.whl", hash = "sha256:70381be1585cb9fa5c02b83af80ae661aaad959e8aa0fddcfe195b004054bd69"}, + {file = "maturin-0.11.5.tar.gz", hash = "sha256:07074778b063a439fdfd5501bd1d1823a216ec5b657d3ecde78fd7f2c4782422"}, ] numpy = [ {file = "numpy-1.21.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52a664323273c08f3b473548bf87c8145b7513afd63e4ebba8496ecd3853df13"}, @@ -76,7 +235,36 @@ numpy = [ {file = "numpy-1.21.2-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d96a6a7d74af56feb11e9a443150216578ea07b7450f7c05df40eec90af7f4a7"}, {file = "numpy-1.21.2.zip", hash = "sha256:423216d8afc5923b15df86037c6053bf030d15cc9e3224206ef868c2d63dd6dc"}, ] +packaging = [ + {file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"}, + {file = "packaging-21.0.tar.gz", hash = "sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7"}, +] +pluggy = [ + {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, + {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, +] +py = [ + {file = "py-1.10.0-py2.py3-none-any.whl", hash = "sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a"}, + {file = "py-1.10.0.tar.gz", hash = "sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3"}, +] +pyparsing = [ + {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, + {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, +] +pytest = [ + {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, + {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, +] toml = [ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] +typing-extensions = [ + {file = "typing_extensions-3.10.0.2-py2-none-any.whl", hash = "sha256:d8226d10bc02a29bcc81df19a26e56a9647f8b0a6d4a83924139f4a8b01f17b7"}, + {file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"}, + {file = "typing_extensions-3.10.0.2.tar.gz", hash = "sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e"}, +] +zipp = [ + {file = "zipp-3.6.0-py3-none-any.whl", hash = "sha256:9fe5ea21568a0a70e50f273397638d39b03353731e6cbbb3fd8502a33fec40bc"}, + {file = "zipp-3.6.0.tar.gz", hash = "sha256:71c644c5369f4a6e07636f0aa966270449561fcea2e3d6747b8d23efaa9d7832"}, +] diff --git a/pyproject.toml b/pyproject.toml index 43a2296..72c7a6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [project] name = "rustfrc" -version = "1.0.2" -description = "Package with some fast Rust functions for use with FRC resolution determination for microscopy. TU Delft BEP 2021-2022." +version = "1.0.3" +description = "Fast utility functions useful for Fourier Ring/Shell Correlation: binomial splitting of arrrays" readme = "README.md" -requires-python = ">=3.7, <3.11" +requires-python = ">= 3.7" license = {text = "Apache-2.0"} authors = [ {email = "T.M.tenBrink@student.tudelft.nl"}, @@ -16,35 +16,23 @@ classifiers = [ "Topic :: Scientific/Engineering :: Image Processing" ] dependencies = [ - 'numpy >= 1.18, < 2' + "numpy >= 1.18, < 2" ] -[project.urls] -repository = "https://github.com/tmtenbrink/rustfrc" [tool.poetry] name = "rustfrc" -version = "1.0.2" -description = "Package with some fast Rust functions for use with FRC resolution determination for microscopy. TU Delft BEP 2021-2022." +version = "1.0.3" +description = "Fast utility functions useful for Fourier Ring/Shell Correlation: binomial splitting of arrays" authors = ["Tip ten Brink "] -readme = "README.md" -classifiers = [ - "Intended Audience :: Science/Research", - "Programming Language :: Python", - "Programming Language :: Rust", - "Topic :: Scientific/Engineering :: Image Processing" -] -license = "Apache-2.0" - -[tool.poetry.urls] -"Source Code" = "https://github.com/tmtenbrink/rustfrc" [tool.poetry.dependencies] -python = ">=3.7, <3.11" -numpy = ">=1.18, < 2" +python = ">= 3.7" +numpy = ">= 1.18, < 2" [tool.poetry.dev-dependencies] maturin = "^0.11.3" +pytest = "^6.2.5" [build-system] -requires = ["poetry-core>=1.0.0"] +requires = ["maturin >= 0.11, < 0.12"] build-backend ="maturin" diff --git a/rustfrc/split.py b/rustfrc/split.py index 7525493..9016b4a 100644 --- a/rustfrc/split.py +++ b/rustfrc/split.py @@ -7,4 +7,4 @@ def binom_split(a: np.ndarray) -> np.ndarray: binomial distribution (n, p) with n = pixel value and p = 0.5. Returns a single image. This conserves shot noise. """ - return _internal.binom_split(a.astype(np.int32)) + return _internal.binom_split_py(a.astype(np.int32)) diff --git a/src/lib.rs b/src/lib.rs index 492ab02..f9bb25f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,19 @@ use ndarray_rand::rand_distr::Binomial; -use ndarray_rand::rand::prelude::Distribution; -use ndarray_rand::rand::thread_rng; +use ndarray_rand::rand::prelude::{Distribution, thread_rng}; +use ndarray::{Array, Dimension}; use numpy::{IntoPyArray, PyReadonlyArrayDyn, PyArrayDyn}; use pyo3::prelude::{pymodule, pyfunction, wrap_pyfunction, PyModule, PyResult, Python}; use pyo3::exceptions::PyValueError; use std::convert::TryFrom; -use std::sync::atomic::{AtomicI32, Ordering}; +use std::sync::atomic::{Ordering, AtomicBool}; +use std::error::Error; +use std::fmt::{Display, Formatter}; + #[pymodule] fn rustfrc(py: Python<'_>, m: &PyModule) -> PyResult<()> { let internal = PyModule::new(py, "_internal")?; - internal.add_function(wrap_pyfunction!(binom_split, internal)?)?; + internal.add_function(wrap_pyfunction!(binom_split_py, internal)?)?; m.add_submodule(internal)?; Ok(()) @@ -22,32 +25,103 @@ fn rustfrc(py: Python<'_>, m: &PyModule) -> PyResult<()> { /// Takes an image (np.ndarray with dtype i32) and splits every pixel value according to the /// binomial distribution (n, p) with n = pixel value and p = 0.5. Returns a single image. #[pyfunction] -fn binom_split<'py>(py: Python<'py>, a: PyReadonlyArrayDyn<'py, i32>) -> PyResult<&'py PyArrayDyn> { - let mut a = a.as_array().to_owned(); +fn binom_split_py<'py>(py: Python<'py>, a: PyReadonlyArrayDyn<'py, i32>) -> PyResult<&'py PyArrayDyn> { + let a = a.to_owned_array(); + + binom_split(a) + .map_err(|e| PyValueError::new_err(format!("{}", e.to_string()))) + .map(|a| a.into_pyarray(py)) +} + +#[derive(Debug)] +struct ToUsizeError {} + +impl Display for ToUsizeError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("Value in array a cannot be cast to u64. All array \ + values must be non-negative.")) + } +} + +impl Error for ToUsizeError {} + +fn binom_split(mut a: Array) -> Result, ToUsizeError> { + // AtomicBool is thread-safe, and allows for communicating an error state occurred across threads + // We initialize it with the value false, since no error occurred + let to_unsized_failed = AtomicBool::new(false); + // We map all values in parallel, since they do not depend on each other - let error_i = AtomicI32::new(0); a.par_mapv_inplace(|i| { - if error_i.load(Ordering::Relaxed) == 0 { + // If no failure has occurred, we continue + // We use Relaxed Ordering because the order in which stores and loads occur does not matter + // Once it is set to true, it will stay true, when exactly that happens does not matter + if !to_unsized_failed.load(Ordering::Relaxed) { + // We use thread rng, which is fast let mut rng = thread_rng(); + // We try to convert i32 to u64 (which is required for Binomial) + // If it fails, we indicate a failure has occurred + // Unfortunately it is not possible to escape from the loop immediately let n = u64::try_from(i).unwrap_or_else(|_| { - // Since this is a parallel function, a special AtomicI32 is necessary to communicate - // if there is a failure. - error_i.store(i, Ordering::Relaxed); + to_unsized_failed.store(true, Ordering::Relaxed); 0 }); Binomial::new(n, 0.5).unwrap().sample(&mut rng) as i32 } + // We just keep the rest unchanged if a failure occurred else { - 0 + i } }); - let error_i = error_i.into_inner(); - if error_i != 0 { - Err(PyValueError::new_err( - format!("{i} in array a cannot be cast to u64. All array values must be non-negative.", i=error_i))) + let to_unsized_failed = to_unsized_failed.into_inner(); + if to_unsized_failed { + Err(ToUsizeError {}) } else { - Ok(a.into_pyarray(py)) + Ok(a) } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn binom_split_2d() { + let a: ndarray::Array2 = ndarray::arr2(&[[9, 9, 2, 3], + [4, 5, 6, 7]]); + let b = binom_split(a.clone()); + let b = b.unwrap(); + assert!(b.iter().clone().max() <= a.iter().max()); + assert!(*(b.iter().min().unwrap()) >= 0); + } + + #[test] + fn binom_split_negative() { + let a: ndarray::Array2 = ndarray::arr2(&[[9, -9, 2, 3], + [4, 5, 6, 7]]); + let b = binom_split(a); + + assert!(b.is_err()); + } + + #[test] + fn binom_split_1_element() { + let a: ndarray::Array1 = ndarray::arr1(&[0]); + let b = binom_split(a); + assert_eq!(b.unwrap(), ndarray::arr1(&[0])); + } + + #[test] + fn binom_split_large_d() { + let a1 = ndarray::arr3(&[ [[2, 3], [4, 3]], + [[2, 9], [4, 5]], [[9, 7], [2, 3]] ]); + let a2 = ndarray::stack(ndarray::Axis(0), &[a1.view(), a1.view()]).unwrap(); + let a3 = ndarray::stack(ndarray::Axis(0), &[a2.view(), a2.view()]).unwrap(); + let a4 = ndarray::stack(ndarray::Axis(0), &[a3.view(), a3.view()]).unwrap(); + let a5 = ndarray::stack(ndarray::Axis(0), &[a4.view(), a4.view()]).unwrap(); + let a6 = ndarray::stack(ndarray::Axis(0), &[a5.view(), a5.view()]).unwrap(); + + let b = binom_split(a6); + assert!(b.is_ok()); + } +} \ No newline at end of file diff --git a/test.py b/test.py index 69b62ab..720ef70 100644 --- a/test.py +++ b/test.py @@ -2,19 +2,19 @@ import numpy as np import time -rng = np.random.default_rng() -x = np.ones((388, 388))*20 -x = np.repeat(x[:, :, np.newaxis], 1000, axis=2).astype(np.int64) -# half_x = np.rint(x/2).astype(int) -# print(half_x) -start = time.time_ns() -# a = rng.binomial(x, 0.5) -# a = r.binom_split(x) -b = x - a -# for i in range(100): -# a = rng.binomial(half_x, 0.5) -# b = half_x - a -end_time = time.time_ns() -print(str(float(end_time - start)/1e9) + " s") +init_x = (np.ones((900, 700, 100))*30).astype(np.int32) +init_y = (np.ones((900, 700, 100))*30).astype(np.int32) +init_z = np.ones((900, 700, 50))*30 + + +start = time.perf_counter() +x = r.binom_split(init_x) +end = time.perf_counter() +print(str(end - start) + " s") -# + +start2 = time.perf_counter() +rng = np.random.default_rng() +y = rng.binomial(init_y, 0.5) +end2 = time.perf_counter() +print(str(end2 - start2) + " s") \ No newline at end of file