-
Notifications
You must be signed in to change notification settings - Fork 416
/
xarray.py
412 lines (351 loc) · 16.6 KB
/
xarray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
# Copyright (c) 2018 MetPy Developers.
# Distributed under the terms of the BSD 3-Clause License.
# SPDX-License-Identifier: BSD-3-Clause
"""Provide accessors to enhance interoperability between XArray and MetPy."""
from __future__ import absolute_import
import functools
import logging
import re
import warnings
import xarray as xr
from xarray.core.accessors import DatetimeAccessor
from .units import DimensionalityError, units
__all__ = []
readable_to_cf_axes = {'time': 'T', 'vertical': 'Z', 'y': 'Y', 'x': 'X'}
cf_to_readable_axes = {readable_to_cf_axes[key]: key for key in readable_to_cf_axes}
log = logging.getLogger(__name__)
@xr.register_dataarray_accessor('metpy')
class MetPyAccessor(object):
"""Provide custom attributes and methods on XArray DataArray for MetPy functionality."""
def __init__(self, data_array):
"""Initialize accessor with a DataArray."""
self._data_array = data_array
self._units = self._data_array.attrs.get('units', 'dimensionless')
@property
def units(self):
return units(self._units)
@property
def unit_array(self):
"""Return data values as a `pint.Quantity`."""
return self._data_array.values * self.units
@unit_array.setter
def unit_array(self, values):
"""Set data values as a `pint.Quantity`."""
self._data_array.values = values
self._units = self._data_array.attrs['units'] = str(values.units)
def convert_units(self, units):
"""Convert the data values to different units in-place."""
self.unit_array = self.unit_array.to(units)
@property
def crs(self):
"""Provide easy access to the `crs` coordinate."""
if 'crs' in self._data_array.coords:
return self._data_array.coords['crs'].item()
raise AttributeError('crs attribute is not available.')
@property
def cartopy_crs(self):
"""Return the coordinate reference system (CRS) as a cartopy object."""
return self.crs.to_cartopy()
@property
def cartopy_globe(self):
"""Return the globe belonging to the coordinate reference system (CRS)."""
return self.crs.cartopy_globe
def _axis(self, axis):
"""Return the coordinate variable corresponding to the given individual axis type."""
if axis in readable_to_cf_axes:
for coord_var in self._data_array.coords.values():
if coord_var.attrs.get('axis') == readable_to_cf_axes[axis]:
return coord_var
raise AttributeError(axis + ' attribute is not available.')
else:
raise AttributeError('\'' + axis + '\' is not an interpretable axis.')
def coordinates(self, *args):
"""Return the coordinate variables corresponding to the given axes types."""
for arg in args:
yield self._axis(arg)
@property
def time(self):
return self._axis('time')
@property
def vertical(self):
return self._axis('vertical')
@property
def y(self):
return self._axis('y')
@property
def x(self):
return self._axis('x')
def coordinates_identical(self, other):
"""Return whether or not the coordinates of other match this DataArray's."""
# If the number of coordinates do not match, we know they can't match.
if len(self._data_array.coords) != len(other.coords):
return False
# If same length, iterate over all of them and check
for coord_name, coord_var in self._data_array.coords.items():
if coord_name not in other.coords or not other[coord_name].identical(coord_var):
return False
# Otherwise, they match.
return True
def as_timestamp(self):
"""Return the data as unix timestamp (for easier time derivatives)."""
attrs = {key: self._data_array.attrs[key] for key in
{'standard_name', 'long_name', 'axis'} & set(self._data_array.attrs)}
attrs['units'] = 'seconds'
return xr.DataArray(self._data_array.values.astype('datetime64[s]').astype('int'),
name=self._data_array.name,
coords=self._data_array.coords,
dims=self._data_array.dims,
attrs=attrs)
def find_axis_name(self, axis):
"""Return the name of the axis corresponding to the given identifier.
The given indentifer can be an axis number (integer), dimension coordinate name
(string) or a standard axis type (string).
"""
if isinstance(axis, int):
# If an integer, use the corresponding dimension
return self._data_array.dims[axis]
elif axis not in self._data_array.dims and axis in readable_to_cf_axes:
# If not a dimension name itself, but a valid axis type, get the name of the
# coordinate corresponding to that axis type
return self._axis(axis).name
elif axis in self._data_array.dims and axis in self._data_array.coords:
# If this is a dimension coordinate name, use it directly
return axis
else:
# Otherwise, not valid
raise ValueError('Given axis is not valid. Must be an axis number, a dimension '
'coordinate name, or a standard axis type.')
@xr.register_dataset_accessor('metpy')
class CFConventionHandler(object):
"""Provide custom attributes and methods on XArray Dataset for MetPy functionality."""
def __init__(self, dataset):
"""Initialize accessor with a Dataset."""
self._dataset = dataset
def parse_cf(self, varname=None, coordinates=None):
"""Parse Climate and Forecasting (CF) convention metadata."""
from .plots.mapping import CFProjection
# If no varname is given, parse the entire dataset
if varname is None:
return self._dataset.apply(lambda da: self.parse_cf(da.name,
coordinates=coordinates))
var = self._dataset[varname]
if 'grid_mapping' in var.attrs:
proj_name = var.attrs['grid_mapping']
try:
proj_var = self._dataset.variables[proj_name]
except KeyError:
log.warning(
'Could not find variable corresponding to the value of '
'grid_mapping: {}'.format(proj_name))
else:
var.coords['crs'] = CFProjection(proj_var.attrs)
self._fixup_coords(var)
# Trying to guess whether we should be adding a crs to this variable's coordinates
# First make sure it's missing CRS but isn't lat/lon itself
if not self.check_axis(var, 'lat', 'lon') and 'crs' not in var.coords:
# Look for both lat/lon in the coordinates
has_lat = has_lon = False
for coord_var in var.coords.values():
has_lat = has_lat or self.check_axis(coord_var, 'lat')
has_lon = has_lon or self.check_axis(coord_var, 'lon')
# If we found them, create a lat/lon projection as the crs coord
if has_lat and has_lon:
var.coords['crs'] = CFProjection({'grid_mapping_name': 'latitude_longitude'})
# Obtain a map of axis types to coordinate variables
if coordinates is None:
# Generate the map from the supplied coordinates
coordinates = self._generate_coordinate_map(var.coords.values())
else:
# Verify that coordinates maps to coordinate variables, not coordinate names
self._fixup_coordinate_map(coordinates, var)
# Overwrite previous axis attributes, and use the coordinates to label anew
self._assign_axes(coordinates, var)
return var
# Define the criteria for coordinate matches
criteria = {
'standard_name': {
'time': 'time',
'vertical': {'air_pressure', 'height', 'geopotential_height', 'altitude',
'model_level_number', 'atmosphere_ln_pressure_coordinate',
'atmosphere_sigma_coordinate',
'atmosphere_hybrid_sigma_pressure_coordinate',
'atmosphere_hybrid_height_coordinate', 'atmosphere_sleve_coordinate',
'height_above_geopotential_datum', 'height_above_reference_ellipsoid',
'height_above_mean_sea_level'},
'y': 'projection_y_coordinate',
'lat': 'latitude',
'x': 'projection_x_coordinate',
'lon': 'longitude'
},
'_CoordinateAxisType': {
'time': 'Time',
'vertical': {'GeoZ', 'Height', 'Pressure'},
'y': 'GeoY',
'lat': 'Lat',
'x': 'GeoX',
'lon': 'Lon'
},
'axis': readable_to_cf_axes,
'positive': {
'vertical': {'up', 'down'}
},
'units': {
'vertical': {
'match': 'dimensionality',
'units': 'Pa'
},
'lat': {
'match': 'name',
'units': {'degree_north', 'degree_N', 'degreeN', 'degrees_north', 'degrees_N',
'degreesN'}
},
'lon': {
'match': 'name',
'units': {'degree_east', 'degree_E', 'degreeE', 'degrees_east', 'degrees_E',
'degreesE'}
},
},
'regular_expression': {
'time': r'time[0-9]*',
'vertical': (r'(bottom_top|sigma|h(ei)?ght|altitude|depth|isobaric|pres|'
r'isotherm)[a-z_]*[0-9]*'),
'y': r'y',
'lat': r'x?lat[a-z0-9]*',
'x': r'x',
'lon': r'x?lon[a-z0-9]*'
}
}
@classmethod
def check_axis(cls, var, *axes):
"""Check if var satisfies the criteria for any of the given axes."""
for axis in axes:
# Check for
# - standard name (CF option)
# - _CoordinateAxisType (from THREDDS)
# - axis (CF option)
# - positive (CF standard for non-pressure vertical coordinate)
for criterion in ('standard_name', '_CoordinateAxisType', 'axis', 'positive'):
if (var.attrs.get(criterion, 'absent') in
cls.criteria[criterion].get(axis, set())):
return True
# Check for units, either by dimensionality or name
if (axis in cls.criteria['units'] and (
(
cls.criteria['units'][axis]['match'] == 'dimensionality'
and (units.get_dimensionality(var.attrs.get('units'))
== units.get_dimensionality(cls.criteria['units'][axis]['units']))
) or (
cls.criteria['units'][axis]['match'] == 'name'
and var.attrs.get('units') in cls.criteria['units'][axis]['units']
))):
return True
# Check if name matches regular expression (non-CF failsafe)
if re.match(cls.criteria['regular_expression'][axis], var.name.lower()):
return True
def _fixup_coords(self, var):
"""Clean up the units on the coordinate variables."""
for coord_name, data_array in var.coords.items():
if (self.check_axis(data_array, 'x', 'y')
and not self.check_axis(data_array, 'lon', 'lat')):
try:
var.coords[coord_name].metpy.convert_units('meters')
except DimensionalityError: # Radians!
if 'crs' in var.coords:
new_data_array = data_array.copy()
height = var.coords['crs'].item()['perspective_point_height']
scaled_vals = new_data_array.metpy.unit_array * (height * units.meters)
new_data_array.metpy.unit_array = scaled_vals.to('meters')
var.coords[coord_name] = new_data_array
def _generate_coordinate_map(self, coords):
"""Generate a coordinate map via CF conventions and other methods."""
# Parse all the coordinates, attempting to identify x, y, vertical, time
coord_lists = {'T': [], 'Z': [], 'Y': [], 'X': []}
for coord_var in coords:
# Identify the coordinate type using check_axis helper
axes_to_check = {
'T': ('time',),
'Z': ('vertical',),
'Y': ('y', 'lat'),
'X': ('x', 'lon')
}
for axis_cf, axes_readable in axes_to_check.items():
if self.check_axis(coord_var, *axes_readable):
coord_lists[axis_cf].append(coord_var)
# Resolve any coordinate conflicts
axis_conflicts = [axis for axis in coord_lists if len(coord_lists[axis]) > 1]
for axis in axis_conflicts:
self._resolve_axis_conflict(axis, coord_lists)
# Collapse the coord_lists to a coord_map
return {axis: (coord_lists[axis][0] if len(coord_lists[axis]) > 0 else None)
for axis in coord_lists}
@staticmethod
def _fixup_coordinate_map(coord_map, var):
"""Ensure sure we have coordinate variables in map, not coordinate names."""
for axis in coord_map:
if not isinstance(coord_map[axis], xr.DataArray):
coord_map[axis] = var[coord_map[axis]]
@staticmethod
def _assign_axes(coord_map, var):
"""Assign axis attribute to coordinates in var according to coord_map."""
for coord_var in var.coords.values():
if 'axis' in coord_var.attrs:
del coord_var.attrs['axis']
for axis in coord_map:
if coord_map[axis] is not None:
coord_map[axis].attrs['axis'] = axis
def _resolve_axis_conflict(self, axis, coord_lists):
"""Handle axis conflicts if they arise."""
if axis in ('Y', 'X'):
# Horizontal coordinate, can be projection x/y or lon/lat. So, check for
# existence of unique projection x/y (preferred over lon/lat) and use that if
# it exists uniquely
projection_coords = [coord_var for coord_var in coord_lists[axis] if
self.check_axis(coord_var, 'x', 'y')]
if len(projection_coords) == 1:
coord_lists[axis] = projection_coords
return
# If one and only one of the possible axes is a dimension, use it
dimension_coords = [coord_var for coord_var in coord_lists[axis] if
coord_var.name in coord_var.dims]
if len(dimension_coords) == 1:
coord_lists[axis] = dimension_coords
return
# Ambiguous axis, raise warning and do not parse
warnings.warn('DataArray of requested variable has more than one '
+ cf_to_readable_axes[axis]
+ ' coordinate. Specify the unique axes using the coordinates argument.')
coord_lists[axis] = []
def preprocess_xarray(func):
"""Decorate a function to convert all DataArray arguments to pint.Quantities.
This uses the metpy xarray accessors to do the actual conversion.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
args = tuple(a.metpy.unit_array if isinstance(a, xr.DataArray) else a for a in args)
kwargs = {name: (v.metpy.unit_array if isinstance(v, xr.DataArray) else v)
for name, v in kwargs.items()}
return func(*args, **kwargs)
return wrapper
def check_matching_coordinates(func):
"""Decorate a function to make sure all given DataArrays have matching coordinates."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
data_arrays = ([a for a in args if isinstance(a, xr.DataArray)]
+ [a for a in kwargs.values() if isinstance(a, xr.DataArray)])
if len(data_arrays) > 1:
first = data_arrays[0]
for other in data_arrays[1:]:
if not first.metpy.coordinates_identical(other):
raise ValueError('Input DataArray arguments must be on same coordinates.')
return func(*args, **kwargs)
return wrapper
# If DatetimeAccessor does not have a strftime, monkey patch one in
if not hasattr(DatetimeAccessor, 'strftime'):
def strftime(self, date_format):
"""Format time as a string."""
import pandas as pd
values = self._obj.data
values_as_series = pd.Series(values.ravel())
strs = values_as_series.dt.strftime(date_format)
return strs.values.reshape(values.shape)
DatetimeAccessor.strftime = strftime