Skip to content

Commit

Permalink
Fix the Cython-/Rust switch, update readmes
Browse files Browse the repository at this point in the history
  • Loading branch information
LSchueler committed Jul 5, 2024
1 parent 51ab732 commit 8fbe2ec
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 46 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ running. Install the package by typing the following command in a command termin

To install the latest development version via pip, see the
[documentation][doc_install_link].
One thing to point out is that this way, the non-parallel version of GSTools
is installed. In case you want the parallel version, follow these easy
[steps][doc_install_link].


## Citation
Expand Down
34 changes: 19 additions & 15 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,24 @@ If something went wrong during installation, try the :code:`-I` `flag from pip <

**Speeding up GSTools by parallelization**

To enable the OpenMP support, you have to provide a C compiler and OpenMP.
Parallel support is controlled by an environment variable ``GSTOOLS_BUILD_PARALLEL``,
that can be ``0`` or ``1`` (interpreted as ``0`` if not present).
GSTools then needs to be installed from source:
We provide two possibilities to run GSTools in parallel, often causing a
massive improvement in runtime. In either case, the number of parallel
threads can be set with the global variable `config.NUM_THREADS`.

***Parallelizing Cython***

To enable the OpenMP support in Cython, you have to provide a C compiler and
OpenMP. Parallel support is controlled by an environment variable
``GSTOOLS_BUILD_PARALLEL``, that can be ``0`` or ``1`` (interpreted as ``0``
if not present). GSTools then needs to be installed from source:

.. code-block:: none
export GSTOOLS_BUILD_PARALLEL=1
pip install --no-binary=gstools gstools
Note, that the ``--no-binary=gstools`` option forces pip to not use a wheel for GSTools.
Note, that the ``--no-binary=gstools`` option forces pip to not use a wheel
for GSTools.

For the development version, you can do almost the same:

Expand All @@ -98,19 +105,18 @@ For the development version, you can do almost the same:
export GSTOOLS_BUILD_PARALLEL=1
pip install git+git://github.com/GeoStat-Framework/GSTools.git@main
The number of parallel threads can be set with the global variable `config.NUM_THREADS`.
**Using experimental GSTools-Core for even more speed**
***Using GSTools-Core for parallelization and even more speed***

You can install the optional dependency `GSTools-Core <https://github.com/GeoStat-Framework/GSTools-Core>`_,
which is a re-implementation of the main algorithms used in GSTools. The new
which is a re-implementation of the algorithms used in GSTools. The new
package uses the language Rust and it should be faster (in some cases by orders
of magnitude), safer, and it will potentially completely replace the current
standard implementation in Cython. Once the package GSTools-Core is available
on your machine, it will be used by default. In case you want to switch back to
the Cython implementation, you can set :code:`gstools.config.USE_RUST=False` in
your code. This also works at runtime. You can install the optional dependency
e.g. by
the Cython implementation, you can set
:code:`gstools.config.USE_GSTOOLS_CORE=False` in your code. This also works at
runtime. You can install the optional dependency e.g. by

.. code-block:: none
Expand All @@ -122,10 +128,8 @@ or by manually installing the package
pip install gstools-core
GSTools-Core will automatically use all your cores in parallel, without having
to use OpenMP or a local C compiler.
In case you want to restrict the number of threads used, you can use the
global variable `config.NUM_THREADS` to the desired number.
GSTools-Core will automatically run in parallel, without having to use provide
OpenMP or a local C compiler.


Citation
Expand Down
6 changes: 4 additions & 2 deletions src/gstools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
try: # pragma: no cover
import gstools_core

USE_RUST = True
_GSTOOLS_CORE_AVAIL = True
USE_GSTOOLS_CORE = True
except ImportError:
USE_RUST = False
_GSTOOLS_CORE_AVAIL = False
USE_GSTOOLS_CORE = False
26 changes: 19 additions & 7 deletions src/gstools/field/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@
from gstools.covmodel.base import CovModel
from gstools.random.rng import RNG

if config.USE_RUST: # pragma: no cover
if config._GSTOOLS_CORE_AVAIL: # pragma: no cover
# pylint: disable=E0401
from gstools_core import summate, summate_incompr
else:
from gstools.field.summator import summate, summate_incompr
from gstools_core import (
summate as summate_gsc,
summate_incompr as summate_incompr_gsc,
)

from gstools.field.summator import summate as summate_c
from gstools.field.summator import summate_incompr as summate_incompr_c

__all__ = ["Generator", "RandMeth", "IncomprRandMeth"]

Expand Down Expand Up @@ -194,8 +198,8 @@ def __init__(
def __call__(self, pos, add_nugget=True):
"""Calculate the random modes for the randomization method.
This method calls the `summate_*` Cython methods, which are the
heart of the randomization method.
This method calls the `summate_*` Rust or Cython methods, which are
the heart of the randomization method.
Parameters
----------
Expand All @@ -209,6 +213,10 @@ def __call__(self, pos, add_nugget=True):
:class:`numpy.ndarray`
the random modes
"""
if config.USE_GSTOOLS_CORE and config._GSTOOLS_CORE_AVAIL:
summate = summate_gsc
else:
summate = summate_c
pos = np.asarray(pos, dtype=np.double)
summed_modes = summate(
self._cov_sample, self._z_1, self._z_2, pos, config.NUM_THREADS
Expand Down Expand Up @@ -473,7 +481,7 @@ def __init__(
def __call__(self, pos, add_nugget=True):
"""Calculate the random modes for the randomization method.
This method calls the `summate_incompr_*` Cython methods,
This method calls the `summate_incompr_*` Rust or Cython methods,
which are the heart of the randomization method.
In this class the method contains a projector to
ensure the incompressibility of the vector field.
Expand All @@ -490,6 +498,10 @@ def __call__(self, pos, add_nugget=True):
:class:`numpy.ndarray`
the random modes
"""
if config.USE_GSTOOLS_CORE and config._GSTOOLS_CORE_AVAIL:
summate_incompr = summate_incompr_gsc
else:
summate_incompr = summate_incompr_c
pos = np.asarray(pos, dtype=np.double)
summed_modes = summate_incompr(
self._cov_sample,
Expand Down
33 changes: 24 additions & 9 deletions src/gstools/krige/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@
from gstools.tools.misc import eval_func
from gstools.variogram import vario_estimate

if config.USE_RUST: # pragma: no cover
if config._GSTOOLS_CORE_AVAIL: # pragma: no cover
# pylint: disable=E0401
from gstools_core import calc_field_krige, calc_field_krige_and_variance
else:
from gstools.krige.krigesum import (
calc_field_krige,
calc_field_krige_and_variance,
from gstools_core import (
calc_field_krige as calc_field_krige_gsc,
calc_field_krige_and_variance as calc_field_krige_and_variance_gsc,
)

from gstools.krige.krigesum import calc_field_krige as calc_field_krige_c
from gstools.krige.krigesum import (
calc_field_krige_and_variance as calc_field_krige_and_variance_c,
)

__all__ = ["Krige"]


Expand Down Expand Up @@ -237,6 +240,16 @@ def __call__(
the kriging error variance
(if return_var is True and only_mean is False)
"""
if config.USE_GSTOOLS_CORE and config._GSTOOLS_CORE_AVAIL:
self._calc_field_krige = calc_field_krige_gsc
self._calc_field_krige_and_variance = (
calc_field_krige_and_variance_gsc
)
else:
self._calc_field_krige = calc_field_krige_c
self._calc_field_krige_and_variance = (
calc_field_krige_and_variance_c
)
return_var &= not only_mean # don't return variance when calc. mean
fld_cnt = 2 if return_var else 1
default = self.default_field_names[2] if only_mean else None
Expand Down Expand Up @@ -284,11 +297,13 @@ def __call__(

def _summate(self, field, krige_var, c_slice, k_vec, return_var):
if return_var: # estimate error variance
field[c_slice], krige_var[c_slice] = calc_field_krige_and_variance(
self._krige_mat, k_vec, self._krige_cond
field[c_slice], krige_var[c_slice] = (
self._calc_field_krige_and_variance(
self._krige_mat, k_vec, self._krige_cond
)
)
else: # solely calculate the interpolated field
field[c_slice] = calc_field_krige(
field[c_slice] = self._calc_field_krige(
self._krige_mat, k_vec, self._krige_cond
)

Expand Down
36 changes: 23 additions & 13 deletions src/gstools/variogram/variogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,17 @@
)
from gstools.variogram.binning import standard_bins

if config.USE_RUST: # pragma: no cover
if config._GSTOOLS_CORE_AVAIL: # pragma: no cover
# pylint: disable=E0401
from gstools_core import variogram_directional as directional
from gstools_core import variogram_ma_structured as ma_structured
from gstools_core import variogram_structured as structured
from gstools_core import variogram_unstructured as unstructured
else:
from gstools.variogram.estimator import (
directional,
ma_structured,
structured,
unstructured,
)
from gstools_core import variogram_directional as directional_gsc
from gstools_core import variogram_ma_structured as ma_structured_gsc
from gstools_core import variogram_structured as structured_gsc
from gstools_core import variogram_unstructured as unstructured_gsc

from gstools.variogram.estimator import directional as directional_c
from gstools.variogram.estimator import ma_structured as ma_structured_c
from gstools.variogram.estimator import structured as structured_c
from gstools.variogram.estimator import unstructured as unstructured_c

__all__ = [
"vario_estimate",
Expand Down Expand Up @@ -366,6 +364,12 @@ def vario_estimate(
# select variogram estimator
cython_estimator = _set_estimator(estimator)
# run
if config.USE_GSTOOLS_CORE and config._GSTOOLS_CORE_AVAIL:
unstructured = unstructured_gsc
directional = directional_gsc
else:
unstructured = unstructured_c
directional = directional_c
if dir_no == 0:
# "h"aversine or "e"uclidean distance type
distance_type = "h" if latlon else "e"
Expand Down Expand Up @@ -471,7 +475,7 @@ def vario_estimate_axis(
if missing:
field.mask = np.logical_or(field.mask, missing_mask)
mask = np.ma.getmaskarray(field)
if not config.USE_RUST:
if not config.USE_GSTOOLS_CORE and config._GSTOOLS_CORE_AVAIL:
mask = np.asarray(mask, dtype=np.int32)
else:
field = np.atleast_1d(np.asarray(field, dtype=np.double))
Expand All @@ -487,6 +491,12 @@ def vario_estimate_axis(

cython_estimator = _set_estimator(estimator)

if config.USE_GSTOOLS_CORE and config._GSTOOLS_CORE_AVAIL:
ma_structured = ma_structured_gsc
structured = structured_gsc
else:
ma_structured = ma_structured_c
structured = structured_c
if masked:
return ma_structured(
field, mask, cython_estimator, num_threads=config.NUM_THREADS
Expand Down

0 comments on commit 8fbe2ec

Please sign in to comment.