Skip to content

Commit

Permalink
Merge pull request #2681 from pnuu/resampler-warnings
Browse files Browse the repository at this point in the history
Get rid of warnings in resampler tests
  • Loading branch information
pnuu authored Dec 14, 2023
2 parents a78b36f + 1bb5655 commit 4076e99
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 119 deletions.
53 changes: 7 additions & 46 deletions satpy/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"bucket_sum", "Sum Bucket Resampling", :class:`~satpy.resample.BucketSum`
"bucket_count", "Count Bucket Resampling", :class:`~satpy.resample.BucketCount`
"bucket_fraction", "Fraction Bucket Resampling", :class:`~satpy.resample.BucketFraction`
"gradient_search", "Gradient Search Resampling", :class:`~pyresample.gradient.GradientSearchResampler`
"gradient_search", "Gradient Search Resampling", :meth:`~pyresample.gradient.create_gradient_search_resampler`
The resampling algorithm used can be specified with the ``resampler`` keyword
argument and defaults to ``nearest``:
Expand Down Expand Up @@ -148,13 +148,11 @@

import dask.array as da
import numpy as np
import pyresample
import xarray as xr
import zarr
from packaging import version
from pyresample.ewa import DaskEWAResampler, LegacyDaskEWAResampler
from pyresample.geometry import SwathDefinition
from pyresample.gradient import GradientSearchResampler
from pyresample.gradient import create_gradient_search_resampler
from pyresample.resampler import BaseResampler as PRBaseResampler

from satpy._config import config_search_paths, get_config_path
Expand All @@ -177,8 +175,6 @@

resamplers_cache: "WeakValueDictionary[tuple, object]" = WeakValueDictionary()

PR_USE_SKIPNA = version.parse(pyresample.__version__) > version.parse("1.17.0")


def hash_dict(the_dict, the_hash=None):
"""Calculate a hash for a dictionary."""
Expand Down Expand Up @@ -773,33 +769,6 @@ def _get_replicated_chunk_sizes(d_arr, repeats):
return tuple(repeated_chunks)


def _get_arg_to_pass_for_skipna_handling(**kwargs):
"""Determine if skipna can be passed to the compute functions for the average and sum bucket resampler."""
# FIXME this can be removed once Pyresample 1.18.0 is a Satpy requirement

if PR_USE_SKIPNA:
if "mask_all_nan" in kwargs:
warnings.warn(
"Argument mask_all_nan is deprecated. Please use skipna for missing values handling. "
"Continuing with default skipna=True, if not provided differently.",
DeprecationWarning,
stacklevel=3
)
kwargs.pop("mask_all_nan")
else:
if "mask_all_nan" in kwargs:
warnings.warn(
"Argument mask_all_nan is deprecated."
"Please update Pyresample and use skipna for missing values handling.",
DeprecationWarning,
stacklevel=3
)
kwargs.setdefault("mask_all_nan", False)
kwargs.pop("skipna")

return kwargs


class BucketResamplerBase(PRBaseResampler):
"""Base class for bucket resampling which implements averaging."""

Expand Down Expand Up @@ -832,11 +801,6 @@ def resample(self, data, **kwargs): # noqa: D417
Returns (xarray.DataArray): Data resampled to the target area
"""
if not PR_USE_SKIPNA and "skipna" in kwargs:
raise ValueError("You are trying to set the skipna argument but you are using an old version of"
" Pyresample that does not support it."
"Please update Pyresample to 1.18.0 or higher to be able to use this argument.")

self.precompute(**kwargs)
attrs = data.attrs.copy()
data_arr = data.data
Expand Down Expand Up @@ -910,17 +874,16 @@ def compute(self, data, fill_value=np.nan, skipna=True, **kwargs): # noqa: D417
Returns:
dask.Array
"""
kwargs = _get_arg_to_pass_for_skipna_handling(skipna=skipna, **kwargs)

results = []
if data.ndim == 3:
for i in range(data.shape[0]):
res = self.resampler.get_average(data[i, :, :],
fill_value=fill_value,
skipna=skipna,
**kwargs)
results.append(res)
else:
res = self.resampler.get_average(data, fill_value=fill_value,
res = self.resampler.get_average(data, fill_value=fill_value, skipna=skipna,
**kwargs)
results.append(res)

Expand Down Expand Up @@ -948,16 +911,14 @@ class BucketSum(BucketResamplerBase):

def compute(self, data, skipna=True, **kwargs):
"""Call the resampling."""
kwargs = _get_arg_to_pass_for_skipna_handling(skipna=skipna, **kwargs)

results = []
if data.ndim == 3:
for i in range(data.shape[0]):
res = self.resampler.get_sum(data[i, :, :],
res = self.resampler.get_sum(data[i, :, :], skipna=skipna,
**kwargs)
results.append(res)
else:
res = self.resampler.get_sum(data, **kwargs)
res = self.resampler.get_sum(data, skipna=skipna, **kwargs)
results.append(res)

return da.stack(results)
Expand Down Expand Up @@ -1009,7 +970,7 @@ def compute(self, data, fill_value=np.nan, categories=None, **kwargs):
"nearest": KDTreeResampler,
"bilinear": BilinearResampler,
"native": NativeResampler,
"gradient_search": GradientSearchResampler,
"gradient_search": create_gradient_search_resampler,
"bucket_avg": BucketAvg,
"bucket_sum": BucketSum,
"bucket_count": BucketCount,
Expand Down
81 changes: 8 additions & 73 deletions satpy/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def get_test_data(input_shape=(100, 50), output_shape=(200, 100), output_proj=No
"""
import dask.array as da
from pyresample.geometry import AreaDefinition, SwathDefinition
from pyresample.utils import proj4_str_to_dict
from xarray import DataArray
ds1 = DataArray(da.zeros(input_shape, chunks=85),
dims=input_dims,
Expand All @@ -62,16 +61,16 @@ def get_test_data(input_shape=(100, 50), output_shape=(200, 100), output_proj=No

input_proj_str = ("+proj=geos +lon_0=-95.0 +h=35786023.0 +a=6378137.0 "
"+b=6356752.31414 +sweep=x +units=m +no_defs")
crs = CRS(input_proj_str)
source = AreaDefinition(
"test_target",
"test_target",
"test_target",
proj4_str_to_dict(input_proj_str),
crs,
input_shape[1], # width
input_shape[0], # height
(-1000., -1500., 1000., 1500.))
ds1.attrs["area"] = source
crs = CRS.from_string(input_proj_str)
ds1 = ds1.assign_coords(crs=crs)

ds2 = ds1.copy()
Expand All @@ -95,7 +94,7 @@ def get_test_data(input_shape=(100, 50), output_shape=(200, 100), output_proj=No
"test_target",
"test_target",
"test_target",
proj4_str_to_dict(output_proj_str),
CRS(output_proj_str),
output_shape[1], # width
output_shape[0], # height
(-1000., -1500., 1000., 1500.),
Expand Down Expand Up @@ -248,8 +247,12 @@ def test_expand_reduce_agg_rechunk(self):
into that chunk size.
"""
from satpy.utils import PerformanceWarning

d_arr = da.zeros((6, 20), chunks=3)
new_data = NativeResampler._expand_reduce(d_arr, {0: 0.5, 1: 0.5})
text = "Array chunk size is not divisible by aggregation factor. Re-chunking to continue native resampling."
with pytest.warns(PerformanceWarning, match=text):
new_data = NativeResampler._expand_reduce(d_arr, {0: 0.5, 1: 0.5})
assert new_data.shape == (3, 10)

def test_expand_reduce_numpy(self):
Expand Down Expand Up @@ -582,17 +585,10 @@ def test_compute(self):
res = self._compute_mocked_bucket_avg(data, return_data=data[0, :, :], fill_value=2)
assert res.shape == (3, 5, 5)

@mock.patch("satpy.resample.PR_USE_SKIPNA", True)
def test_compute_and_use_skipna_handling(self):
"""Test bucket resampler computation and use skipna handling."""
data = da.ones((5,))

self._compute_mocked_bucket_avg(data, fill_value=2, mask_all_nan=True)
self.bucket.resampler.get_average.assert_called_once_with(
data,
fill_value=2,
skipna=True)

self._compute_mocked_bucket_avg(data, fill_value=2, skipna=False)
self.bucket.resampler.get_average.assert_called_once_with(
data,
Expand All @@ -605,35 +601,6 @@ def test_compute_and_use_skipna_handling(self):
fill_value=2,
skipna=True)

@mock.patch("satpy.resample.PR_USE_SKIPNA", False)
def test_compute_and_not_use_skipna_handling(self):
"""Test bucket resampler computation and not use skipna handling."""
data = da.ones((5,))

self._compute_mocked_bucket_avg(data, fill_value=2, mask_all_nan=True)
self.bucket.resampler.get_average.assert_called_once_with(
data,
fill_value=2,
mask_all_nan=True)

self._compute_mocked_bucket_avg(data, fill_value=2, mask_all_nan=False)
self.bucket.resampler.get_average.assert_called_once_with(
data,
fill_value=2,
mask_all_nan=False)

self._compute_mocked_bucket_avg(data, fill_value=2)
self.bucket.resampler.get_average.assert_called_once_with(
data,
fill_value=2,
mask_all_nan=False)

self._compute_mocked_bucket_avg(data, fill_value=2, skipna=True)
self.bucket.resampler.get_average.assert_called_once_with(
data,
fill_value=2,
mask_all_nan=False)

@mock.patch("pyresample.bucket.BucketResampler")
def test_resample(self, pyresample_bucket):
"""Test bucket resamplers resample method."""
Expand Down Expand Up @@ -713,16 +680,10 @@ def test_compute(self):
res = self._compute_mocked_bucket_sum(data, return_data=data[0, :, :])
assert res.shape == (3, 5, 5)

@mock.patch("satpy.resample.PR_USE_SKIPNA", True)
def test_compute_and_use_skipna_handling(self):
"""Test bucket resampler computation and use skipna handling."""
data = da.ones((5,))

self._compute_mocked_bucket_sum(data, mask_all_nan=True)
self.bucket.resampler.get_sum.assert_called_once_with(
data,
skipna=True)

self._compute_mocked_bucket_sum(data, skipna=False)
self.bucket.resampler.get_sum.assert_called_once_with(
data,
Expand All @@ -733,32 +694,6 @@ def test_compute_and_use_skipna_handling(self):
data,
skipna=True)

@mock.patch("satpy.resample.PR_USE_SKIPNA", False)
def test_compute_and_not_use_skipna_handling(self):
"""Test bucket resampler computation and not use skipna handling."""
data = da.ones((5,))

self._compute_mocked_bucket_sum(data, mask_all_nan=True)
self.bucket.resampler.get_sum.assert_called_once_with(
data,
mask_all_nan=True)

self._compute_mocked_bucket_sum(data, mask_all_nan=False)
self.bucket.resampler.get_sum.assert_called_once_with(
data,
mask_all_nan=False)

self._compute_mocked_bucket_sum(data)
self.bucket.resampler.get_sum.assert_called_once_with(
data,
mask_all_nan=False)

self._compute_mocked_bucket_sum(data, fill_value=2, skipna=True)
self.bucket.resampler.get_sum.assert_called_once_with(
data,
fill_value=2,
mask_all_nan=False)


class TestBucketCount(unittest.TestCase):
"""Test the count bucket resampler."""
Expand Down

0 comments on commit 4076e99

Please sign in to comment.