Skip to content

Commit

Permalink
Ensure return type of wrapped lazy function; add test for new function.
Browse files Browse the repository at this point in the history
  • Loading branch information
pp-mo committed Mar 6, 2018
1 parent 6fd0f2b commit e0097b9
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 8 deletions.
24 changes: 16 additions & 8 deletions lib/iris/_lazy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,20 @@ def wrap_lazy_elementwise(lazy_array, elementwise_op):
* elementwise_op:
The elementwise (numpy) array operation to apply.
.. note:
A single-point "dummy" call is made to the operation function, to
determine dtype of the result.
This return dtype should be stable in actual operation (!)
"""
# This is just an Iris wrapper for the Dask operation.
# For now, we support only a single argument array, and assume that the
# output dtype is the same as the input. This scope is intentionally
# minimal : we can extend it later as needed.

# Note: pass dtype, to prevent Dask making a test call to work it out.
return da.map_blocks(elementwise_op,
lazy_array, dtype=lazy_array.dtype)
# This is just a wrapper to provide an Iris-specific abstraction for a
# lazy operation in Dask (map_blocks).

# Explicitly determine the return type with a dummy call.
# This makes good practical sense for unit conversions, as a Unit.convert
# call may cast to float, or not, depending on unit equality : Thus, it's
# much safer to get cf_units to decide that for us.
dtype = elementwise_op(np.zeros(1, lazy_array.dtype)).dtype

return da.map_blocks(elementwise_op, lazy_array, dtype=dtype)
68 changes: 68 additions & 0 deletions lib/iris/tests/unit/lazy_data/test_wrap_lazy_elementwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# (C) British Crown Copyright 2018, Met Office
#
# This file is part of Iris.
#
# Iris is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the
# Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Iris is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Iris. If not, see <http://www.gnu.org/licenses/>.
"""Test function :func:`iris._lazy data.wrap_lazy_elementwise`."""

from __future__ import (absolute_import, division, print_function)
from six.moves import (filter, input, map, range, zip) # noqa

# Import iris.tests first so that some things can be initialised before
# importing anything else.
import iris.tests as tests

import numpy as np

from iris._lazy_data import as_lazy_data, is_lazy_data

from iris._lazy_data import wrap_lazy_elementwise


def _test_elementwise_op(array):
# Promotes the type of a bool argument, but not a float.
return array + 1


class Test_is_lazy_data(tests.IrisTest):
def test_basic(self):
concrete_array = np.arange(30).reshape((2, 5, 3))
lazy_array = as_lazy_data(concrete_array)
wrapped = wrap_lazy_elementwise(lazy_array,
_test_elementwise_op)
self.assertTrue(is_lazy_data(wrapped))
self.assertArrayAllClose(wrapped.compute(),
_test_elementwise_op(concrete_array))

def test_dtype_same(self):
concrete_array = np.array([3.], dtype=np.float16)
lazy_array = as_lazy_data(concrete_array)
wrapped = wrap_lazy_elementwise(lazy_array,
_test_elementwise_op)
self.assertTrue(is_lazy_data(wrapped))
self.assertEqual(wrapped.dtype, np.float16)
self.assertEqual(wrapped.compute().dtype, np.float16)

def test_dtype_change(self):
concrete_array = np.array([True, False])
lazy_array = as_lazy_data(concrete_array)
wrapped = wrap_lazy_elementwise(lazy_array,
_test_elementwise_op)
self.assertTrue(is_lazy_data(wrapped))
self.assertEqual(wrapped.dtype, np.int)
self.assertEqual(wrapped.compute().dtype, wrapped.dtype)


if __name__ == '__main__':
tests.main()

0 comments on commit e0097b9

Please sign in to comment.