forked from metoppv/improver
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add apply_mask cli and function * remove print statement * Updates to make sure plugin doesn't require cubelist input * Update Docstring --------- Co-authored-by: Marcus Spelman <[email protected]>
- Loading branch information
Showing
5 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/env python | ||
# (C) Crown copyright, Met Office. All rights reserved. | ||
# | ||
# This file is part of IMPROVER and is released under a BSD 3-Clause license. | ||
# See LICENSE in the root of the repository for full licensing details. | ||
"""Script to apply provided mask to cube data.""" | ||
|
||
from improver import cli | ||
|
||
|
||
@cli.clizefy | ||
@cli.with_output | ||
def process(*cubes: cli.inputcube, mask_name: str, invert_mask: bool = "False"): | ||
""" | ||
Applies provided mask to cube data. The mask_name is used to extract the mask cube | ||
from the input cubelist. The other cube in the cubelist is then masked using the | ||
mask data. If invert_mask is True, the mask will be inverted before it is applied. | ||
Args: | ||
cubes (iris.cube.CubeList): | ||
A list of iris cubes that should contain exactly two cubes: a mask to be applied | ||
and a cube to apply the mask to. The cubes should have the same dimensions. | ||
mask_name (str): | ||
The name of the cube containing the mask data. This should match with exactly one | ||
of the cubes in the input cubelist. | ||
invert_mask (bool): | ||
Use to select whether the mask should be inverted before being applied to the data. | ||
Returns: | ||
A cube with the mask applied to the data. The metadata will exactly match the input cube. | ||
""" | ||
from improver.utilities.mask import apply_mask | ||
|
||
return apply_mask(*cubes, mask_name=mask_name, invert_mask=invert_mask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# (C) Crown copyright, Met Office. All rights reserved. | ||
# | ||
# This file is part of IMPROVER and is released under a BSD 3-Clause license. | ||
# See LICENSE in the root of the repository for full licensing details. | ||
"""Module for applying mask to a cube.""" | ||
|
||
from typing import Union | ||
|
||
import iris | ||
import numpy as np | ||
|
||
from improver.utilities.common_input_handle import as_cubelist | ||
from improver.utilities.cube_checker import find_dimension_coordinate_mismatch | ||
from improver.utilities.cube_manipulation import ( | ||
enforce_coordinate_ordering, | ||
get_coord_names, | ||
) | ||
|
||
|
||
def apply_mask( | ||
*cubes: Union[iris.cube.CubeList, iris.cube.Cube], | ||
mask_name: str, | ||
invert_mask: bool = False, | ||
) -> iris.cube.Cube: | ||
""" | ||
Apply a provided mask to a cube. If invert_mask is True, the mask will be inverted. | ||
Args: | ||
cubes: | ||
A list of iris cubes that should contain exactly two cubes: a mask and a cube | ||
to apply the mask to. The cubes should have the same dimensions. | ||
mask_name: | ||
The name of the mask cube. It should match with exactly one of the cubes in | ||
the input cubelist. | ||
invert_mask: | ||
If True, the mask will be inverted before it is applied. | ||
Returns: | ||
A cube with a mask applied to the data. | ||
Raises: | ||
ValueError: If the number of cubes provided is not equal to 2. | ||
ValueError: If the input cube and mask cube have different dimensions. | ||
""" | ||
cubes = as_cubelist(*cubes) | ||
cube_names = [cube.name() for cube in cubes] | ||
if len(cubes) != 2: | ||
raise ValueError( | ||
f"""Two cubes are required for masking, a mask and the cube to be masked. | ||
Provided cubes are {cube_names}""" | ||
) | ||
|
||
mask = cubes.extract_cube(mask_name) | ||
cubes.remove(mask) | ||
cube = cubes[0] | ||
|
||
# Ensure mask is in a boolean form and invert if requested | ||
mask.data = mask.data.astype(np.bool) | ||
if invert_mask: | ||
mask.data = ~mask.data | ||
|
||
coord_list = get_coord_names(cube) | ||
enforce_coordinate_ordering(mask, coord_list) | ||
|
||
# This error is required to stop the mask from being broadcasted to a new shape by numpy. When | ||
# the mask and cube have different shapes numpy will try to broadcast the mask to be the same | ||
# shape as the cube data. This might succeed but masks unexpected data points. | ||
if find_dimension_coordinate_mismatch(cube, mask): | ||
raise ValueError("Input cube and mask cube must have the same dimensions") | ||
|
||
cube.data = np.ma.array(cube.data, mask=mask.data) | ||
return cube |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# (C) Crown copyright, Met Office. All rights reserved. | ||
# | ||
# This file is part of IMPROVER and is released under a BSD 3-Clause license. | ||
# See LICENSE in the root of the repository for full licensing details. | ||
""" | ||
Tests for the apply-mask CLI | ||
""" | ||
|
||
import pytest | ||
|
||
from . import acceptance as acc | ||
|
||
pytestmark = [pytest.mark.acc, acc.skip_if_kgo_missing] | ||
CLI = acc.cli_name_with_dashes(__file__) | ||
run_cli = acc.run_cli(CLI) | ||
|
||
|
||
@pytest.mark.parametrize("invert", [True, False]) | ||
def test_apply_mask(tmp_path, invert): | ||
"""Test apply-mask CLI.""" | ||
kgo_dir = acc.kgo_root() / "apply-mask/" | ||
kgo_path = kgo_dir / "kgo.nc" | ||
if invert: | ||
kgo_path = kgo_dir / "kgo_inverted.nc" | ||
|
||
output_path = tmp_path / "output.nc" | ||
args = [ | ||
kgo_dir / "wind_speed.nc", | ||
kgo_dir / "mask.nc", | ||
"--mask-name", | ||
"land_binary_mask", | ||
"--invert-mask", | ||
f"{invert}", | ||
"--output", | ||
output_path, | ||
] | ||
run_cli(args) | ||
acc.compare(output_path, kgo_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# (C) Crown copyright, Met Office. All rights reserved. | ||
# | ||
# This file is part of IMPROVER and is released under a BSD 3-Clause license. | ||
# See LICENSE in the root of the repository for full licensing details. | ||
""" | ||
Unit tests for the function apply_mask. | ||
""" | ||
|
||
import iris | ||
import numpy as np | ||
import pytest | ||
|
||
from improver.synthetic_data.set_up_test_cubes import set_up_variable_cube | ||
from improver.utilities.cube_manipulation import enforce_coordinate_ordering | ||
from improver.utilities.mask import apply_mask | ||
|
||
|
||
@pytest.fixture | ||
def wind_gust_cube(): | ||
data = np.full((2, 3), 10) | ||
return set_up_variable_cube( | ||
data=data, attributes={"wind_gust_type": "10m_ratio"}, name="wind_gust" | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mask(): | ||
data = np.array([[0, 0, 1], [1, 1, 0]]) | ||
return set_up_variable_cube(data=data, name="land_sea_mask") | ||
|
||
|
||
@pytest.mark.parametrize("invert_mask", [True, False]) | ||
@pytest.mark.parametrize("switch_coord_order", [True, False]) | ||
def test_basic(wind_gust_cube, mask, switch_coord_order, invert_mask): | ||
""" | ||
Test the basic functionality of the apply_mask plugin. Checks that the | ||
mask is correctly applied and inverted if requested. Also checks plugin | ||
can cope with different orderings of coordinates on the input cubes.""" | ||
|
||
expected_data = np.full((2, 3), 10) | ||
expected_mask = np.array([[False, False, True], [True, True, False]]) | ||
if switch_coord_order: | ||
enforce_coordinate_ordering(wind_gust_cube, ["longitude", "latitude"]) | ||
expected_data = expected_data.transpose() | ||
expected_mask = expected_mask.transpose() | ||
if invert_mask: | ||
expected_mask = np.invert(expected_mask) | ||
|
||
input_list = [wind_gust_cube, mask] | ||
|
||
result = apply_mask( | ||
iris.cube.CubeList(input_list), | ||
mask_name="land_sea_mask", | ||
invert_mask=invert_mask, | ||
) | ||
|
||
assert np.allclose(result.data, expected_data) | ||
assert np.allclose(result.data.mask, expected_mask) | ||
|
||
|
||
def test_different_dimensions(wind_gust_cube, mask): | ||
""" Test that the function will raise an error if the mask cube has different | ||
dimensions to other cube.""" | ||
mask = mask[0] | ||
input_list = [wind_gust_cube, mask] | ||
with pytest.raises( | ||
ValueError, match="Input cube and mask cube must have the same dimensions" | ||
): | ||
apply_mask(iris.cube.CubeList(input_list), mask_name="land_sea_mask") | ||
|
||
|
||
def test_too_many_cubes(wind_gust_cube, mask): | ||
""" | ||
Test that the function will raise an error if more than two cubes are provided. | ||
""" | ||
input_list = [wind_gust_cube, wind_gust_cube, wind_gust_cube] | ||
with pytest.raises(ValueError, match="Two cubes are required for masking"): | ||
apply_mask(iris.cube.CubeList(input_list), mask_name="land_sea_mask") |