Skip to content

Commit

Permalink
avoid index dataarrays for simplicity.
Browse files Browse the repository at this point in the history
need a solution to preserve index attrs
  • Loading branch information
dcherian committed May 6, 2020
1 parent a0e699f commit 4d40a25
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@
T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)


def get_index_vars(obj: Union[DataArray, Dataset]) -> dict:
return {dim: obj[dim] for dim in obj.indexes}


def to_object_array(iterable):
npargs = np.empty((len(iterable),), dtype=np.object)
for idx, item in enumerate(iterable):
Expand Down Expand Up @@ -247,7 +243,7 @@ def _wrapper(
raise ValueError(f"Dimensions {missing_dimensions} missing on returned object.")

# check that index lengths and values are as expected
for name, index in get_index_vars(result).items():
for name, index in result.indexes.items():
if name in check_shapes:
if len(index) != check_shapes[name]:
raise ValueError(
Expand Down Expand Up @@ -412,11 +408,11 @@ def map_blocks(

# check that chunk sizes are compatible
input_chunks = dict(npargs[0].chunks)
input_indexes = get_index_vars(npargs[0])
input_indexes = dict(npargs[0].indexes)
for arg in npargs[1:][is_xarray[1:]]:
assert_chunks_compatible(npargs[0], arg)
input_chunks.update(arg.chunks)
input_indexes.update(get_index_vars(arg))
input_indexes.update(arg.indexes)

if template is None:
# infer template by providing zero-shaped arrays
Expand All @@ -425,15 +421,15 @@ def map_blocks(
preserved_indexes = template_indexes & set(input_indexes)
new_indexes = template_indexes - set(input_indexes)
indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
indexes.update({k: template[k] for k in new_indexes})
indexes.update({k: template.indexes[k] for k in new_indexes})
output_chunks = {
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
}

else:
# template xarray object has been provided with proper sizes and chunk shapes
indexes = input_indexes
indexes.update(get_index_vars(template))
indexes.update(template.indexes)
if isinstance(template, DataArray):
output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore
else:
Expand Down

0 comments on commit 4d40a25

Please sign in to comment.