From 6dc8b60849fab48f24494859c15a42f078025be6 Mon Sep 17 00:00:00 2001 From: Zach Griffith Date: Sun, 26 May 2019 19:20:54 -0500 Subject: [PATCH] Add fill_value for concat and auto_combine (#2964) * add fill_value option for concat and auto_combine * add tests for fill_value in concat and auto_combine * remove errant whitespace * add fill_value description to doc-string * add missing assert --- xarray/core/combine.py | 55 +++++++++++++++++++++++------------- xarray/tests/test_combine.py | 42 +++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 20 deletions(-) diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 1abd14cd20b..6d922064f6f 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -4,7 +4,7 @@ import pandas as pd -from . import utils +from . import utils, dtypes from .alignment import align from .merge import merge from .variable import IndexVariable, Variable, as_variable @@ -14,7 +14,7 @@ def concat(objs, dim=None, data_vars='all', coords='different', compat='equals', positions=None, indexers=None, mode=None, - concat_over=None): + concat_over=None, fill_value=dtypes.NA): """Concatenate xarray objects along a new or existing dimension. Parameters @@ -66,6 +66,8 @@ def concat(objs, dim=None, data_vars='all', coords='different', List of integer arrays which specifies the integer positions to which to assign each dataset along the concatenated dimension. If not supplied, objects are concatenated in the provided order. + fill_value : scalar, optional + Value to use for newly missing values indexers, mode, concat_over : deprecated Returns @@ -117,7 +119,7 @@ def concat(objs, dim=None, data_vars='all', coords='different', else: raise TypeError('can only concatenate xarray Dataset and DataArray ' 'objects, got %s' % type(first_obj)) - return f(objs, dim, data_vars, coords, compat, positions) + return f(objs, dim, data_vars, coords, compat, positions, fill_value) def _calc_concat_dim_coord(dim): @@ -212,7 +214,8 @@ def process_subset_opt(opt, subset): return concat_over, equals -def _dataset_concat(datasets, dim, data_vars, coords, compat, positions): +def _dataset_concat(datasets, dim, data_vars, coords, compat, positions, + fill_value=dtypes.NA): """ Concatenate a sequence of datasets along a new or existing dimension """ @@ -225,7 +228,8 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions): dim, coord = _calc_concat_dim_coord(dim) # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] - datasets = align(*datasets, join='outer', copy=False, exclude=[dim]) + datasets = align(*datasets, join='outer', copy=False, exclude=[dim], + fill_value=fill_value) concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords) @@ -317,7 +321,7 @@ def ensure_common_dims(vars): def _dataarray_concat(arrays, dim, data_vars, coords, compat, - positions): + positions, fill_value=dtypes.NA): arrays = list(arrays) if data_vars != 'all': @@ -336,14 +340,15 @@ def _dataarray_concat(arrays, dim, data_vars, coords, compat, datasets.append(arr._to_temp_dataset()) ds = _dataset_concat(datasets, dim, data_vars, coords, compat, - positions) + positions, fill_value) result = arrays[0]._from_temp_dataset(ds, name) result.name = result_name(arrays) return result -def _auto_concat(datasets, dim=None, data_vars='all', coords='different'): +def _auto_concat(datasets, dim=None, data_vars='all', coords='different', + fill_value=dtypes.NA): if len(datasets) == 1 and dim is None: # There is nothing more to combine, so kick out early. return datasets[0] @@ -366,7 +371,8 @@ def _auto_concat(datasets, dim=None, data_vars='all', coords='different'): 'supply the ``concat_dim`` argument ' 'explicitly') dim, = concat_dims - return concat(datasets, dim=dim, data_vars=data_vars, coords=coords) + return concat(datasets, dim=dim, data_vars=data_vars, + coords=coords, fill_value=fill_value) _CONCAT_DIM_DEFAULT = utils.ReprObject('') @@ -442,7 +448,8 @@ def _check_shape_tile_ids(combined_tile_ids): def _combine_nd(combined_ids, concat_dims, data_vars='all', - coords='different', compat='no_conflicts'): + coords='different', compat='no_conflicts', + fill_value=dtypes.NA): """ Concatenates and merges an N-dimensional structure of datasets. @@ -472,13 +479,14 @@ def _combine_nd(combined_ids, concat_dims, data_vars='all', dim=concat_dim, data_vars=data_vars, coords=coords, - compat=compat) + compat=compat, + fill_value=fill_value) combined_ds = list(combined_ids.values())[0] return combined_ds def _auto_combine_all_along_first_dim(combined_ids, dim, data_vars, - coords, compat): + coords, compat, fill_value=dtypes.NA): # Group into lines of datasets which must be combined along dim # need to sort by _new_tile_id first for groupby to work # TODO remove all these sorted OrderedDicts once python >= 3.6 only @@ -490,7 +498,8 @@ def _auto_combine_all_along_first_dim(combined_ids, dim, data_vars, combined_ids = OrderedDict(sorted(group)) datasets = combined_ids.values() new_combined_ids[new_id] = _auto_combine_1d(datasets, dim, compat, - data_vars, coords) + data_vars, coords, + fill_value) return new_combined_ids @@ -500,18 +509,20 @@ def vars_as_keys(ds): def _auto_combine_1d(datasets, concat_dim=_CONCAT_DIM_DEFAULT, compat='no_conflicts', - data_vars='all', coords='different'): + data_vars='all', coords='different', + fill_value=dtypes.NA): # This is just the old auto_combine function (which only worked along 1D) if concat_dim is not None: dim = None if concat_dim is _CONCAT_DIM_DEFAULT else concat_dim sorted_datasets = sorted(datasets, key=vars_as_keys) grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys) concatenated = [_auto_concat(list(ds_group), dim=dim, - data_vars=data_vars, coords=coords) + data_vars=data_vars, coords=coords, + fill_value=fill_value) for id, ds_group in grouped_by_vars] else: concatenated = datasets - merged = merge(concatenated, compat=compat) + merged = merge(concatenated, compat=compat, fill_value=fill_value) return merged @@ -521,7 +532,7 @@ def _new_tile_id(single_id_ds_pair): def _auto_combine(datasets, concat_dims, compat, data_vars, coords, - infer_order_from_coords, ids): + infer_order_from_coords, ids, fill_value=dtypes.NA): """ Calls logic to decide concatenation order before concatenating. """ @@ -550,12 +561,14 @@ def _auto_combine(datasets, concat_dims, compat, data_vars, coords, # Repeatedly concatenate then merge along each dimension combined = _combine_nd(combined_ids, concat_dims, compat=compat, - data_vars=data_vars, coords=coords) + data_vars=data_vars, coords=coords, + fill_value=fill_value) return combined def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT, - compat='no_conflicts', data_vars='all', coords='different'): + compat='no_conflicts', data_vars='all', coords='different', + fill_value=dtypes.NA): """Attempt to auto-magically combine the given datasets into one. This method attempts to combine a list of datasets into a single entity by inspecting metadata and using a combination of concat and merge. @@ -596,6 +609,8 @@ def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT, Details are in the documentation of concat coords : {'minimal', 'different', 'all' or list of str}, optional Details are in the documentation of conca + fill_value : scalar, optional + Value to use for newly missing values Returns ------- @@ -622,4 +637,4 @@ def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT, return _auto_combine(datasets, concat_dims=concat_dims, compat=compat, data_vars=data_vars, coords=coords, infer_order_from_coords=infer_order_from_coords, - ids=False) + ids=False, fill_value=fill_value) diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 1d8ed169d29..a477df0b0d4 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -7,6 +7,7 @@ import pytest from xarray import DataArray, Dataset, Variable, auto_combine, concat +from xarray.core import dtypes from xarray.core.combine import ( _auto_combine, _auto_combine_1d, _auto_combine_all_along_first_dim, _check_shape_tile_ids, _combine_nd, _infer_concat_order_from_positions, @@ -237,6 +238,20 @@ def test_concat_multiindex(self): assert expected.equals(actual) assert isinstance(actual.x.to_index(), pd.MultiIndex) + @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + def test_concat_fill_value(self, fill_value): + datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}), + Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})] + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + expected = Dataset({'a': (('t', 'x'), + [[fill_value, 2, 3], [1, 2, fill_value]])}, + {'x': [0, 1, 2]}) + actual = concat(datasets, dim='t', fill_value=fill_value) + assert_identical(actual, expected) + class TestConcatDataArray: def test_concat(self): @@ -306,6 +321,19 @@ def test_concat_lazy(self): assert combined.shape == (2, 3, 3) assert combined.dims == ('z', 'x', 'y') + @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + def test_concat_fill_value(self, fill_value): + foo = DataArray([1, 2], coords=[('x', [1, 2])]) + bar = DataArray([1, 2], coords=[('x', [1, 3])]) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + expected = DataArray([[1, 2, fill_value], [1, fill_value, 2]], + dims=['y', 'x'], coords={'x': [1, 2, 3]}) + actual = concat((foo, bar), dim='y', fill_value=fill_value) + assert_identical(actual, expected) + class TestAutoCombine: @@ -417,6 +445,20 @@ def test_auto_combine_no_concat(self): {'baz': [100]}) assert_identical(expected, actual) + @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + def test_auto_combine_fill_value(self, fill_value): + datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}), + Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})] + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + expected = Dataset({'a': (('t', 'x'), + [[fill_value, 2, 3], [1, 2, fill_value]])}, + {'x': [0, 1, 2]}) + actual = auto_combine(datasets, concat_dim='t', fill_value=fill_value) + assert_identical(expected, actual) + def assert_combined_tile_ids_equal(dict1, dict2): assert len(dict1) == len(dict2)