diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 5e9aa06f507..d1c79953a9b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,4 +1,3 @@ - [ ] Closes #xxxx (remove if there is no corresponding issue, which should only be the case for minor changes) - [ ] Tests added (for all bug fixes or enhancements) - - [ ] Tests passed (for all non-documentation changes) - [ ] Fully documented, including `whats-new.rst` for all changes and `api.rst` for new API (remove if this change should not be visible to users, e.g., if it is an internal clean-up, or if this is part of a larger project that will be documented later) diff --git a/.pep8speaks.yml b/.pep8speaks.yml new file mode 100644 index 00000000000..aedce6e44eb --- /dev/null +++ b/.pep8speaks.yml @@ -0,0 +1,12 @@ +# File : .pep8speaks.yml + +scanner: + diff_only: True # If True, errors caused by only the patch are shown + +pycodestyle: + max-line-length: 79 + ignore: # Errors and warnings to ignore + - E402, # module level import not at top of file + - E731, # do not assign a lambda expression, use a def + - W503 # line break before binary operator + - W504 # line break after binary operator diff --git a/.stickler.yml b/.stickler.yml deleted file mode 100644 index 79d8b7fb717..00000000000 --- a/.stickler.yml +++ /dev/null @@ -1,11 +0,0 @@ -linters: - flake8: - max-line-length: 79 - fixer: false - ignore: I002 - # stickler doesn't support 'exclude' for flake8 properly, so we disable it - # below with files.ignore: - # https://github.com/markstory/lint-review/issues/184 -files: - ignore: - - doc/**/*.py diff --git a/.travis.yml b/.travis.yml index 6df70e92954..defb37ec8aa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ # Based on http://conda.pydata.org/docs/travis.html -language: python +language: minimal sudo: false # use container based build notifications: email: false @@ -10,76 +10,48 @@ branches: matrix: fast_finish: true include: - - python: 2.7 - env: CONDA_ENV=py27-min - - python: 2.7 - env: CONDA_ENV=py27-cdat+iris+pynio - - python: 3.5 - env: CONDA_ENV=py35 - - python: 3.6 - env: CONDA_ENV=py36 - - python: 3.6 - env: + - env: CONDA_ENV=py27-min + - env: CONDA_ENV=py27-cdat+iris+pynio + - env: CONDA_ENV=py35 + - env: CONDA_ENV=py36 + - env: CONDA_ENV=py37 + - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" - - python: 3.6 - env: CONDA_ENV=py36-netcdf4-dev + - env: CONDA_ENV=py36-netcdf4-dev addons: apt_packages: - libhdf5-serial-dev - netcdf-bin - libnetcdf-dev - - python: 3.6 - env: CONDA_ENV=py36-dask-dev - - python: 3.6 - env: CONDA_ENV=py36-pandas-dev - - python: 3.6 - env: CONDA_ENV=py36-bottleneck-dev - - python: 3.6 - env: CONDA_ENV=py36-condaforge-rc - - python: 3.6 - env: CONDA_ENV=py36-pynio-dev - - python: 3.6 - env: CONDA_ENV=py36-rasterio1.0alpha - - python: 3.6 - env: CONDA_ENV=py36-zarr-dev - - python: 3.5 - env: CONDA_ENV=docs - - python: 3.6 - env: CONDA_ENV=py36-hypothesis + - env: CONDA_ENV=py36-dask-dev + - env: CONDA_ENV=py36-pandas-dev + - env: CONDA_ENV=py36-bottleneck-dev + - env: CONDA_ENV=py36-condaforge-rc + - env: CONDA_ENV=py36-pynio-dev + - env: CONDA_ENV=py36-rasterio-0.36 + - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=docs + - env: CONDA_ENV=py36-hypothesis + allow_failures: - - python: 3.6 - env: + - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" - - python: 3.6 - env: CONDA_ENV=py36-netcdf4-dev + - env: CONDA_ENV=py36-netcdf4-dev addons: apt_packages: - libhdf5-serial-dev - netcdf-bin - libnetcdf-dev - - python: 3.6 - env: CONDA_ENV=py36-dask-dev - - python: 3.6 - env: CONDA_ENV=py36-pandas-dev - - python: 3.6 - env: CONDA_ENV=py36-bottleneck-dev - - python: 3.6 - env: CONDA_ENV=py36-condaforge-rc - - python: 3.6 - env: CONDA_ENV=py36-pynio-dev - - python: 3.6 - env: CONDA_ENV=py36-rasterio1.0alpha - - python: 3.6 - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=py36-pandas-dev + - env: CONDA_ENV=py36-bottleneck-dev + - env: CONDA_ENV=py36-condaforge-rc + - env: CONDA_ENV=py36-pynio-dev + - env: CONDA_ENV=py36-zarr-dev before_install: - - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then - wget http://repo.continuum.io/miniconda/Miniconda-3.16.0-Linux-x86_64.sh -O miniconda.sh; - else - wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh; - fi + - wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh; - bash miniconda.sh -b -p $HOME/miniconda - export PATH="$HOME/miniconda/bin:$PATH" - hash -r @@ -99,9 +71,9 @@ install: - python xarray/util/print_versions.py script: - # TODO: restore this check once the upstream pandas issue is fixed: - # https://github.com/pandas-dev/pandas/issues/21071 - # - python -OO -c "import xarray" + - which python + - python --version + - python -OO -c "import xarray" - if [[ "$CONDA_ENV" == "docs" ]]; then conda install -c conda-forge sphinx sphinx_rtd_theme sphinx-gallery numpydoc; sphinx-build -n -j auto -b html -d _build/doctrees doc _build/html; diff --git a/HOW_TO_RELEASE b/HOW_TO_RELEASE index cdfcace809a..80f37e672a5 100644 --- a/HOW_TO_RELEASE +++ b/HOW_TO_RELEASE @@ -14,6 +14,7 @@ Time required: about an hour. 5. Tag the release: git tag -a v0.X.Y -m 'v0.X.Y' 6. Build source and binary wheels for pypi: + git clean -xdf # this deletes all uncommited changes! python setup.py bdist_wheel sdist 7. Use twine to register and upload the release on pypi. Be careful, you can't take this back! @@ -37,16 +38,12 @@ Time required: about an hour. git push upstream master You're done pushing to master! 12. Issue the release on GitHub. Click on "Draft a new release" at - https://github.com/pydata/xarray/releases and paste in the latest from - whats-new.rst. + https://github.com/pydata/xarray/releases. Type in the version number, but + don't bother to describe it -- we maintain that on the docs instead. 13. Update the docs. Login to https://readthedocs.org/projects/xray/versions/ and switch your new release tag (at the bottom) from "Inactive" to "Active". It should now build automatically. -14. Update conda-forge. Clone https://github.com/conda-forge/xarray-feedstock - and update the version number and sha256 in meta.yaml. (On OS X, you can - calculate sha256 with `shasum -a 256 xarray-0.X.Y.tar.gz`). Submit a pull - request (and merge it, once CI passes). -15. Issue the release announcement! For bug fix releases, I usually only email +14. Issue the release announcement! For bug fix releases, I usually only email xarray@googlegroups.com. For major/feature releases, I will email a broader list (no more than once every 3-6 months): pydata@googlegroups.com, xarray@googlegroups.com, diff --git a/README.rst b/README.rst index 94beea1dba4..0ac71d33954 100644 --- a/README.rst +++ b/README.rst @@ -15,6 +15,8 @@ xarray: N-D labeled arrays and datasets :target: https://zenodo.org/badge/latestdoi/13221727 .. image:: http://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat :target: http://pandas.pydata.org/speed/xarray/ +.. image:: https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A + :target: http://numfocus.org **xarray** (formerly **xray**) is an open source project and Python package that aims to bring the labeled data power of pandas_ to the physical sciences, by providing @@ -103,20 +105,36 @@ Get in touch .. _mailing list: https://groups.google.com/forum/#!forum/xarray .. _on GitHub: http://github.com/pydata/xarray +NumFOCUS +-------- + +.. image:: https://numfocus.org/wp-content/uploads/2017/07/NumFocus_LRG.png + :scale: 25 % + :target: https://numfocus.org/ + +Xarray is a fiscally sponsored project of NumFOCUS_, a nonprofit dedicated +to supporting the open source scientific computing community. If you like +Xarray and want to support our mission, please consider making a donation_ +to support our efforts. + +.. _donation: https://www.flipcause.com/secure/cause_pdetails/NDE2NTU= + History ------- xarray is an evolution of an internal tool developed at `The Climate Corporation`__. It was originally written by Climate Corp researchers Stephan Hoyer, Alex Kleeman and Eugene Brevdo and was released as open source in -May 2014. The project was renamed from "xray" in January 2016. +May 2014. The project was renamed from "xray" in January 2016. Xarray became a +fiscally sponsored project of NumFOCUS_ in August 2018. __ http://climate.com/ +.. _NumFOCUS: https://numfocus.org License ------- -Copyright 2014-2017, xarray Developers +Copyright 2014-2018, xarray Developers Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index b5953436387..e3933b400e6 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -64,6 +64,7 @@ "scipy": [""], "bottleneck": ["", null], "dask": [""], + "distributed": [""], }, diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 54ed9ac9fa2..3e070e1355b 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -1,11 +1,13 @@ from __future__ import absolute_import, division, print_function +import os + import numpy as np import pandas as pd import xarray as xr -from . import randn, randint, requires_dask +from . import randint, randn, requires_dask try: import dask @@ -14,6 +16,9 @@ pass +os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' + + class IOSingleNetCDF(object): """ A few examples that benchmark reading/writing a single netCDF file with @@ -163,7 +168,7 @@ def time_load_dataset_netcdf4_with_block_chunks_vindexing(self): ds = ds.isel(**self.vinds).load() def time_load_dataset_netcdf4_with_block_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_dataset(self.filepath, engine='netcdf4', chunks=self.block_chunks).load() @@ -172,7 +177,7 @@ def time_load_dataset_netcdf4_with_time_chunks(self): chunks=self.time_chunks).load() def time_load_dataset_netcdf4_with_time_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_dataset(self.filepath, engine='netcdf4', chunks=self.time_chunks).load() @@ -189,7 +194,7 @@ def setup(self): self.ds.to_netcdf(self.filepath, format=self.format) def time_load_dataset_scipy_with_block_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_dataset(self.filepath, engine='scipy', chunks=self.block_chunks).load() @@ -204,7 +209,7 @@ def time_load_dataset_scipy_with_block_chunks_vindexing(self): ds = ds.isel(**self.vinds).load() def time_load_dataset_scipy_with_time_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_dataset(self.filepath, engine='scipy', chunks=self.time_chunks).load() @@ -344,7 +349,7 @@ def time_load_dataset_netcdf4_with_block_chunks(self): chunks=self.block_chunks).load() def time_load_dataset_netcdf4_with_block_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='netcdf4', chunks=self.block_chunks).load() @@ -353,7 +358,7 @@ def time_load_dataset_netcdf4_with_time_chunks(self): chunks=self.time_chunks).load() def time_load_dataset_netcdf4_with_time_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='netcdf4', chunks=self.time_chunks).load() @@ -362,7 +367,7 @@ def time_open_dataset_netcdf4_with_block_chunks(self): chunks=self.block_chunks) def time_open_dataset_netcdf4_with_block_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='netcdf4', chunks=self.block_chunks) @@ -371,7 +376,7 @@ def time_open_dataset_netcdf4_with_time_chunks(self): chunks=self.time_chunks) def time_open_dataset_netcdf4_with_time_chunks_multiprocessing(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='netcdf4', chunks=self.time_chunks) @@ -387,21 +392,57 @@ def setup(self): format=self.format) def time_load_dataset_scipy_with_block_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.block_chunks).load() def time_load_dataset_scipy_with_time_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.time_chunks).load() def time_open_dataset_scipy_with_block_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.block_chunks) def time_open_dataset_scipy_with_time_chunks(self): - with dask.set_options(get=dask.multiprocessing.get): + with dask.config.set(scheduler="multiprocessing"): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.time_chunks) + + +def create_delayed_write(): + import dask.array as da + vals = da.random.random(300, chunks=(1,)) + ds = xr.Dataset({'vals': (['a'], vals)}) + return ds.to_netcdf('file.nc', engine='netcdf4', compute=False) + + +class IOWriteNetCDFDask(object): + timeout = 60 + repeat = 1 + number = 5 + + def setup(self): + requires_dask() + self.write = create_delayed_write() + + def time_write(self): + self.write.compute() + + +class IOWriteNetCDFDaskDistributed(object): + def setup(self): + try: + import distributed + except ImportError: + raise NotImplementedError + self.client = distributed.Client() + self.write = create_delayed_write() + + def cleanup(self): + self.client.shutdown() + + def time_write(self): + self.write.compute() diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py new file mode 100644 index 00000000000..54436b422e9 --- /dev/null +++ b/asv_bench/benchmarks/unstacking.py @@ -0,0 +1,26 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +import xarray as xr + +from . import requires_dask + + +class Unstacking(object): + def setup(self): + data = np.random.RandomState(0).randn(1, 1000, 500) + self.ds = xr.DataArray(data).stack(flat_dim=['dim_1', 'dim_2']) + + def time_unstack_fast(self): + self.ds.unstack('flat_dim') + + def time_unstack_slow(self): + self.ds[:, ::-1].unstack('flat_dim') + + +class UnstackingDask(Unstacking): + def setup(self, *args, **kwargs): + requires_dask() + super(UnstackingDask, self).setup(**kwargs) + self.ds = self.ds.chunk({'flat_dim': 50}) diff --git a/ci/requirements-py36-dask-dev.yml b/ci/requirements-py36-dask-dev.yml index 54cdb54e8fc..e580aaf3889 100644 --- a/ci/requirements-py36-dask-dev.yml +++ b/ci/requirements-py36-dask-dev.yml @@ -12,9 +12,13 @@ dependencies: - flake8 - numpy - pandas - - seaborn - scipy + - seaborn - toolz + - rasterio + - bottleneck + - zarr + - pseudonetcdf>=3.0.1 - pip: - coveralls - pytest-cov diff --git a/ci/requirements-py36-rasterio1.0alpha.yml b/ci/requirements-py36-rasterio-0.36.yml similarity index 86% rename from ci/requirements-py36-rasterio1.0alpha.yml rename to ci/requirements-py36-rasterio-0.36.yml index 15ba13e753b..5c724e1b981 100644 --- a/ci/requirements-py36-rasterio1.0alpha.yml +++ b/ci/requirements-py36-rasterio-0.36.yml @@ -1,7 +1,6 @@ name: test_env channels: - conda-forge - - conda-forge/label/dev dependencies: - python=3.6 - cftime @@ -17,7 +16,7 @@ dependencies: - scipy - seaborn - toolz - - rasterio>=1.* + - rasterio=0.36.0 - bottleneck - pip: - coveralls diff --git a/ci/requirements-py36.yml b/ci/requirements-py36.yml index fd63fe26130..321f3087ea2 100644 --- a/ci/requirements-py36.yml +++ b/ci/requirements-py36.yml @@ -21,8 +21,10 @@ dependencies: - bottleneck - zarr - pseudonetcdf>=3.0.1 + - eccodes - pip: - coveralls - pytest-cov - pydap - lxml + - cfgrib>=0.9.2 diff --git a/ci/requirements-py37.yml b/ci/requirements-py37.yml new file mode 100644 index 00000000000..6292c4c5eb6 --- /dev/null +++ b/ci/requirements-py37.yml @@ -0,0 +1,30 @@ +name: test_env +channels: + - conda-forge +dependencies: + - python=3.7 + - cftime + - dask + - distributed + - h5py + - h5netcdf + - matplotlib + - netcdf4 + - pytest + - flake8 + - numpy + - pandas + - scipy + - seaborn + - toolz + - rasterio + - bottleneck + - zarr + - pseudonetcdf>=3.0.1 + - eccodes + - pip: + - coveralls + - pytest-cov + - pydap + - lxml + - cfgrib>=0.9.2 \ No newline at end of file diff --git a/doc/_static/dataset-diagram-square-logo.png b/doc/_static/dataset-diagram-square-logo.png new file mode 100644 index 00000000000..d1eeda092c4 Binary files /dev/null and b/doc/_static/dataset-diagram-square-logo.png differ diff --git a/doc/_static/dataset-diagram-square-logo.tex b/doc/_static/dataset-diagram-square-logo.tex new file mode 100644 index 00000000000..0a784770b50 --- /dev/null +++ b/doc/_static/dataset-diagram-square-logo.tex @@ -0,0 +1,277 @@ +\documentclass[class=minimal,border=0pt,convert={size=600,outext=.png}]{standalone} +% \documentclass[class=minimal,border=0pt]{standalone} +\usepackage[scaled]{helvet} +\renewcommand*\familydefault{\sfdefault} + +% =========================================================================== +% The code below (used to define the \tikzcuboid command) is copied, +% unmodified, from a tex.stackexchange.com answer by the user "Tom Bombadil": +% http://tex.stackexchange.com/a/29882/8335 +% +% It is licensed under the Creative Commons Attribution-ShareAlike 3.0 +% Unported license: http://creativecommons.org/licenses/by-sa/3.0/ +% =========================================================================== + +\usepackage[usenames,dvipsnames]{color} +\usepackage{tikz} +\usepackage{keyval} +\usepackage{ifthen} + +%==================================== +%emphasize vertices --> switch and emph style (e.g. thick,black) +%==================================== +\makeatletter +% Standard Values for Parameters +\newcommand{\tikzcuboid@shiftx}{0} +\newcommand{\tikzcuboid@shifty}{0} +\newcommand{\tikzcuboid@dimx}{3} +\newcommand{\tikzcuboid@dimy}{3} +\newcommand{\tikzcuboid@dimz}{3} +\newcommand{\tikzcuboid@scale}{1} +\newcommand{\tikzcuboid@densityx}{1} +\newcommand{\tikzcuboid@densityy}{1} +\newcommand{\tikzcuboid@densityz}{1} +\newcommand{\tikzcuboid@rotation}{0} +\newcommand{\tikzcuboid@anglex}{0} +\newcommand{\tikzcuboid@angley}{90} +\newcommand{\tikzcuboid@anglez}{225} +\newcommand{\tikzcuboid@scalex}{1} +\newcommand{\tikzcuboid@scaley}{1} +\newcommand{\tikzcuboid@scalez}{sqrt(0.5)} +\newcommand{\tikzcuboid@linefront}{black} +\newcommand{\tikzcuboid@linetop}{black} +\newcommand{\tikzcuboid@lineright}{black} +\newcommand{\tikzcuboid@fillfront}{white} +\newcommand{\tikzcuboid@filltop}{white} +\newcommand{\tikzcuboid@fillright}{white} +\newcommand{\tikzcuboid@shaded}{N} +\newcommand{\tikzcuboid@shadecolor}{black} +\newcommand{\tikzcuboid@shadeperc}{25} +\newcommand{\tikzcuboid@emphedge}{N} +\newcommand{\tikzcuboid@emphstyle}{thick} + +% Definition of Keys +\define@key{tikzcuboid}{shiftx}[\tikzcuboid@shiftx]{\renewcommand{\tikzcuboid@shiftx}{#1}} +\define@key{tikzcuboid}{shifty}[\tikzcuboid@shifty]{\renewcommand{\tikzcuboid@shifty}{#1}} +\define@key{tikzcuboid}{dimx}[\tikzcuboid@dimx]{\renewcommand{\tikzcuboid@dimx}{#1}} +\define@key{tikzcuboid}{dimy}[\tikzcuboid@dimy]{\renewcommand{\tikzcuboid@dimy}{#1}} +\define@key{tikzcuboid}{dimz}[\tikzcuboid@dimz]{\renewcommand{\tikzcuboid@dimz}{#1}} +\define@key{tikzcuboid}{scale}[\tikzcuboid@scale]{\renewcommand{\tikzcuboid@scale}{#1}} +\define@key{tikzcuboid}{densityx}[\tikzcuboid@densityx]{\renewcommand{\tikzcuboid@densityx}{#1}} +\define@key{tikzcuboid}{densityy}[\tikzcuboid@densityy]{\renewcommand{\tikzcuboid@densityy}{#1}} +\define@key{tikzcuboid}{densityz}[\tikzcuboid@densityz]{\renewcommand{\tikzcuboid@densityz}{#1}} +\define@key{tikzcuboid}{rotation}[\tikzcuboid@rotation]{\renewcommand{\tikzcuboid@rotation}{#1}} +\define@key{tikzcuboid}{anglex}[\tikzcuboid@anglex]{\renewcommand{\tikzcuboid@anglex}{#1}} +\define@key{tikzcuboid}{angley}[\tikzcuboid@angley]{\renewcommand{\tikzcuboid@angley}{#1}} +\define@key{tikzcuboid}{anglez}[\tikzcuboid@anglez]{\renewcommand{\tikzcuboid@anglez}{#1}} +\define@key{tikzcuboid}{scalex}[\tikzcuboid@scalex]{\renewcommand{\tikzcuboid@scalex}{#1}} +\define@key{tikzcuboid}{scaley}[\tikzcuboid@scaley]{\renewcommand{\tikzcuboid@scaley}{#1}} +\define@key{tikzcuboid}{scalez}[\tikzcuboid@scalez]{\renewcommand{\tikzcuboid@scalez}{#1}} +\define@key{tikzcuboid}{linefront}[\tikzcuboid@linefront]{\renewcommand{\tikzcuboid@linefront}{#1}} +\define@key{tikzcuboid}{linetop}[\tikzcuboid@linetop]{\renewcommand{\tikzcuboid@linetop}{#1}} +\define@key{tikzcuboid}{lineright}[\tikzcuboid@lineright]{\renewcommand{\tikzcuboid@lineright}{#1}} +\define@key{tikzcuboid}{fillfront}[\tikzcuboid@fillfront]{\renewcommand{\tikzcuboid@fillfront}{#1}} +\define@key{tikzcuboid}{filltop}[\tikzcuboid@filltop]{\renewcommand{\tikzcuboid@filltop}{#1}} +\define@key{tikzcuboid}{fillright}[\tikzcuboid@fillright]{\renewcommand{\tikzcuboid@fillright}{#1}} +\define@key{tikzcuboid}{shaded}[\tikzcuboid@shaded]{\renewcommand{\tikzcuboid@shaded}{#1}} +\define@key{tikzcuboid}{shadecolor}[\tikzcuboid@shadecolor]{\renewcommand{\tikzcuboid@shadecolor}{#1}} +\define@key{tikzcuboid}{shadeperc}[\tikzcuboid@shadeperc]{\renewcommand{\tikzcuboid@shadeperc}{#1}} +\define@key{tikzcuboid}{emphedge}[\tikzcuboid@emphedge]{\renewcommand{\tikzcuboid@emphedge}{#1}} +\define@key{tikzcuboid}{emphstyle}[\tikzcuboid@emphstyle]{\renewcommand{\tikzcuboid@emphstyle}{#1}} +% Commands +\newcommand{\tikzcuboid}[1]{ + \setkeys{tikzcuboid}{#1} % Process Keys passed to command + \pgfmathsetmacro{\vectorxx}{\tikzcuboid@scalex*cos(\tikzcuboid@anglex)} + \pgfmathsetmacro{\vectorxy}{\tikzcuboid@scalex*sin(\tikzcuboid@anglex)} + \pgfmathsetmacro{\vectoryx}{\tikzcuboid@scaley*cos(\tikzcuboid@angley)} + \pgfmathsetmacro{\vectoryy}{\tikzcuboid@scaley*sin(\tikzcuboid@angley)} + \pgfmathsetmacro{\vectorzx}{\tikzcuboid@scalez*cos(\tikzcuboid@anglez)} + \pgfmathsetmacro{\vectorzy}{\tikzcuboid@scalez*sin(\tikzcuboid@anglez)} + \begin{scope}[xshift=\tikzcuboid@shiftx, yshift=\tikzcuboid@shifty, scale=\tikzcuboid@scale, rotate=\tikzcuboid@rotation, x={(\vectorxx,\vectorxy)}, y={(\vectoryx,\vectoryy)}, z={(\vectorzx,\vectorzy)}] + \pgfmathsetmacro{\steppingx}{1/\tikzcuboid@densityx} + \pgfmathsetmacro{\steppingy}{1/\tikzcuboid@densityy} + \pgfmathsetmacro{\steppingz}{1/\tikzcuboid@densityz} + \newcommand{\dimx}{\tikzcuboid@dimx} + \newcommand{\dimy}{\tikzcuboid@dimy} + \newcommand{\dimz}{\tikzcuboid@dimz} + \pgfmathsetmacro{\secondx}{2*\steppingx} + \pgfmathsetmacro{\secondy}{2*\steppingy} + \pgfmathsetmacro{\secondz}{2*\steppingz} + \foreach \x in {\steppingx,\secondx,...,\dimx} + { \foreach \y in {\steppingy,\secondy,...,\dimy} + { \pgfmathsetmacro{\lowx}{(\x-\steppingx)} + \pgfmathsetmacro{\lowy}{(\y-\steppingy)} + \filldraw[fill=\tikzcuboid@fillfront,draw=\tikzcuboid@linefront] (\lowx,\lowy,\dimz) -- (\lowx,\y,\dimz) -- (\x,\y,\dimz) -- (\x,\lowy,\dimz) -- cycle; + + } + } + \foreach \x in {\steppingx,\secondx,...,\dimx} + { \foreach \z in {\steppingz,\secondz,...,\dimz} + { \pgfmathsetmacro{\lowx}{(\x-\steppingx)} + \pgfmathsetmacro{\lowz}{(\z-\steppingz)} + \filldraw[fill=\tikzcuboid@filltop,draw=\tikzcuboid@linetop] (\lowx,\dimy,\lowz) -- (\lowx,\dimy,\z) -- (\x,\dimy,\z) -- (\x,\dimy,\lowz) -- cycle; + } + } + \foreach \y in {\steppingy,\secondy,...,\dimy} + { \foreach \z in {\steppingz,\secondz,...,\dimz} + { \pgfmathsetmacro{\lowy}{(\y-\steppingy)} + \pgfmathsetmacro{\lowz}{(\z-\steppingz)} + \filldraw[fill=\tikzcuboid@fillright,draw=\tikzcuboid@lineright] (\dimx,\lowy,\lowz) -- (\dimx,\lowy,\z) -- (\dimx,\y,\z) -- (\dimx,\y,\lowz) -- cycle; + } + } + \ifthenelse{\equal{\tikzcuboid@emphedge}{Y}}% + {\draw[\tikzcuboid@emphstyle](0,\dimy,0) -- (\dimx,\dimy,0) -- (\dimx,\dimy,\dimz) -- (0,\dimy,\dimz) -- cycle;% + \draw[\tikzcuboid@emphstyle] (0,0,\dimz) -- (0,\dimy,\dimz) -- (\dimx,\dimy,\dimz) -- (\dimx,0,\dimz) -- cycle;% + \draw[\tikzcuboid@emphstyle](\dimx,0,0) -- (\dimx,\dimy,0) -- (\dimx,\dimy,\dimz) -- (\dimx,0,\dimz) -- cycle;% + }% + {} + \end{scope} +} + +\makeatother + +\begin{document} + +\begin{tikzpicture} + \tikzcuboid{% + shiftx=21cm,% + shifty=8cm,% + scale=1.00,% + rotation=0,% + densityx=2,% + densityy=2,% + densityz=2,% + dimx=4,% + dimy=3,% + dimz=3,% + linefront=purple!75!black,% + linetop=purple!50!black,% + lineright=purple!25!black,% + fillfront=purple!25!white,% + filltop=purple!50!white,% + fillright=purple!75!white,% + emphedge=Y,% + emphstyle=ultra thick, + } + \tikzcuboid{% + shiftx=21cm,% + shifty=11.6cm,% + scale=1.00,% + rotation=0,% + densityx=2,% + densityy=2,% + densityz=2,% + dimx=4,% + dimy=3,% + dimz=3,% + linefront=teal!75!black,% + linetop=teal!50!black,% + lineright=teal!25!black,% + fillfront=teal!25!white,% + filltop=teal!50!white,% + fillright=teal!75!white,% + emphedge=Y,% + emphstyle=ultra thick, + } + \tikzcuboid{% + shiftx=26.8cm,% + shifty=8cm,% + scale=1.00,% + rotation=0,% + densityx=10000,% + densityy=2,% + densityz=2,% + dimx=0,% + dimy=3,% + dimz=3,% + linefront=orange!75!black,% + linetop=orange!50!black,% + lineright=orange!25!black,% + fillfront=orange!25!white,% + filltop=orange!50!white,% + fillright=orange!100!white,% + emphedge=Y,% + emphstyle=ultra thick, + } + \tikzcuboid{% + shiftx=28.6cm,% + shifty=8cm,% + scale=1.00,% + rotation=0,% + densityx=10000,% + densityy=2,% + densityz=2,% + dimx=0,% + dimy=3,% + dimz=3,% + linefront=purple!75!black,% + linetop=purple!50!black,% + lineright=purple!25!black,% + fillfront=purple!25!white,% + filltop=purple!50!white,% + fillright=red!75!white,% + emphedge=Y,% + emphstyle=ultra thick, + } + % \tikzcuboid{% + % shiftx=27.1cm,% + % shifty=10.1cm,% + % scale=1.00,% + % rotation=0,% + % densityx=100,% + % densityy=2,% + % densityz=100,% + % dimx=0,% + % dimy=3,% + % dimz=0,% + % emphedge=Y,% + % emphstyle=ultra thick, + % } + % \tikzcuboid{% + % shiftx=27.1cm,% + % shifty=10.1cm,% + % scale=1.00,% + % rotation=180,% + % densityx=100,% + % densityy=100,% + % densityz=2,% + % dimx=0,% + % dimy=0,% + % dimz=3,% + % emphedge=Y,% + % emphstyle=ultra thick, + % } + \tikzcuboid{% + shiftx=26.8cm,% + shifty=11.4cm,% + scale=1.00,% + rotation=0,% + densityx=100,% + densityy=2,% + densityz=100,% + dimx=0,% + dimy=3,% + dimz=0,% + emphedge=Y,% + emphstyle=ultra thick, + } + \tikzcuboid{% + shiftx=25.3cm,% + shifty=12.9cm,% + scale=1.00,% + rotation=180,% + densityx=100,% + densityy=100,% + densityz=2,% + dimx=0,% + dimy=0,% + dimz=3,% + emphedge=Y,% + emphstyle=ultra thick, + } + % \fill (27.1,10.1) circle[radius=2pt]; + \node [font=\fontsize{130}{100}\fontfamily{phv}\selectfont, anchor=east, text width=2cm, align=right, color=white!50!black] at (19.8,4.4) {\textbf{\emph{x}}}; + \node [font=\fontsize{130}{100}\fontfamily{phv}\selectfont, anchor=west, text width=10cm, align=left] at (20.3,4) {{array}}; +\end{tikzpicture} + +\end{document} diff --git a/doc/_static/numfocus_logo.png b/doc/_static/numfocus_logo.png new file mode 100644 index 00000000000..af3c84209e0 Binary files /dev/null and b/doc/_static/numfocus_logo.png differ diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 1826cc86892..4b2fed8be37 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -39,7 +39,6 @@ Dataset.imag Dataset.round Dataset.real - Dataset.T Dataset.cumsum Dataset.cumprod Dataset.rank @@ -151,3 +150,6 @@ plot.FacetGrid.set_titles plot.FacetGrid.set_ticks plot.FacetGrid.map + + CFTimeIndex.shift + CFTimeIndex.to_datetimeindex diff --git a/doc/api.rst b/doc/api.rst index 927c0aa072c..662ef567710 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -150,6 +150,7 @@ Computation Dataset.resample Dataset.diff Dataset.quantile + Dataset.differentiate **Aggregation**: :py:attr:`~Dataset.all` @@ -317,6 +318,7 @@ Computation DataArray.diff DataArray.dot DataArray.quantile + DataArray.differentiate **Aggregation**: :py:attr:`~DataArray.all` @@ -555,6 +557,13 @@ Custom Indexes CFTimeIndex +Creating custom indexes +----------------------- +.. autosummary:: + :toctree: generated/ + + cftime_range + Plotting ======== @@ -615,3 +624,6 @@ arguments for the ``from_store`` and ``dump_to_store`` Dataset methods: backends.H5NetCDFStore backends.PydapDataStore backends.ScipyDataStore + backends.FileManager + backends.CachingFileManager + backends.DummyFileManager diff --git a/doc/computation.rst b/doc/computation.rst index 6793e667e06..759c87a6cc7 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -200,6 +200,31 @@ You can also use ``construct`` to compute a weighted rolling sum: To avoid this, use ``skipna=False`` as the above example. +Computation using Coordinates +============================= + +Xarray objects have some handy methods for the computation with their +coordinates. :py:meth:`~xarray.DataArray.differentiate` computes derivatives by +central finite differences using their coordinates, + +.. ipython:: python + + a = xr.DataArray([0, 1, 2, 3], dims=['x'], coords=[[0.1, 0.11, 0.2, 0.3]]) + a + a.differentiate('x') + +This method can be used also for multidimensional arrays, + +.. ipython:: python + + a = xr.DataArray(np.arange(8).reshape(4, 2), dims=['x', 'y'], + coords={'x': [0.1, 0.11, 0.2, 0.3]}) + a.differentiate('x') + +.. note:: + This method is limited to simple cartesian geometry. Differentiation along + multidimensional coordinate is not supported. + .. _compute.broadcasting: Broadcasting by dimension name diff --git a/doc/conf.py b/doc/conf.py index 5fd3bece3bd..897c0443054 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -25,7 +25,8 @@ print("python exec:", sys.executable) print("sys.path:", sys.path) for name in ('numpy scipy pandas matplotlib dask IPython seaborn ' - 'cartopy netCDF4 rasterio zarr').split(): + 'cartopy netCDF4 rasterio zarr iris flake8 ' + 'sphinx_gallery cftime').split(): try: module = importlib.import_module(name) if name == 'matplotlib': diff --git a/doc/dask.rst b/doc/dask.rst index 2d4beea4f70..672450065cb 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -13,7 +13,7 @@ dependency in a future version of xarray. For a full example of how to use xarray's dask integration, read the `blog post introducing xarray and dask`_. -.. _blog post introducing xarray and dask: https://www.anaconda.com/blog/developer-blog/xray-dask-out-core-labeled-arrays-python/ +.. _blog post introducing xarray and dask: http://stephanhoyer.com/2015/06/11/xray-dask-out-of-core-labeled-arrays/ What is a dask array? --------------------- diff --git a/doc/data-structures.rst b/doc/data-structures.rst index 10d83ca448f..618ccccff3e 100644 --- a/doc/data-structures.rst +++ b/doc/data-structures.rst @@ -408,13 +408,6 @@ operations keep around coordinates: list(ds[['x']]) list(ds.drop('temperature')) -If a dimension name is given as an argument to ``drop``, it also drops all -variables that use that dimension: - -.. ipython:: python - - list(ds.drop('time')) - As an alternate to dictionary-like modifications, you can use :py:meth:`~xarray.Dataset.assign` and :py:meth:`~xarray.Dataset.assign_coords`. These methods return a new dataset with additional (or replaced) or values: diff --git a/doc/environment.yml b/doc/environment.yml index a7683ff1824..bd134a7656f 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -1,24 +1,23 @@ name: xarray-docs channels: - conda-forge - - defaults dependencies: - python=3.6 - - numpy=1.13 - - pandas=0.21.0 - - scipy=1.0 - - bottleneck - - numpydoc=0.7.0 - - matplotlib=2.1.2 - - seaborn=0.8 - - dask=0.16.0 - - ipython=6.2.1 - - sphinx=1.5 - - netCDF4=1.3.1 - - cartopy=0.15.1 - - rasterio=0.36.0 - - sphinx-gallery - - zarr - - iris - - flake8 - - cftime + - numpy=1.14.5 + - pandas=0.23.3 + - scipy=1.1.0 + - matplotlib=2.2.2 + - seaborn=0.9.0 + - dask=0.18.2 + - ipython=6.4.0 + - netCDF4=1.4.0 + - cartopy=0.16.0 + - rasterio=1.0.1 + - zarr=2.2.0 + - iris=2.1.0 + - flake8=3.5.0 + - cftime=1.0.0 + - bottleneck=1.2 + - sphinx=1.7.6 + - numpydoc=0.8.0 + - sphinx-gallery=0.2.0 diff --git a/doc/faq.rst b/doc/faq.rst index 170a1e17bdc..44bc021024b 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -119,7 +119,8 @@ conventions`_. (An exception is serialization to and from netCDF files.) An implication of this choice is that we do not propagate ``attrs`` through most operations unless explicitly flagged (some methods have a ``keep_attrs`` -option). Similarly, xarray does not check for conflicts between ``attrs`` when +option, and there is a global flag for setting this to be always True or +False). Similarly, xarray does not check for conflicts between ``attrs`` when combining arrays and datasets, unless explicitly requested with the option ``compat='identical'``. The guiding principle is that metadata should not be allowed to get in the way. @@ -160,70 +161,10 @@ methods for converting back and forth between xarray and these libraries. See :py:meth:`~xarray.DataArray.to_iris` and :py:meth:`~xarray.DataArray.to_cdms2` for more details. -.. _faq.other_projects: - What other projects leverage xarray? ------------------------------------ -Here are several existing libraries that build functionality upon xarray. - -Geosciences -~~~~~~~~~~~ - -- `aospy `_: Automated analysis and management of gridded climate data. -- `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meterology data -- `marc_analysis `_: Analysis package for CESM/MARC experiments and output. -- `MPAS-Analysis `_: Analysis for simulations produced with Model for Prediction Across Scales (MPAS) components and the Accelerated Climate Model for Energy (ACME). -- `OGGM `_: Open Global Glacier Model -- `Oocgcm `_: Analysis of large gridded geophysical datasets -- `Open Data Cube `_: Analysis toolkit of continental scale Earth Observation data from satellites. -- `Pangaea: `_: xarray extension for gridded land surface & weather model output). -- `Pangeo `_: A community effort for big data geoscience in the cloud. -- `PyGDX `_: Python 3 package for - accessing data stored in GAMS Data eXchange (GDX) files. Also uses a custom - subclass. -- `Regionmask `_: plotting and creation of masks of spatial regions -- `salem `_: Adds geolocalised subsetting, masking, and plotting operations to xarray's data structures via accessors. -- `Spyfit `_: FTIR spectroscopy of the atmosphere -- `windspharm `_: Spherical - harmonic wind analysis in Python. -- `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model. -- `xarray-simlab `_: xarray extension for computer model simulations. -- `xarray-topo `_: xarray extension for topographic analysis and modelling. -- `xbpch `_: xarray interface for bpch files. -- `xESMF `_: Universal Regridder for Geospatial Data. -- `xgcm `_: Extends the xarray data model to understand finite volume grid cells (common in General Circulation Models) and provides interpolation and difference operations for such grids. -- `xmitgcm `_: a python package for reading `MITgcm `_ binary MDS files into xarray data structures. -- `xshape `_: Tools for working with shapefiles, topographies, and polygons in xarray. - -Machine Learning -~~~~~~~~~~~~~~~~ -- `cesium `_: machine learning for time series analysis -- `Elm `_: Parallel machine learning on xarray data structures -- `sklearn-xarray (1) `_: Combines scikit-learn and xarray (1). -- `sklearn-xarray (2) `_: Combines scikit-learn and xarray (2). - -Extend xarray capabilities -~~~~~~~~~~~~~~~~~~~~~~~~~~ -- `Collocate `_: Collocate xarray trajectories in arbitrary physical dimensions -- `eofs `_: EOF analysis in Python. -- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. intergrations/interpolations). -- `xrft `_: Fourier transforms for xarray data. -- `xr-scipy `_: A lightweight scipy wrapper for xarray. -- `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library. -- `xyzpy `_: Easily generate high dimensional data, including parallelization. - -Visualization -~~~~~~~~~~~~~ -- `Datashader `_, `geoviews `_, `holoviews `_, : visualization packages for large data -- `psyplot `_: Interactive data visualization with python. - -Other -~~~~~ -- `ptsa `_: EEG Time Series Analysis -- `pycalphad `_: Computational Thermodynamics in Python - -More projects can be found at the `"xarray" Github topic `_. +See section :ref:`related-projects`. How should I cite xarray? ------------------------- diff --git a/doc/gallery/README.txt b/doc/gallery/README.txt index 242c4f7dc91..b17f803696b 100644 --- a/doc/gallery/README.txt +++ b/doc/gallery/README.txt @@ -1,5 +1,5 @@ .. _recipes: -Recipes +Gallery ======= diff --git a/doc/groupby.rst b/doc/groupby.rst index 4851cbe5dcc..6e42dbbc9f0 100644 --- a/doc/groupby.rst +++ b/doc/groupby.rst @@ -207,3 +207,12 @@ may be desirable: .. ipython:: python da.groupby_bins('lon', [0,45,50]).sum() + +These methods group by `lon` values. It is also possible to groupby each +cell in a grid, regardless of value, by stacking multiple dimensions, +applying your function, and then unstacking the result: + +.. ipython:: python + + stacked = da.stack(gridcell=['ny', 'nx']) + stacked.groupby('gridcell').sum().unstack('gridcell') diff --git a/doc/index.rst b/doc/index.rst index 7528f3cb1fa..45897f4bccb 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -74,7 +74,9 @@ Documentation * :doc:`whats-new` * :doc:`api` * :doc:`internals` +* :doc:`roadmap` * :doc:`contributing` +* :doc:`related-projects` .. toctree:: :maxdepth: 1 @@ -84,7 +86,9 @@ Documentation whats-new api internals + roadmap contributing + related-projects See also -------- @@ -116,12 +120,20 @@ Get in touch .. _mailing list: https://groups.google.com/forum/#!forum/xarray .. _on GitHub: http://github.com/pydata/xarray -License -------- +NumFOCUS +-------- -xarray is available under the open source `Apache License`__. +.. image:: _static/numfocus_logo.png + :scale: 50 % + :target: https://numfocus.org/ + +Xarray is a fiscally sponsored project of NumFOCUS_, a nonprofit dedicated +to supporting the open source scientific computing community. If you like +Xarray and want to support our mission, please consider making a donation_ +to support our efforts. + +.. _donation: https://www.flipcause.com/secure/cause_pdetails/NDE2NTU= -__ http://www.apache.org/licenses/LICENSE-2.0.html History ------- @@ -129,6 +141,15 @@ History xarray is an evolution of an internal tool developed at `The Climate Corporation`__. It was originally written by Climate Corp researchers Stephan Hoyer, Alex Kleeman and Eugene Brevdo and was released as open source in -May 2014. The project was renamed from "xray" in January 2016. +May 2014. The project was renamed from "xray" in January 2016. Xarray became a +fiscally sponsored project of NumFOCUS_ in August 2018. __ http://climate.com/ +.. _NumFOCUS: https://numfocus.org + +License +------- + +xarray is available under the open source `Apache License`__. + +__ http://www.apache.org/licenses/LICENSE-2.0.html diff --git a/doc/indexing.rst b/doc/indexing.rst index c05bf9994fc..3878d983cf6 100644 --- a/doc/indexing.rst +++ b/doc/indexing.rst @@ -411,7 +411,7 @@ can use indexing with ``.loc`` : .. ipython:: python - ds = xr.tutorial.load_dataset('air_temperature') + ds = xr.tutorial.open_dataset('air_temperature') #add an empty 2D dataarray ds['empty']= xr.full_like(ds.air.mean('time'),fill_value=0) diff --git a/doc/installing.rst b/doc/installing.rst index b3154c3d8bb..64751eea637 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -6,7 +6,7 @@ Installation Required dependencies --------------------- -- Python 2.7 [1]_, 3.5, or 3.6 +- Python 2.7 [1]_, 3.5, 3.6, or 3.7 - `numpy `__ (1.12 or later) - `pandas `__ (0.19.2 or later) @@ -31,6 +31,12 @@ For netCDF and IO - `PseudoNetCDF `__: recommended for accessing CAMx, GEOS-Chem (bpch), NOAA ARL files, ICARTT files (ffi1001) and many other. +- `rasterio `__: for reading GeoTiffs and + other gridded raster datasets. +- `iris `__: for conversion to and from iris' + Cube objects +- `cfgrib `__: for reading GRIB files via the + *ECMWF ecCodes* library. For accelerating xarray ~~~~~~~~~~~~~~~~~~~~~~~ @@ -101,6 +107,7 @@ A fixed-point performance monitoring of (a part of) our codes can be seen on `this page `__. To run these benchmark tests in a local machine, first install + - `airspeed-velocity `__: a tool for benchmarking Python packages over their lifetime. and run diff --git a/doc/interpolation.rst b/doc/interpolation.rst index cd1c078fb2d..71e88079676 100644 --- a/doc/interpolation.rst +++ b/doc/interpolation.rst @@ -48,6 +48,24 @@ array-like, which gives the interpolated result as an array. # interpolation da.interp(time=[2.5, 3.5]) +To interpolate data with a :py:func:`numpy.datetime64` coordinate you can pass a string. + +.. ipython:: python + + da_dt64 = xr.DataArray([1, 3], + [('time', pd.date_range('1/1/2000', '1/3/2000', periods=2))]) + da_dt64.interp(time='2000-01-02') + +The interpolated data can be merged into the original :py:class:`~xarray.DataArray` +by specifing the time periods required. + +.. ipython:: python + + da_dt64.interp(time=pd.date_range('1/1/2000', '1/3/2000', periods=3)) + +Interpolation of data indexed by a :py:class:`~xarray.CFTimeIndex` is also +allowed. See :ref:`CFTimeIndex` for examples. + .. note:: Currently, our interpolation only works for regular grids. @@ -244,7 +262,7 @@ Let's see how :py:meth:`~xarray.DataArray.interp` works on real data. .. ipython:: python # Raw data - ds = xr.tutorial.load_dataset('air_temperature').isel(time=0) + ds = xr.tutorial.open_dataset('air_temperature').isel(time=0) fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) ds.air.plot(ax=axes[0]) axes[0].set_title('Raw data') diff --git a/doc/io.rst b/doc/io.rst index 093ee773e15..e841e665308 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -635,6 +635,28 @@ For example: Not all native zarr compression and filtering options have been tested with xarray. +.. _io.cfgrib: + +GRIB format via cfgrib +---------------------- + +xarray supports reading GRIB files via ECMWF cfgrib_ python driver and ecCodes_ +C-library, if they are installed. To open a GRIB file supply ``engine='cfgrib'`` +to :py:func:`~xarray.open_dataset`: + +.. ipython:: + :verbatim: + + In [1]: ds_grib = xr.open_dataset('example.grib', engine='cfgrib') + +We recommend installing ecCodes via conda:: + + conda install -c conda-forge eccodes + pip install cfgrib + +.. _cfgrib: https://github.com/ecmwf/cfgrib +.. _ecCodes: https://confluence.ecmwf.int/display/ECC/ecCodes+Home + .. _io.pynio: Formats supported by PyNIO diff --git a/doc/plotting.rst b/doc/plotting.rst index 54fa2f57ac8..f8ba82febb0 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -60,7 +60,7 @@ For these examples we'll use the North American air temperature dataset. .. ipython:: python - airtemps = xr.tutorial.load_dataset('air_temperature') + airtemps = xr.tutorial.open_dataset('air_temperature') airtemps # Convert to celsius @@ -212,8 +212,6 @@ If required, the automatic legend can be turned off using ``add_legend=False``. ``hue`` can be passed directly to :py:func:`xarray.plot` as `air.isel(lon=10, lat=[19,21,22]).plot(hue='lat')`. - - Dimension along y-axis ~~~~~~~~~~~~~~~~~~~~~~ @@ -224,8 +222,40 @@ It is also possible to make line plots such that the data are on the x-axis and @savefig plotting_example_xy_kwarg.png air.isel(time=10, lon=[10, 11]).plot(y='lat', hue='lon') -Changing Axes Direction ------------------------ +Step plots +~~~~~~~~~~ + +As an alternative, also a step plot similar to matplotlib's ``plt.step`` can be +made using 1D data. + +.. ipython:: python + + @savefig plotting_example_step.png width=4in + air1d[:20].plot.step(where='mid') + +The argument ``where`` defines where the steps should be placed, options are +``'pre'`` (default), ``'post'``, and ``'mid'``. This is particularly handy +when plotting data grouped with :py:func:`xarray.Dataset.groupby_bins`. + +.. ipython:: python + + air_grp = air.mean(['time','lon']).groupby_bins('lat',[0,23.5,66.5,90]) + air_mean = air_grp.mean() + air_std = air_grp.std() + air_mean.plot.step() + (air_mean + air_std).plot.step(ls=':') + (air_mean - air_std).plot.step(ls=':') + plt.ylim(-20,30) + @savefig plotting_example_step_groupby.png width=4in + plt.title('Zonal mean temperature') + +In this case, the actual boundaries of the bins are used and the ``where`` argument +is ignored. + + +Other axes kwargs +----------------- + The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes direction. @@ -234,6 +264,9 @@ The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes d @savefig plotting_example_xincrease_yincrease_kwarg.png air.isel(time=10, lon=[10, 11]).plot.line(y='lat', hue='lon', xincrease=False, yincrease=False) +In addition, one can use ``xscale, yscale`` to set axes scaling; ``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits. These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``, ``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively. + + Two Dimensions -------------- @@ -494,7 +527,8 @@ Faceted plotting supports other arguments common to xarray 2d plots. @savefig plot_facet_robust.png g = hasoutliers.plot.pcolormesh('lon', 'lat', col='time', col_wrap=3, - robust=True, cmap='viridis') + robust=True, cmap='viridis', + cbar_kwargs={'label': 'this has outliers'}) FacetGrid Objects ~~~~~~~~~~~~~~~~~ @@ -551,7 +585,7 @@ This script will plot the air temperature on a map. .. ipython:: python import cartopy.crs as ccrs - air = xr.tutorial.load_dataset('air_temperature').air + air = xr.tutorial.open_dataset('air_temperature').air ax = plt.axes(projection=ccrs.Orthographic(-80, 35)) air.isel(time=0).plot.contourf(ax=ax, transform=ccrs.PlateCarree()); @savefig plotting_maps_cartopy.png width=100% @@ -701,3 +735,12 @@ You can however decide to infer the cell boundaries and use the outside the xarray framework. .. _cell boundaries: http://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#cell-boundaries + +One can also make line plots with multidimensional coordinates. In this case, ``hue`` must be a dimension name, not a coordinate name. + +.. ipython:: python + + f, ax = plt.subplots(2, 1) + da.plot.line(x='lon', hue='y', ax=ax[0]); + @savefig plotting_example_2d_hue_xy.png + da.plot.line(x='lon', hue='x', ax=ax[1]); diff --git a/doc/related-projects.rst b/doc/related-projects.rst new file mode 100644 index 00000000000..cf89c715bc7 --- /dev/null +++ b/doc/related-projects.rst @@ -0,0 +1,69 @@ +.. _related-projects: + +Xarray related projects +----------------------- + +Here below is a list of several existing libraries that build +functionality upon xarray. See also section :ref:`internals` for more +details on how to build xarray extensions. + +Geosciences +~~~~~~~~~~~ + +- `aospy `_: Automated analysis and management of gridded climate data. +- `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meterology data +- `marc_analysis `_: Analysis package for CESM/MARC experiments and output. +- `MPAS-Analysis `_: Analysis for simulations produced with Model for Prediction Across Scales (MPAS) components and the Accelerated Climate Model for Energy (ACME). +- `OGGM `_: Open Global Glacier Model +- `Oocgcm `_: Analysis of large gridded geophysical datasets +- `Open Data Cube `_: Analysis toolkit of continental scale Earth Observation data from satellites. +- `Pangaea: `_: xarray extension for gridded land surface & weather model output). +- `Pangeo `_: A community effort for big data geoscience in the cloud. +- `PyGDX `_: Python 3 package for + accessing data stored in GAMS Data eXchange (GDX) files. Also uses a custom + subclass. +- `Regionmask `_: plotting and creation of masks of spatial regions +- `salem `_: Adds geolocalised subsetting, masking, and plotting operations to xarray's data structures via accessors. +- `SatPy `_ : Library for reading and manipulating meteorological remote sensing data and writing it to various image and data file formats. +- `Spyfit `_: FTIR spectroscopy of the atmosphere +- `windspharm `_: Spherical + harmonic wind analysis in Python. +- `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model. +- `xarray-simlab `_: xarray extension for computer model simulations. +- `xarray-topo `_: xarray extension for topographic analysis and modelling. +- `xbpch `_: xarray interface for bpch files. +- `xESMF `_: Universal Regridder for Geospatial Data. +- `xgcm `_: Extends the xarray data model to understand finite volume grid cells (common in General Circulation Models) and provides interpolation and difference operations for such grids. +- `xmitgcm `_: a python package for reading `MITgcm `_ binary MDS files into xarray data structures. +- `xshape `_: Tools for working with shapefiles, topographies, and polygons in xarray. + +Machine Learning +~~~~~~~~~~~~~~~~ +- `cesium `_: machine learning for time series analysis +- `Elm `_: Parallel machine learning on xarray data structures +- `sklearn-xarray (1) `_: Combines scikit-learn and xarray (1). +- `sklearn-xarray (2) `_: Combines scikit-learn and xarray (2). + +Extend xarray capabilities +~~~~~~~~~~~~~~~~~~~~~~~~~~ +- `Collocate `_: Collocate xarray trajectories in arbitrary physical dimensions +- `eofs `_: EOF analysis in Python. +- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations). +- `xrft `_: Fourier transforms for xarray data. +- `xr-scipy `_: A lightweight scipy wrapper for xarray. +- `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library. +- `xskillscore `_: Metrics for verifying forecasts. +- `xyzpy `_: Easily generate high dimensional data, including parallelization. + +Visualization +~~~~~~~~~~~~~ +- `Datashader `_, `geoviews `_, `holoviews `_, : visualization packages for large data. +- `hvplot `_ : A high-level plotting API for the PyData ecosystem built on HoloViews. +- `psyplot `_: Interactive data visualization with python. + +Other +~~~~~ +- `ptsa `_: EEG Time Series Analysis +- `pycalphad `_: Computational Thermodynamics in Python + +More projects can be found at the `"xarray" Github topic `_. diff --git a/doc/roadmap.rst b/doc/roadmap.rst new file mode 100644 index 00000000000..34d203c3f48 --- /dev/null +++ b/doc/roadmap.rst @@ -0,0 +1,227 @@ +.. _roadmap: + +Development roadmap +=================== + +Authors: Stephan Hoyer, Joe Hamman and xarray developers + +Date: July 24, 2018 + +Xarray is an open source Python library for labeled multidimensional +arrays and datasets. + +Our philosophy +-------------- + +Why has xarray been successful? In our opinion: + +- Xarray does a great job of solving **specific use-cases** for + multidimensional data analysis: + + - The dominant use-case for xarray is for analysis of gridded + dataset in the geosciences, e.g., as part of the + `Pangeo `__ project. + - Xarray is also used more broadly in the physical sciences, where + we've found the needs for analyzing multidimensional datasets are + remarkably consistent (e.g., see + `SunPy `__ and + `PlasmaPy `__). + - Finally, xarray is used in a variety of other domains, including + finance, `probabilistic + programming `__ and + genomics. + +- Xarray is also a **domain agnostic** solution: + + - We focus on providing a flexible set of functionality related + labeled multidimensional arrays, rather than solving particular + problems. + - This facilitates collaboration between users with different needs, + and helps us attract a broad community of contributers. + - Importantly, this retains flexibility, for use cases that don't + fit particularly well into existing frameworks. + +- Xarray **integrates well** with other libraries in the scientific + Python stack. + + - We leverage first-class external libraries for core features of + xarray (e.g., NumPy for ndarrays, pandas for indexing, dask for + parallel computing) + - We expose our internal abstractions to users (e.g., + ``apply_ufunc()``), which facilitates extending xarray in various + ways. + +Together, these features have made xarray a first-class choice for +labeled multidimensional arrays in Python. + +We want to double-down on xarray's strengths by making it an even more +flexible and powerful tool for multidimensional data analysis. We want +to continue to engage xarray's core geoscience users, and to also reach +out to new domains to learn from other successful data models like those +of `yt `__ or the `OLAP +cube `__. + +Specific needs +-------------- + +The user community has voiced a number specific needs related to how +xarray interfaces with domain specific problems. Xarray may not solve +all of these issues directly, but these areas provide opportunities for +xarray to provide better, more extensible, interfaces. Some examples of +these common needs are: + +- Non-regular grids (e.g., staggered and unstructured meshes). +- Physical units. +- Lazily computed arrays (e.g., for coordinate systems). +- New file-formats. + +Technical vision +---------------- + +We think the right approach to extending xarray's user community and the +usefulness of the project is to focus on improving key interfaces that +can be used externally to meet domain-specific needs. + +We can generalize the community's needs into three main catagories: + +- More flexible grids/indexing. +- More flexible arrays/computing. +- More flexible storage backends. + +Each of these are detailed further in the subsections below. + +Flexible indexes +~~~~~~~~~~~~~~~~ + +Xarray currently keeps track of indexes associated with coordinates by +storing them in the form of a ``pandas.Index`` in special +``xarray.IndexVariable`` objects. + +The limitations of this model became clear with the addition of +``pandas.MultiIndex`` support in xarray 0.9, where a single index +corresponds to multiple xarray variables. MultiIndex support is highly +useful, but xarray now has numerous special cases to check for +MultiIndex levels. + +A cleaner model would be to elevate ``indexes`` to an explicit part of +xarray's data model, e.g., as attributes on the ``Dataset`` and +``DataArray`` classes. Indexes would need to be propagated along with +coordinates in xarray operations, but will no longer would need to have +a one-to-one correspondance with coordinate variables. Instead, an index +should be able to refer to multiple (possibly multidimensional) +coordinates that define it. See `GH +1603 `__ for full details + +Specific tasks: + +- Add an ``indexes`` attribute to ``xarray.Dataset`` and + ``xarray.Dataset``, as dictionaries that map from coordinate names to + xarray index objects. +- Use the new index interface to write wrappers for ``pandas.Index``, + ``pandas.MultiIndex`` and ``scipy.spatial.KDTree``. +- Expose the interface externally to allow third-party libraries to + implement custom indexing routines, e.g., for geospatial look-ups on + the surface of the Earth. + +In addition to the new features it directly enables, this clean up will +allow xarray to more easily implement some long-awaited features that +build upon indexing, such as groupby operations with multiple variables. + +Flexible arrays +~~~~~~~~~~~~~~~ + +Xarray currently supports wrapping multidimensional arrays defined by +NumPy, dask and to a limited-extent pandas. It would be nice to have +interfaces that allow xarray to wrap alternative N-D array +implementations, e.g.: + +- Arrays holding physical units. +- Lazily computed arrays. +- Other ndarray objects, e.g., sparse, xnd, xtensor. + +Our strategy has been to pursue upstream improvements in NumPy (see +`NEP-22 `__) +for supporting a complete duck-typing interface using with NumPy's +higher level array API. Improvements in NumPy's support for custom data +types would also be highly useful for xarray users. + +By pursuing these improvements in NumPy we hope to extend the benefits +to the full scientific Python community, and avoid tight coupling +between xarray and specific third-party libraries (e.g., for +implementing untis). This will allow xarray to maintain its domain +agnostic strengths. + +We expect that we may eventually add some minimal interfaces in xarray +for features that we delegate to external array libraries (e.g., for +getting units and changing units). If we do add these features, we +expect them to be thin wrappers, with core functionality implemented by +third-party libraries. + +Flexible storage +~~~~~~~~~~~~~~~~ + +The xarray backends module has grown in size and complexity. Much of +this growth has been "organic" and mostly to support incremental +additions to the supported backends. This has left us with a fragile +internal API that is difficult for even experienced xarray developers to +use. Moreover, the lack of a public facing API for building xarray +backends means that users can not easily build backend interface for +xarray in third-party libraries. + +The idea of refactoring the backends API and exposing it to users was +originally proposed in `GH +1970 `__. The idea would +be to develop a well tested and generic backend base class and +associated utilities for external use. Specific tasks for this +development would include: + +- Exposing an abstract backend for writing new storage systems. +- Exposing utilities for features like automatic closing of files, + LRU-caching and explicit/lazy indexing. +- Possibly moving some infrequently used backends to third-party + packages. + +Engaging more users +------------------- + +Like many open-source projects, the documentation of xarray has grown +together with the library's features. While we think that the xarray +documentation is comprehensive already, we aknowledge that the adoption +of xarray might be slowed down because of the substantial time +investment required to learn its working principles. In particular, +non-computer scientists or users less familiar with the pydata ecosystem +might find it difficult to learn xarray and realize how xarray can help +them in their daily work. + +In order to lower this adoption barrier, we propose to: + +- Develop entry-level tutorials for users with different backgrounds. For + example, we would like to develop tutorials for users with or without + previous knowledge of pandas, numpy, netCDF, etc. These tutorials may be + built as part of xarray's documentation or included in a seperate repository + to enable interactive use (e.g. mybinder.org). +- Document typical user workflows in a dedicated website, following the example + of `dask-stories + `__. +- Write a basic glossary that defines terms that might not be familiar to all + (e.g. "lazy", "labeled", "serialization", "indexing", "backend"). + +Administrative +-------------- + +Current core developers +~~~~~~~~~~~~~~~~~~~~~~~ + +- Stephan Hoyer +- Ryan Abernathey +- Joe Hamman +- Benoit Bovy +- Fabien Maussion +- Keisuke Fujii +- Maximilian Roos + +NumFOCUS +~~~~~~~~ + +On July 16, 2018, Joe and Stephan submitted xarray's fiscal sponsorship +application to NumFOCUS. diff --git a/doc/time-series.rst b/doc/time-series.rst index a7ce9226d4d..c225c246a8c 100644 --- a/doc/time-series.rst +++ b/doc/time-series.rst @@ -70,11 +70,12 @@ You can manual decode arrays in this form by passing a dataset to One unfortunate limitation of using ``datetime64[ns]`` is that it limits the native representation of dates to those that fall between the years 1678 and 2262. When a netCDF file contains dates outside of these bounds, dates will be -returned as arrays of ``cftime.datetime`` objects and a ``CFTimeIndex`` -can be used for indexing. The ``CFTimeIndex`` enables only a subset of -the indexing functionality of a ``pandas.DatetimeIndex`` and is only enabled -when using the standalone version of ``cftime`` (not the version packaged with -earlier versions ``netCDF4``). See :ref:`CFTimeIndex` for more information. +returned as arrays of :py:class:`cftime.datetime` objects and a :py:class:`~xarray.CFTimeIndex` +will be used for indexing. :py:class:`~xarray.CFTimeIndex` enables a subset of +the indexing functionality of a :py:class:`pandas.DatetimeIndex` and is only +fully compatible with the standalone version of ``cftime`` (not the version +packaged with earlier versions ``netCDF4``). See :ref:`CFTimeIndex` for more +information. Datetime indexing ----------------- @@ -198,17 +199,6 @@ and ``interpolate``. ``interpolate`` extends ``scipy.interpolate.interp1d`` and supports all of its schemes. All of these resampling operations work on both Dataset and DataArray objects with an arbitrary number of dimensions. -.. note:: - - The ``resample`` api was updated in version 0.10.0 to reflect similar - updates in pandas ``resample`` api to be more groupby-like. Older style - calls to ``resample`` will still be supported for a short period: - - .. ipython:: python - - ds.resample('6H', dim='time', how='mean') - - For more examples of using grouped operations on a time dimension, see :ref:`toy weather data`. @@ -219,20 +209,30 @@ Non-standard calendars and dates outside the Timestamp-valid range ------------------------------------------------------------------ Through the standalone ``cftime`` library and a custom subclass of -``pandas.Index``, xarray supports a subset of the indexing functionality enabled -through the standard ``pandas.DatetimeIndex`` for dates from non-standard -calendars or dates using a standard calendar, but outside the -`Timestamp-valid range`_ (approximately between years 1678 and 2262). This -behavior has not yet been turned on by default; to take advantage of this -functionality, you must have the ``enable_cftimeindex`` option set to -``True`` within your context (see :py:func:`~xarray.set_options` for more -information). It is expected that this will become the default behavior in -xarray version 0.11. - -For instance, you can create a DataArray indexed by a time -coordinate with a no-leap calendar within a context manager setting the -``enable_cftimeindex`` option, and the time index will be cast to a -``CFTimeIndex``: +:py:class:`pandas.Index`, xarray supports a subset of the indexing +functionality enabled through the standard :py:class:`pandas.DatetimeIndex` for +dates from non-standard calendars commonly used in climate science or dates +using a standard calendar, but outside the `Timestamp-valid range`_ +(approximately between years 1678 and 2262). + +.. note:: + + As of xarray version 0.11, by default, :py:class:`cftime.datetime` objects + will be used to represent times (either in indexes, as a + :py:class:`~xarray.CFTimeIndex`, or in data arrays with dtype object) if + any of the following are true: + + - The dates are from a non-standard calendar + - Any dates are outside the Timestamp-valid range. + + Otherwise pandas-compatible dates from a standard calendar will be + represented with the ``np.datetime64[ns]`` data type, enabling the use of a + :py:class:`pandas.DatetimeIndex` or arrays with dtype ``np.datetime64[ns]`` + and their full set of associated features. + +For example, you can create a DataArray indexed by a time +coordinate with dates from a no-leap calendar and a +:py:class:`~xarray.CFTimeIndex` will automatically be used: .. ipython:: python @@ -241,25 +241,18 @@ coordinate with a no-leap calendar within a context manager setting the dates = [DatetimeNoLeap(year, month, 1) for year, month in product(range(1, 3), range(1, 13))] - with xr.set_options(enable_cftimeindex=True): - da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], - name='foo') + da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo') -.. note:: +xarray also includes a :py:func:`~xarray.cftime_range` function, which enables +creating a :py:class:`~xarray.CFTimeIndex` with regularly-spaced dates. For +instance, we can create the same dates and DataArray we created above using: - With the ``enable_cftimeindex`` option activated, a ``CFTimeIndex`` - will be used for time indexing if any of the following are true: - - - The dates are from a non-standard calendar - - Any dates are outside the Timestamp-valid range +.. ipython:: python - Otherwise a ``pandas.DatetimeIndex`` will be used. In addition, if any - variable (not just an index variable) is encoded using a non-standard - calendar, its times will be decoded into ``cftime.datetime`` objects, - regardless of whether or not they can be represented using - ``np.datetime64[ns]`` objects. - -For data indexed by a ``CFTimeIndex`` xarray currently supports: + dates = xr.cftime_range(start='0001', periods=24, freq='MS', calendar='noleap') + da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo') + +For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: - `Partial datetime string indexing`_ using strictly `ISO 8601-format`_ partial datetime strings: @@ -285,18 +278,65 @@ For data indexed by a ``CFTimeIndex`` xarray currently supports: .. ipython:: python da.groupby('time.month').sum() - + +- Interpolation using :py:class:`cftime.datetime` objects: + +.. ipython:: python + + da.interp(time=[DatetimeNoLeap(1, 1, 15), DatetimeNoLeap(1, 2, 15)]) + +- Interpolation using datetime strings: + +.. ipython:: python + + da.interp(time=['0001-01-15', '0001-02-15']) + +- Differentiation: + +.. ipython:: python + + da.differentiate('time') + - And serialization: .. ipython:: python - da.to_netcdf('example.nc') - xr.open_dataset('example.nc') + da.to_netcdf('example-no-leap.nc') + xr.open_dataset('example-no-leap.nc') .. note:: - Currently resampling along the time dimension for data indexed by a - ``CFTimeIndex`` is not supported. + While much of the time series functionality that is possible for standard + dates has been implemented for dates from non-standard calendars, there are + still some remaining important features that have yet to be implemented, + for example: + + - Resampling along the time dimension for data indexed by a + :py:class:`~xarray.CFTimeIndex` (:issue:`2191`, :issue:`2458`) + - Built-in plotting of data with :py:class:`cftime.datetime` coordinate axes + (:issue:`2164`). + + For some use-cases it may still be useful to convert from + a :py:class:`~xarray.CFTimeIndex` to a :py:class:`pandas.DatetimeIndex`, + despite the difference in calendar types (e.g. to allow the use of some + forms of resample with non-standard calendars). The recommended way of + doing this is to use the built-in + :py:meth:`~xarray.CFTimeIndex.to_datetimeindex` method: + + .. ipython:: python + + modern_times = xr.cftime_range('2000', periods=24, freq='MS', calendar='noleap') + da = xr.DataArray(range(24), [('time', modern_times)]) + da + datetimeindex = da.indexes['time'].to_datetimeindex() + da['time'] = datetimeindex + da.resample(time='Y').mean('time') + + However in this case one should use caution to only perform operations which + do not depend on differences between dates (e.g. differentiation, + interpolation, or upsampling with resample), as these could introduce subtle + and silent errors due to the difference in calendar types between the dates + encoded in your data and the dates stored in memory. .. _Timestamp-valid range: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#timestamp-limitations .. _ISO 8601-format: https://en.wikipedia.org/wiki/ISO_8601 diff --git a/doc/whats-new.rst b/doc/whats-new.rst index af90ea7f9d3..1da1da700e7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,13 +25,302 @@ What's New - `Python 3 Statement `__ - `Tips on porting to Python 3 `__ -.. _whats-new.0.10.8: +.. _whats-new.0.11.1: -v0.10.8 (unreleased) +v0.11.1 (unreleased) -------------------- -Documentation -~~~~~~~~~~~~~ +Breaking changes +~~~~~~~~~~~~~~~~ + +Enhancements +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + +.. _whats-new.0.11.0: + +v0.11.0 (7 November 2018) +------------------------- + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Finished deprecations (changed behavior with this release): + + - ``Dataset.T`` has been removed as a shortcut for :py:meth:`Dataset.transpose`. + Call :py:meth:`Dataset.transpose` directly instead. + - Iterating over a ``Dataset`` now includes only data variables, not coordinates. + Similarily, calling ``len`` and ``bool`` on a ``Dataset`` now + includes only data variables. + - ``DataArray.__contains__`` (used by Python's ``in`` operator) now checks + array data, not coordinates. + - The old resample syntax from before xarray 0.10, e.g., + ``data.resample('1D', dim='time', how='mean')``, is no longer supported will + raise an error in most cases. You need to use the new resample syntax + instead, e.g., ``data.resample(time='1D').mean()`` or + ``data.resample({'time': '1D'}).mean()``. + + +- New deprecations (behavior will be changed in xarray 0.12): + + - Reduction of :py:meth:`DataArray.groupby` and :py:meth:`DataArray.resample` + without dimension argument will change in the next release. + Now we warn a FutureWarning. + By `Keisuke Fujii `_. + - The ``inplace`` kwarg of a number of `DataArray` and `Dataset` methods is being + deprecated and will be removed in the next release. + By `Deepak Cherian `_. + + +- Refactored storage backends: + + - Xarray's storage backends now automatically open and close files when + necessary, rather than requiring opening a file with ``autoclose=True``. A + global least-recently-used cache is used to store open files; the default + limit of 128 open files should suffice in most cases, but can be adjusted if + necessary with + ``xarray.set_options(file_cache_maxsize=...)``. The ``autoclose`` argument + to ``open_dataset`` and related functions has been deprecated and is now a + no-op. + + This change, along with an internal refactor of xarray's storage backends, + should significantly improve performance when reading and writing + netCDF files with Dask, especially when working with many files or using + Dask Distributed. By `Stephan Hoyer `_ + + +- Support for non-standard calendars used in climate science: + + - Xarray will now always use :py:class:`cftime.datetime` objects, rather + than by default trying to coerce them into ``np.datetime64[ns]`` objects. + A :py:class:`~xarray.CFTimeIndex` will be used for indexing along time + coordinates in these cases. + - A new method :py:meth:`~xarray.CFTimeIndex.to_datetimeindex` has been added + to aid in converting from a :py:class:`~xarray.CFTimeIndex` to a + :py:class:`pandas.DatetimeIndex` for the remaining use-cases where + using a :py:class:`~xarray.CFTimeIndex` is still a limitation (e.g. for + resample or plotting). + - Setting the ``enable_cftimeindex`` option is now a no-op and emits a + ``FutureWarning``. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`xarray.DataArray.plot.line` can now accept multidimensional + coordinate variables as input. `hue` must be a dimension name in this case. + (:issue:`2407`) + By `Deepak Cherian `_. +- Added support for Python 3.7. (:issue:`2271`). + By `Joe Hamman `_. +- Added support for plotting data with `pandas.Interval` coordinates, such as those + created by :py:meth:`~xarray.DataArray.groupby_bins` + By `Maximilian Maahn `_. +- Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a + CFTimeIndex by a specified frequency. (:issue:`2244`). + By `Spencer Clark `_. +- Added support for using ``cftime.datetime`` coordinates with + :py:meth:`~xarray.DataArray.differentiate`, + :py:meth:`~xarray.Dataset.differentiate`, + :py:meth:`~xarray.DataArray.interp`, and + :py:meth:`~xarray.Dataset.interp`. + By `Spencer Clark `_ +- There is now a global option to either always keep or always discard + dataset and dataarray attrs upon operations. The option is set with + ``xarray.set_options(keep_attrs=True)``, and the default is to use the old + behaviour. + By `Tom Nicholas `_. +- Added a new backend for the GRIB file format based on ECMWF *cfgrib* + python driver and *ecCodes* C-library. (:issue:`2475`) + By `Alessandro Amici `_, + sponsored by `ECMWF `_. +- Resample now supports a dictionary mapping from dimension to frequency as + its first argument, e.g., ``data.resample({'time': '1D'}).mean()``. This is + consistent with other xarray functions that accept either dictionaries or + keyword arguments. By `Stephan Hoyer `_. + +- The preferred way to access tutorial data is now to load it lazily with + :py:meth:`xarray.tutorial.open_dataset`. + :py:meth:`xarray.tutorial.load_dataset` calls `Dataset.load()` prior + to returning (and is now deprecated). This was changed in order to facilitate + using tutorial datasets with dask. + By `Joe Hamman `_. + +Bug fixes +~~~~~~~~~ + +- ``FacetGrid`` now properly uses the ``cbar_kwargs`` keyword argument. + (:issue:`1504`, :issue:`1717`) + By `Deepak Cherian `_. +- Addition and subtraction operators used with a CFTimeIndex now preserve the + index's type. (:issue:`2244`). + By `Spencer Clark `_. +- We now properly handle arrays of ``datetime.datetime`` and ``datetime.timedelta`` + provided as coordinates. (:issue:`2512`) + By `Deepak Cherian `_. +- ``xarray.plot()`` now properly accepts a ``norm`` argument and does not override + the norm's ``vmin`` and ``vmax``. (:issue:`2381`) + By `Deepak Cherian `_. +- ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument. + (:issue:`2240`) + By `Keisuke Fujii `_. +- Restore matplotlib's default of plotting dashed negative contours when + a single color is passed to ``DataArray.contour()`` e.g. ``colors='k'``. + By `Deepak Cherian `_. + + +- Fix a bug that caused some indexing operations on arrays opened with + ``open_rasterio`` to error (:issue:`2454`). + By `Stephan Hoyer `_. + +- Subtracting one CFTimeIndex from another now returns a + ``pandas.TimedeltaIndex``, analogous to the behavior for DatetimeIndexes + (:issue:`2484`). By `Spencer Clark `_. +- Adding a TimedeltaIndex to, or subtracting a TimedeltaIndex from a + CFTimeIndex is now allowed (:issue:`2484`). + By `Spencer Clark `_. +- Avoid use of Dask's deprecated ``get=`` parameter in tests + by `Matthew Rocklin `_. +- An ``OverflowError`` is now accurately raised and caught during the + encoding process if a reference date is used that is so distant that + the dates must be encoded using cftime rather than NumPy (:issue:`2272`). + By `Spencer Clark `_. + +- Chunked datasets can now roundtrip to Zarr storage continually + with `to_zarr` and ``open_zarr`` (:issue:`2300`). + By `Lily Wang `_. + +.. _whats-new.0.10.9: + +v0.10.9 (21 September 2018) +--------------------------- + +This minor release contains a number of backwards compatible enhancements. + +Announcements of note: + +- Xarray is now a NumFOCUS fiscally sponsored project! Read + `the anouncement `_ + for more details. +- We have a new :doc:`roadmap` that outlines our future development plans. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`~xarray.DataArray.differentiate` and + :py:meth:`~xarray.Dataset.differentiate` are newly added. + (:issue:`1332`) + By `Keisuke Fujii `_. +- Default colormap for sequential and divergent data can now be set via + :py:func:`~xarray.set_options()` + (:issue:`2394`) + By `Julius Busecke `_. + +- min_count option is newly supported in :py:meth:`~xarray.DataArray.sum`, + :py:meth:`~xarray.DataArray.prod` and :py:meth:`~xarray.Dataset.sum`, and + :py:meth:`~xarray.Dataset.prod`. + (:issue:`2230`) + By `Keisuke Fujii `_. + +- :py:meth:`plot()` now accepts the kwargs + ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits. + By `Deepak Cherian `_. (:issue:`2224`) + +- DataArray coordinates and Dataset coordinates and data variables are + now displayed as `a b ... y z` rather than `a b c d ...`. + (:issue:`1186`) + By `Seth P `_. +- A new CFTimeIndex-enabled :py:func:`cftime_range` function for use in + generating dates from standard or non-standard calendars. By `Spencer Clark + `_. + +- When interpolating over a ``datetime64`` axis, you can now provide a datetime string instead of a ``datetime64`` object. E.g. ``da.interp(time='1991-02-01')`` + (:issue:`2284`) + By `Deepak Cherian `_. + +- A clear error message is now displayed if a ``set`` or ``dict`` is passed in place of an array + (:issue:`2331`) + By `Maximilian Roos `_. + +- Applying ``unstack`` to a large DataArray or Dataset is now much faster if the MultiIndex has not been modified after stacking the indices. + (:issue:`1560`) + By `Maximilian Maahn `_. + +- You can now control whether or not to offset the coordinates when using + the ``roll`` method and the current behavior, coordinates rolled by default, + raises a deprecation warning unless explicitly setting the keyword argument. + (:issue:`1875`) + By `Andrew Huang `_. + +- You can now call ``unstack`` without arguments to unstack every MultiIndex in a DataArray or Dataset. + By `Julia Signell `_. + +- Added the ability to pass a data kwarg to ``copy`` to create a new object with the + same metadata as the original object but using new values. + By `Julia Signell `_. + +Bug fixes +~~~~~~~~~ + +- ``xarray.plot.imshow()`` correctly uses the ``origin`` argument. + (:issue:`2379`) + By `Deepak Cherian `_. + +- Fixed ``DataArray.to_iris()`` failure while creating ``DimCoord`` by + falling back to creating ``AuxCoord``. Fixed dependency on ``var_name`` + attribute being set. + (:issue:`2201`) + By `Thomas Voigt `_. +- Fixed a bug in ``zarr`` backend which prevented use with datasets with + invalid chunk size encoding after reading from an existing store + (:issue:`2278`). + By `Joe Hamman `_. + +- Tests can be run in parallel with pytest-xdist + By `Tony Tung `_. + +- Follow up the renamings in dask; from dask.ghost to dask.overlap + By `Keisuke Fujii `_. + +- Now raises a ValueError when there is a conflict between dimension names and + level names of MultiIndex. (:issue:`2299`) + By `Keisuke Fujii `_. + +- Follow up the renamings in dask; from dask.ghost to dask.overlap + By `Keisuke Fujii `_. + +- Now :py:func:`xr.apply_ufunc` raises a ValueError when the size of + ``input_core_dims`` is inconsistent with the number of arguments. + (:issue:`2341`) + By `Keisuke Fujii `_. + +- Fixed ``Dataset.filter_by_attrs()`` behavior not matching ``netCDF4.Dataset.get_variables_by_attributes()``. + When more than one ``key=value`` is passed into ``Dataset.filter_by_attrs()`` it will now return a Dataset with variables which pass + all the filters. + (:issue:`2315`) + By `Andrew Barna `_. + +.. _whats-new.0.10.8: + +v0.10.8 (18 July 2018) +---------------------- + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Xarray no longer supports python 3.4. Additionally, the minimum supported + versions of the following dependencies has been updated and/or clarified: + + - Pandas: 0.18 -> 0.19 + - NumPy: 1.11 -> 1.12 + - Dask: 0.9 -> 0.16 + - Matplotlib: unspecified -> 1.5 + + (:issue:`2204`). By `Joe Hamman `_. Enhancements ~~~~~~~~~~~~ @@ -46,7 +335,6 @@ Enhancements :py:meth:`~xarray.DataArray.from_cdms2` (:issue:`2262`). By `Stephane Raynaud `_. - Bug fixes ~~~~~~~~~ @@ -66,18 +354,9 @@ Bug fixes weren't monotonic (:issue:`2250`). By `Fabien Maussion `_. -Breaking changes -~~~~~~~~~~~~~~~~ - -- Xarray no longer supports python 3.4. Additionally, the minimum supported - versions of the following dependencies has been updated and/or clarified: - - - Pandas: 0.18 -> 0.19 - - NumPy: 1.11 -> 1.12 - - Dask: 0.9 -> 0.16 - - Matplotlib: unspecified -> 1.5 - - (:issue:`2204`). By `Joe Hamman `_. +- Fixed warning raised in :py:meth:`~Dataset.to_netcdf` due to deprecation of + `effective_get` in dask (:issue:`2238`). + By `Joe Hamman `_. .. _whats-new.0.10.7: diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py index 8d84c0f6815..13f63f259cf 100644 --- a/properties/test_encode_decode.py +++ b/properties/test_encode_decode.py @@ -6,14 +6,15 @@ """ from __future__ import absolute_import, division, print_function -from hypothesis import given, settings -import hypothesis.strategies as st import hypothesis.extra.numpy as npst +import hypothesis.strategies as st +from hypothesis import given, settings import xarray as xr # Run for a while - arrays are a bigger search space than usual -settings.deadline = None +settings.register_profile("ci", deadline=None) +settings.load_profile("ci") an_array = npst.arrays( diff --git a/readthedocs.yml b/readthedocs.yml index 0129abe15aa..8e9c09c9414 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,5 +1,8 @@ +build: + image: latest conda: file: doc/environment.yml python: - version: 3 - setup_py_install: true + version: 3.6 + setup_py_install: true +formats: [] diff --git a/setup.py b/setup.py index e35611e01b1..3b56d9265af 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,8 @@ #!/usr/bin/env python import sys -from setuptools import find_packages, setup - import versioneer - +from setuptools import find_packages, setup DISTNAME = 'xarray' LICENSE = 'Apache' @@ -20,13 +18,13 @@ 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Scientific/Engineering', ] -INSTALL_REQUIRES = ['numpy >= 1.11', 'pandas >= 0.18.0'] +INSTALL_REQUIRES = ['numpy >= 1.12', 'pandas >= 0.19.2'] TESTS_REQUIRE = ['pytest >= 2.7.1'] if sys.version_info[0] < 3: TESTS_REQUIRE.append('mock') @@ -70,5 +68,6 @@ install_requires=INSTALL_REQUIRES, tests_require=TESTS_REQUIRE, url=URL, + python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*', packages=find_packages(), package_data={'xarray': ['tests/data/*']}) diff --git a/versioneer.py b/versioneer.py index 64fea1c8927..dffd66b69a6 100644 --- a/versioneer.py +++ b/versioneer.py @@ -277,10 +277,7 @@ """ from __future__ import print_function -try: - import configparser -except ImportError: - import ConfigParser as configparser + import errno import json import os @@ -288,6 +285,11 @@ import subprocess import sys +try: + import configparser +except ImportError: + import ConfigParser as configparser + class VersioneerConfig: """Container for Versioneer configuration parameters.""" diff --git a/xarray/__init__.py b/xarray/__init__.py index 7cc7811b783..59a961c6b56 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -10,7 +10,7 @@ from .core.alignment import align, broadcast, broadcast_arrays from .core.common import full_like, zeros_like, ones_like from .core.combine import concat, auto_combine -from .core.computation import apply_ufunc, where, dot +from .core.computation import apply_ufunc, dot, where from .core.extensions import (register_dataarray_accessor, register_dataset_accessor) from .core.variable import as_variable, Variable, IndexVariable, Coordinate @@ -26,6 +26,7 @@ from .conventions import decode_cf, SerializationWarning +from .coding.cftime_offsets import cftime_range from .coding.cftimeindex import CFTimeIndex from .util.print_versions import show_versions @@ -33,3 +34,5 @@ from . import tutorial from . import ufuncs from . import testing + +from .core.common import ALL_DIMS diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 47a2011a3af..9b9e04d9346 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -4,6 +4,8 @@ formats. They should not be used directly, but rather through Dataset objects. """ from .common import AbstractDataStore +from .file_manager import FileManager, CachingFileManager, DummyFileManager +from .cfgrib_ import CfGribDataStore from .memory import InMemoryDataStore from .netCDF4_ import NetCDF4DataStore from .pydap_ import PydapDataStore @@ -15,6 +17,10 @@ __all__ = [ 'AbstractDataStore', + 'FileManager', + 'CachingFileManager', + 'CfGribDataStore', + 'DummyFileManager', 'InMemoryDataStore', 'NetCDF4DataStore', 'PydapDataStore', diff --git a/xarray/backends/api.py b/xarray/backends/api.py index d5e2e8bbc2c..c1ace7774f9 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -4,6 +4,7 @@ from glob import glob from io import BytesIO from numbers import Number +import warnings import numpy as np @@ -11,37 +12,80 @@ from ..core import indexing from ..core.combine import auto_combine from ..core.pycompat import basestring, path_type -from ..core.utils import close_on_error, is_remote_uri -from .common import ( - HDF5_LOCK, ArrayWriter, CombinedLock, get_scheduler, get_scheduler_lock) +from ..core.utils import close_on_error, is_remote_uri, is_grib_path +from .common import ArrayWriter +from .locks import _get_scheduler + DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' -def _get_default_engine(path, allow_remote=False): - if allow_remote and is_remote_uri(path): # pragma: no cover +def _get_default_engine_remote_uri(): + try: + import netCDF4 + engine = 'netcdf4' + except ImportError: # pragma: no cover try: - import netCDF4 - engine = 'netcdf4' + import pydap # flake8: noqa + engine = 'pydap' except ImportError: - try: - import pydap # flake8: noqa - engine = 'pydap' - except ImportError: - raise ValueError('netCDF4 or pydap is required for accessing ' - 'remote datasets via OPeNDAP') + raise ValueError('netCDF4 or pydap is required for accessing ' + 'remote datasets via OPeNDAP') + return engine + + +def _get_default_engine_grib(): + msgs = [] + try: + import Nio # flake8: noqa + msgs += ["set engine='pynio' to access GRIB files with PyNIO"] + except ImportError: # pragma: no cover + pass + try: + import cfgrib # flake8: noqa + msgs += ["set engine='cfgrib' to access GRIB files with cfgrib"] + except ImportError: # pragma: no cover + pass + if msgs: + raise ValueError(' or\n'.join(msgs)) else: + raise ValueError('PyNIO or cfgrib is required for accessing ' + 'GRIB files') + + +def _get_default_engine_gz(): + try: + import scipy # flake8: noqa + engine = 'scipy' + except ImportError: # pragma: no cover + raise ValueError('scipy is required for accessing .gz files') + return engine + + +def _get_default_engine_netcdf(): + try: + import netCDF4 # flake8: noqa + engine = 'netcdf4' + except ImportError: # pragma: no cover try: - import netCDF4 # flake8: noqa - engine = 'netcdf4' - except ImportError: # pragma: no cover - try: - import scipy.io.netcdf # flake8: noqa - engine = 'scipy' - except ImportError: - raise ValueError('cannot read or write netCDF files without ' - 'netCDF4-python or scipy installed') + import scipy.io.netcdf # flake8: noqa + engine = 'scipy' + except ImportError: + raise ValueError('cannot read or write netCDF files without ' + 'netCDF4-python or scipy installed') + return engine + + +def _get_default_engine(path, allow_remote=False): + if allow_remote and is_remote_uri(path): + engine = _get_default_engine_remote_uri() + elif is_grib_path(path): + engine = _get_default_engine_grib() + elif path.endswith('.gz'): + engine = _get_default_engine_gz() + else: + engine = _get_default_engine_netcdf() return engine @@ -52,27 +96,6 @@ def _normalize_path(path): return os.path.abspath(os.path.expanduser(path)) -def _default_lock(filename, engine): - if filename.endswith('.gz'): - lock = False - else: - if engine is None: - engine = _get_default_engine(filename, allow_remote=True) - - if engine == 'netcdf4': - if is_remote_uri(filename): - lock = False - else: - # TODO: identify netcdf3 files and don't use the global lock - # for them - lock = HDF5_LOCK - elif engine in {'h5netcdf', 'pynio'}: - lock = HDF5_LOCK - else: - lock = False - return lock - - def _validate_dataset_names(dataset): """DataArray.name and Dataset keys must be a string or None""" def check_name(name): @@ -90,7 +113,7 @@ def check_name(name): def _validate_attrs(dataset): - """`attrs` must have a string key and a value which is either: a number + """`attrs` must have a string key and a value which is either: a number, a string, an ndarray or a list/tuple of numbers/strings. """ def check_attr(name, value): @@ -105,8 +128,8 @@ def check_attr(name, value): if not isinstance(value, (basestring, Number, np.ndarray, np.number, list, tuple)): - raise TypeError('Invalid value for attr: {} must be a number ' - 'string, ndarray or a list/tuple of ' + raise TypeError('Invalid value for attr: {} must be a number, ' + 'a string, an ndarray or a list/tuple of ' 'numbers/strings for serialization to netCDF ' 'files'.format(value)) @@ -130,29 +153,14 @@ def _protect_dataset_variables_inplace(dataset, cache): variable.data = data -def _get_lock(engine, scheduler, format, path_or_file): - """ Get the lock(s) that apply to a particular scheduler/engine/format""" - - locks = [] - if format in ['NETCDF4', None] and engine in ['h5netcdf', 'netcdf4']: - locks.append(HDF5_LOCK) - locks.append(get_scheduler_lock(scheduler, path_or_file)) - - # When we have more than one lock, use the CombinedLock wrapper class - lock = CombinedLock(locks) if len(locks) > 1 else locks[0] - - return lock - - def _finalize_store(write, store): """ Finalize this store by explicitly syncing and closing""" del write # ensure writing is done first - store.sync() store.close() def open_dataset(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None): @@ -179,7 +187,7 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, taken from variable attributes (if they exist). If the `_FillValue` or `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will - be replaced by NA. mask_and_scale defaults to True except for the + be replaced by NA. mask_and_scale defaults to True except for the pseudonetcdf backend. decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format @@ -196,7 +204,8 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, decode_coords : bool, optional If True, decode the 'coordinates' attribute to identify coordinates in the resulting dataset. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'pseudonetcdf'}, optional + engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'cfgrib', + 'pseudonetcdf'}, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. @@ -204,12 +213,11 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, If chunks is provided, it used to load the new dataset into dask arrays. ``chunks={}`` loads the dataset with dask using a single chunk for all arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -223,7 +231,7 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, inconsistent values. backend_kwargs: dictionary, optional A dictionary of keyword arguments to pass on to the backend. This - may be useful when backend options would improve performance or + may be useful when backend options would improve performance or allow user control of dataset processing. Returns @@ -235,7 +243,15 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, -------- open_mfdataset """ - + if autoclose is not None: + warnings.warn( + 'The autoclose argument is no longer used by ' + 'xarray.open_dataset() and is now ignored; it will be removed in ' + 'xarray v0.12. If necessary, you can control the maximum number ' + 'of simultaneous open files with ' + 'xarray.set_options(file_cache_maxsize=...).', + FutureWarning, stacklevel=2) + if mask_and_scale is None: mask_and_scale = not engine == 'pseudonetcdf' @@ -272,18 +288,11 @@ def maybe_decode_store(store, lock=False): mask_and_scale, decode_times, concat_characters, decode_coords, engine, chunks, drop_variables) name_prefix = 'open_dataset-%s' % token - ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token, - lock=lock) + ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) ds2._file_obj = ds._file_obj else: ds2 = ds - # protect so that dataset store isn't necessarily closed, e.g., - # streams like BytesIO can't be reopened - # datastore backend is responsible for determining this capability - if store._autoclose: - store.close() - return ds2 if isinstance(filename_or_obj, path_type): @@ -303,47 +312,35 @@ def maybe_decode_store(store, lock=False): elif isinstance(filename_or_obj, basestring): filename_or_obj = _normalize_path(filename_or_obj) - if filename_or_obj.endswith('.gz'): - if engine is not None and engine != 'scipy': - raise ValueError('can only read gzipped netCDF files with ' - "default engine or engine='scipy'") - else: - engine = 'scipy' - if engine is None: engine = _get_default_engine(filename_or_obj, allow_remote=True) if engine == 'netcdf4': - store = backends.NetCDF4DataStore.open(filename_or_obj, - group=group, - autoclose=autoclose, - **backend_kwargs) + store = backends.NetCDF4DataStore.open( + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'scipy': - store = backends.ScipyDataStore(filename_or_obj, - autoclose=autoclose, - **backend_kwargs) + store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) elif engine == 'pydap': - store = backends.PydapDataStore.open(filename_or_obj, - **backend_kwargs) + store = backends.PydapDataStore.open( + filename_or_obj, **backend_kwargs) elif engine == 'h5netcdf': - store = backends.H5NetCDFStore(filename_or_obj, group=group, - autoclose=autoclose, - **backend_kwargs) + store = backends.H5NetCDFStore( + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'pynio': - store = backends.NioDataStore(filename_or_obj, - autoclose=autoclose, - **backend_kwargs) + store = backends.NioDataStore( + filename_or_obj, lock=lock, **backend_kwargs) elif engine == 'pseudonetcdf': store = backends.PseudoNetCDFDataStore.open( - filename_or_obj, autoclose=autoclose, **backend_kwargs) + filename_or_obj, lock=lock, **backend_kwargs) + elif engine == 'cfgrib': + store = backends.CfGribDataStore( + filename_or_obj, lock=lock, **backend_kwargs) else: raise ValueError('unrecognized engine for open_dataset: %r' % engine) - if lock is None: - lock = _default_lock(filename_or_obj, engine) with close_on_error(store): - return maybe_decode_store(store, lock) + return maybe_decode_store(store) else: if engine is not None and engine != 'scipy': raise ValueError('can only read file-like objects with ' @@ -355,7 +352,7 @@ def maybe_decode_store(store, lock=False): def open_dataarray(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None): @@ -385,15 +382,11 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, taken from variable attributes (if they exist). If the `_FillValue` or `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will - be replaced by NA. mask_and_scale defaults to True except for the + be replaced by NA. mask_and_scale defaults to True except for the pseudonetcdf backend. decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. concat_characters : bool, optional If True, concatenate along the last dimension of character arrays to form string arrays. Dimensions will only be concatenated over (and @@ -402,19 +395,19 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, decode_coords : bool, optional If True, decode the 'coordinates' attribute to identify coordinates in the resulting dataset. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio'}, optional + engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'cfgrib'}, + optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. chunks : int or dict, optional If chunks is provided, it used to load the new dataset into dask arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -428,7 +421,7 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, inconsistent values. backend_kwargs: dictionary, optional A dictionary of keyword arguments to pass on to the backend. This - may be useful when backend options would improve performance or + may be useful when backend options would improve performance or allow user control of dataset processing. Notes @@ -490,7 +483,7 @@ def close(self): def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, compat='no_conflicts', preprocess=None, engine=None, lock=None, data_vars='all', coords='different', - autoclose=False, parallel=False, **kwargs): + autoclose=None, parallel=False, **kwargs): """Open multiple files as a single dataset. Requires dask to be installed. See documentation for details on dask [1]. @@ -533,19 +526,16 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, of all non-null values. preprocess : callable, optional If provided, call this function on each dataset prior to concatenation. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio'}, optional + engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'cfgrib'}, + optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. - lock : False, True or threading.Lock, optional - This argument is passed on to :py:func:`dask.array.from_array`. By - default, a per-variable lock is used when reading data from netCDF - files with the netcdf4 and h5netcdf engines to avoid issues with - concurrent access when using dask's multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. data_vars : {'minimal', 'different', 'all' or list of str}, optional These data variables will be concatenated together: * 'minimal': Only data variables in which the dimension already @@ -604,9 +594,6 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, if not paths: raise IOError('no files to open') - if lock is None: - lock = _default_lock(paths[0], engine) - open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock, autoclose=autoclose, **kwargs) @@ -656,19 +643,21 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, - engine=None, writer=None, encoding=None, unlimited_dims=None, - compute=True): + engine=None, encoding=None, unlimited_dims=None, compute=True, + multifile=False): """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file See `Dataset.to_netcdf` for full API docs. - The ``writer`` argument is only for the private use of save_mfdataset. + The ``multifile`` argument is only for the private use of save_mfdataset. """ if isinstance(path_or_file, path_type): path_or_file = str(path_or_file) + if encoding is None: encoding = {} + if path_or_file is None: if engine is None: engine = 'scipy' @@ -676,6 +665,10 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, raise ValueError('invalid engine for creating bytes with ' 'to_netcdf: %r. Only the default engine ' "or engine='scipy' is supported" % engine) + if not compute: + raise NotImplementedError( + 'to_netcdf() with compute=False is not yet implemented when ' + 'returning bytes') elif isinstance(path_or_file, basestring): if engine is None: engine = _get_default_engine(path_or_file) @@ -695,45 +688,78 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, if format is not None: format = format.upper() - # if a writer is provided, store asynchronously - sync = writer is None - # handle scheduler specific logic - scheduler = get_scheduler() + scheduler = _get_scheduler() have_chunks = any(v.chunks for v in dataset.variables.values()) - if (have_chunks and scheduler in ['distributed', 'multiprocessing'] and - engine != 'netcdf4'): + + autoclose = have_chunks and scheduler in ['distributed', 'multiprocessing'] + if autoclose and engine == 'scipy': raise NotImplementedError("Writing netCDF files with the %s backend " "is not currently supported with dask's %s " "scheduler" % (engine, scheduler)) - lock = _get_lock(engine, scheduler, format, path_or_file) - autoclose = (have_chunks and - scheduler in ['distributed', 'multiprocessing']) target = path_or_file if path_or_file is not None else BytesIO() - store = store_open(target, mode, format, group, writer, - autoclose=autoclose, lock=lock) + kwargs = dict(autoclose=True) if autoclose else {} + store = store_open(target, mode, format, group, **kwargs) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) if isinstance(unlimited_dims, basestring): unlimited_dims = [unlimited_dims] + writer = ArrayWriter() + + # TODO: figure out how to refactor this logic (here and in save_mfdataset) + # to avoid this mess of conditionals try: - dataset.dump_to_store(store, sync=sync, encoding=encoding, - unlimited_dims=unlimited_dims, compute=compute) + # TODO: allow this work (setting up the file for writing array data) + # to be parallelized with dask + dump_to_store(dataset, store, writer, encoding=encoding, + unlimited_dims=unlimited_dims) + if autoclose: + store.close() + + if multifile: + return writer, store + + writes = writer.sync(compute=compute) + if path_or_file is None: + store.sync() return target.getvalue() finally: - if sync and isinstance(path_or_file, basestring): + if not multifile and compute: store.close() if not compute: import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return dask.delayed(_finalize_store)(writes, store) + + +def dump_to_store(dataset, store, writer=None, encoder=None, + encoding=None, unlimited_dims=None): + """Store dataset contents to a backends.*DataStore object.""" + if writer is None: + writer = ArrayWriter() + + if encoding is None: + encoding = {} + + variables, attrs = conventions.encode_dataset_coordinates(dataset) + + check_encoding = set() + for k, enc in encoding.items(): + # no need to shallow copy the variable again; that already happened + # in encode_dataset_coordinates + variables[k].encoding = enc + check_encoding.add(k) + + if encoder: + variables, attrs = encoder(variables, attrs) + + store.store(variables, attrs, check_encoding, writer, + unlimited_dims=unlimited_dims) - if not sync: - return store def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, engine=None, compute=True): @@ -806,7 +832,7 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, for obj in datasets: if not isinstance(obj, Dataset): raise TypeError('save_mfdataset only supports writing Dataset ' - 'objects, recieved type %s' % type(obj)) + 'objects, received type %s' % type(obj)) if groups is None: groups = [None] * len(datasets) @@ -816,22 +842,22 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, 'datasets, paths and groups arguments to ' 'save_mfdataset') - writer = ArrayWriter() if compute else None - stores = [to_netcdf(ds, path, mode, format, group, engine, writer, - compute=compute) - for ds, path, group in zip(datasets, paths, groups)] - - if not compute: - import dask - return dask.delayed(stores) + writers, stores = zip(*[ + to_netcdf(ds, path, mode, format, group, engine, compute=compute, + multifile=True) + for ds, path, group in zip(datasets, paths, groups)]) try: - delayed = writer.sync(compute=compute) - for store in stores: - store.sync() + writes = [w.sync(compute=compute) for w in writers] finally: - for store in stores: - store.close() + if compute: + for store in stores: + store.close() + + if not compute: + import dask + return dask.delayed([dask.delayed(_finalize_store)(w, s) + for w, s in zip(writes, stores)]) def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, @@ -852,13 +878,14 @@ def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, store = backends.ZarrStore.open_group(store=store, mode=mode, synchronizer=synchronizer, - group=group, writer=None) + group=group) - # I think zarr stores should always be sync'd immediately + writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims - dataset.dump_to_store(store, sync=True, encoding=encoding, compute=compute) + dump_to_store(dataset, store, writer, encoding=encoding) + writes = writer.sync(compute=compute) if not compute: import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return dask.delayed(_finalize_store)(writes, store) return store diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py new file mode 100644 index 00000000000..0807900054a --- /dev/null +++ b/xarray/backends/cfgrib_.py @@ -0,0 +1,71 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +from .. import Variable +from ..core import indexing +from ..core.utils import Frozen, FrozenOrderedDict +from .common import AbstractDataStore, BackendArray +from .locks import ensure_lock, SerializableLock + +# FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe +# in most circumstances. See: +# https://confluence.ecmwf.int/display/ECC/Frequently+Asked+Questions +ECCODES_LOCK = SerializableLock() + + +class CfGribArrayWrapper(BackendArray): + def __init__(self, datastore, array): + self.datastore = datastore + self.shape = array.shape + self.dtype = array.dtype + self.array = array + + def __getitem__(self, key): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER, self._getitem) + + def _getitem(self, key): + with self.datastore.lock: + return self.array[key] + + +class CfGribDataStore(AbstractDataStore): + """ + Implements the ``xr.AbstractDataStore`` read-only API for a GRIB file. + """ + def __init__(self, filename, lock=None, **backend_kwargs): + import cfgrib + if lock is None: + lock = ECCODES_LOCK + self.lock = ensure_lock(lock) + self.ds = cfgrib.open_file(filename, **backend_kwargs) + + def open_store_variable(self, name, var): + if isinstance(var.data, np.ndarray): + data = var.data + else: + wrapped_array = CfGribArrayWrapper(self, var.data) + data = indexing.LazilyOuterIndexedArray(wrapped_array) + + encoding = self.ds.encoding.copy() + encoding['original_shape'] = var.data.shape + + return Variable(var.dimensions, data, var.attributes, encoding) + + def get_variables(self): + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) + + def get_attrs(self): + return Frozen(self.ds.attributes) + + def get_dimensions(self): + return Frozen(self.ds.dimensions) + + def get_encoding(self): + dims = self.get_dimensions() + encoding = { + 'unlimited_dims': {k for k, v in dims.items() if v is None}, + } + return encoding diff --git a/xarray/backends/common.py b/xarray/backends/common.py index d5eccd9be52..405d989f4af 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,14 +1,10 @@ from __future__ import absolute_import, division, print_function -import contextlib import logging -import multiprocessing -import threading import time import traceback import warnings from collections import Mapping, OrderedDict -from functools import partial import numpy as np @@ -17,13 +13,6 @@ from ..core.pycompat import dask_array_type, iteritems from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin -# Import default lock -try: - from dask.utils import SerializableLock - HDF5_LOCK = SerializableLock() -except ImportError: - HDF5_LOCK = threading.Lock() - # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -31,56 +20,6 @@ NONE_VAR_NAME = '__values__' -def get_scheduler(get=None, collection=None): - """ Determine the dask scheduler that is being used. - - None is returned if not dask scheduler is active. - - See also - -------- - dask.utils.effective_get - """ - try: - from dask.utils import effective_get - actual_get = effective_get(get, collection) - try: - from dask.distributed import Client - if isinstance(actual_get.__self__, Client): - return 'distributed' - except (ImportError, AttributeError): - try: - import dask.multiprocessing - if actual_get == dask.multiprocessing.get: - return 'multiprocessing' - else: - return 'threaded' - except ImportError: - return 'threaded' - except ImportError: - return None - - -def get_scheduler_lock(scheduler, path_or_file=None): - """ Get the appropriate lock for a certain situation based onthe dask - scheduler used. - - See Also - -------- - dask.utils.get_scheduler_lock - """ - - if scheduler == 'distributed': - from dask.distributed import Lock - return Lock(path_or_file) - elif scheduler == 'multiprocessing': - return multiprocessing.Lock() - elif scheduler == 'threaded': - from dask.utils import SerializableLock - return SerializableLock() - else: - return threading.Lock() - - def _encode_variable_name(name): if name is None: name = NONE_VAR_NAME @@ -127,39 +66,6 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, time.sleep(1e-3 * next_delay) -class CombinedLock(object): - """A combination of multiple locks. - - Like a locked door, a CombinedLock is locked if any of its constituent - locks are locked. - """ - - def __init__(self, locks): - self.locks = tuple(set(locks)) # remove duplicates - - def acquire(self, *args): - return all(lock.acquire(*args) for lock in self.locks) - - def release(self, *args): - for lock in self.locks: - lock.release(*args) - - def __enter__(self): - for lock in self.locks: - lock.__enter__() - - def __exit__(self, *args): - for lock in self.locks: - lock.__exit__(*args) - - @property - def locked(self): - return any(lock.locked for lock in self.locks) - - def __repr__(self): - return "CombinedLock(%r)" % list(self.locks) - - class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): def __array__(self, dtype=None): @@ -168,9 +74,6 @@ def __array__(self, dtype=None): class AbstractDataStore(Mapping): - _autoclose = None - _ds = None - _isopen = False def __iter__(self): return iter(self.variables) @@ -253,7 +156,7 @@ def __exit__(self, exception_type, exception_value, traceback): class ArrayWriter(object): - def __init__(self, lock=HDF5_LOCK): + def __init__(self, lock=None): self.sources = [] self.targets = [] self.lock = lock @@ -268,6 +171,9 @@ def add(self, source, target): def sync(self, compute=True): if self.sources: import dask.array as da + # TODO: consider wrapping targets with dask.delayed, if this makes + # for any discernable difference in perforance, e.g., + # targets = [dask.delayed(t) for t in self.targets] delayed_store = da.store(self.sources, self.targets, lock=self.lock, compute=compute, flush=True) @@ -277,11 +183,6 @@ def sync(self, compute=True): class AbstractWritableDataStore(AbstractDataStore): - def __init__(self, writer=None, lock=HDF5_LOCK): - if writer is None: - writer = ArrayWriter(lock=lock) - self.writer = writer - self.delayed_store = None def encode(self, variables, attributes): """ @@ -323,12 +224,6 @@ def set_attribute(self, k, v): # pragma: no cover def set_variable(self, k, v): # pragma: no cover raise NotImplementedError - def sync(self, compute=True): - if self._isopen and self._autoclose: - # datastore will be reopened during write - self.close() - self.delayed_store = self.writer.sync(compute=compute) - def store_dataset(self, dataset): """ in stores, variables are all variables AND coordinates @@ -339,7 +234,7 @@ def store_dataset(self, dataset): self.store(dataset, dataset.attrs) def store(self, variables, attributes, check_encoding_set=frozenset(), - unlimited_dims=None): + writer=None, unlimited_dims=None): """ Top level method for putting data on this store, this method: - encodes variables/attributes @@ -355,16 +250,19 @@ def store(self, variables, attributes, check_encoding_set=frozenset(), check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. """ + if writer is None: + writer = ArrayWriter() variables, attributes = self.encode(variables, attributes) self.set_attributes(attributes) self.set_dimensions(variables, unlimited_dims=unlimited_dims) - self.set_variables(variables, check_encoding_set, + self.set_variables(variables, check_encoding_set, writer, unlimited_dims=unlimited_dims) def set_attributes(self, attributes): @@ -380,7 +278,7 @@ def set_attributes(self, attributes): for k, v in iteritems(attributes): self.set_attribute(k, v) - def set_variables(self, variables, check_encoding_set, + def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None): """ This provides a centralized method to set the variables on the data @@ -393,6 +291,7 @@ def set_variables(self, variables, check_encoding_set, check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. @@ -404,7 +303,7 @@ def set_variables(self, variables, check_encoding_set, target, source = self.prepare_variable( name, v, check, unlimited_dims=unlimited_dims) - self.writer.add(source, target) + writer.add(source, target) def set_dimensions(self, variables, unlimited_dims=None): """ @@ -451,87 +350,3 @@ def encode(self, variables, attributes): attributes = OrderedDict([(k, self.encode_attribute(v)) for k, v in attributes.items()]) return variables, attributes - - -class DataStorePickleMixin(object): - """Subclasses must define `ds`, `_opener` and `_mode` attributes. - - Do not subclass this class: it is not part of xarray's external API. - """ - - def __getstate__(self): - state = self.__dict__.copy() - del state['_ds'] - del state['_isopen'] - if self._mode == 'w': - # file has already been created, don't override when restoring - state['_mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self._ds = None - self._isopen = False - - @property - def ds(self): - if self._ds is not None and self._isopen: - return self._ds - ds = self._opener(mode=self._mode) - self._isopen = True - return ds - - @contextlib.contextmanager - def ensure_open(self, autoclose=None): - """ - Helper function to make sure datasets are closed and opened - at appropriate times to avoid too many open file errors. - - Use requires `autoclose=True` argument to `open_mfdataset`. - """ - - if autoclose is None: - autoclose = self._autoclose - - if not self._isopen: - try: - self._ds = self._opener() - self._isopen = True - yield - finally: - if autoclose: - self.close() - else: - yield - - def assert_open(self): - if not self._isopen: - raise AssertionError('internal failure: file must be open ' - 'if `autoclose=True` is used.') - - -class PickleByReconstructionWrapper(object): - - def __init__(self, opener, file, mode='r', **kwargs): - self.opener = partial(opener, file, mode=mode, **kwargs) - self.mode = mode - self._ds = None - - @property - def value(self): - self._ds = self.opener() - return self._ds - - def __getstate__(self): - state = self.__dict__.copy() - del state['_ds'] - if self.mode == 'w': - # file has already been created, don't override when restoring - state['mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - def close(self): - self._ds.close() diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py new file mode 100644 index 00000000000..a93285370b2 --- /dev/null +++ b/xarray/backends/file_manager.py @@ -0,0 +1,206 @@ +import threading + +from ..core import utils +from ..core.options import OPTIONS +from .lru_cache import LRUCache + + +# Global cache for storing open files. +FILE_CACHE = LRUCache( + OPTIONS['file_cache_maxsize'], on_evict=lambda k, v: v.close()) +assert FILE_CACHE.maxsize, 'file cache must be at least size one' + + +_DEFAULT_MODE = utils.ReprObject('') + + +class FileManager(object): + """Manager for acquiring and closing a file object. + + Use FileManager subclasses (CachingFileManager in particular) on backend + storage classes to automatically handle issues related to keeping track of + many open files and transferring them between multiple processes. + """ + + def acquire(self): + """Acquire the file object from this manager.""" + raise NotImplementedError + + def close(self, needs_lock=True): + """Close the file object associated with this manager, if needed.""" + raise NotImplementedError + + +class CachingFileManager(FileManager): + """Wrapper for automatically opening and closing file objects. + + Unlike files, CachingFileManager objects can be safely pickled and passed + between processes. They should be explicitly closed to release resources, + but a per-process least-recently-used cache for open files ensures that you + can safely create arbitrarily large numbers of FileManager objects. + + Don't directly close files acquired from a FileManager. Instead, call + FileManager.close(), which ensures that closed files are removed from the + cache as well. + + Example usage: + + manager = FileManager(open, 'example.txt', mode='w') + f = manager.acquire() + f.write(...) + manager.close() # ensures file is closed + + Note that as long as previous files are still cached, acquiring a file + multiple times from the same FileManager is essentially free: + + f1 = manager.acquire() + f2 = manager.acquire() + assert f1 is f2 + + """ + + def __init__(self, opener, *args, **keywords): + """Initialize a FileManager. + + Parameters + ---------- + opener : callable + Function that when called like ``opener(*args, **kwargs)`` returns + an open file object. The file object must implement a ``close()`` + method. + *args + Positional arguments for opener. A ``mode`` argument should be + provided as a keyword argument (see below). All arguments must be + hashable. + mode : optional + If provided, passed as a keyword argument to ``opener`` along with + ``**kwargs``. ``mode='w' `` has special treatment: after the first + call it is replaced by ``mode='a'`` in all subsequent function to + avoid overriding the newly created file. + kwargs : dict, optional + Keyword arguments for opener, excluding ``mode``. All values must + be hashable. + lock : duck-compatible threading.Lock, optional + Lock to use when modifying the cache inside acquire() and close(). + By default, uses a new threading.Lock() object. If set, this object + should be pickleable. + cache : MutableMapping, optional + Mapping to use as a cache for open files. By default, uses xarray's + global LRU file cache. Because ``cache`` typically points to a + global variable and contains non-picklable file objects, an + unpickled FileManager objects will be restored with the default + cache. + """ + # TODO: replace with real keyword arguments when we drop Python 2 + # support + mode = keywords.pop('mode', _DEFAULT_MODE) + kwargs = keywords.pop('kwargs', None) + lock = keywords.pop('lock', None) + cache = keywords.pop('cache', FILE_CACHE) + if keywords: + raise TypeError('FileManager() got unexpected keyword arguments: ' + '%s' % list(keywords)) + + self._opener = opener + self._args = args + self._mode = mode + self._kwargs = {} if kwargs is None else dict(kwargs) + self._default_lock = lock is None or lock is False + self._lock = threading.Lock() if self._default_lock else lock + self._cache = cache + self._key = self._make_key() + + def _make_key(self): + """Make a key for caching files in the LRU cache.""" + value = (self._opener, + self._args, + self._mode, + tuple(sorted(self._kwargs.items()))) + return _HashedSequence(value) + + def acquire(self): + """Acquiring a file object from the manager. + + A new file is only opened if it has expired from the + least-recently-used cache. + + This method uses a reentrant lock, which ensures that it is + thread-safe. You can safely acquire a file in multiple threads at the + same time, as long as the underlying file object is thread-safe. + + Returns + ------- + An open file object, as returned by ``opener(*args, **kwargs)``. + """ + with self._lock: + try: + file = self._cache[self._key] + except KeyError: + kwargs = self._kwargs + if self._mode is not _DEFAULT_MODE: + kwargs = kwargs.copy() + kwargs['mode'] = self._mode + file = self._opener(*self._args, **kwargs) + if self._mode == 'w': + # ensure file doesn't get overriden when opened again + self._mode = 'a' + self._key = self._make_key() + self._cache[self._key] = file + return file + + def _close(self): + default = None + file = self._cache.pop(self._key, default) + if file is not None: + file.close() + + def close(self, needs_lock=True): + """Explicitly close any associated file object (if necessary).""" + # TODO: remove needs_lock if/when we have a reentrant lock in + # dask.distributed: https://github.com/dask/dask/issues/3832 + if needs_lock: + with self._lock: + self._close() + else: + self._close() + + def __getstate__(self): + """State for pickling.""" + lock = None if self._default_lock else self._lock + return (self._opener, self._args, self._mode, self._kwargs, lock) + + def __setstate__(self, state): + """Restore from a pickle.""" + opener, args, mode, kwargs, lock = state + self.__init__(opener, *args, mode=mode, kwargs=kwargs, lock=lock) + + +class _HashedSequence(list): + """Speedup repeated look-ups by caching hash values. + + Based on what Python uses internally in functools.lru_cache. + + Python doesn't perform this optimization automatically: + https://bugs.python.org/issue1462796 + """ + + def __init__(self, tuple_value): + self[:] = tuple_value + self.hashvalue = hash(tuple_value) + + def __hash__(self): + return self.hashvalue + + +class DummyFileManager(FileManager): + """FileManager that simply wraps an open file in the FileManager interface. + """ + def __init__(self, value): + self._value = value + + def acquire(self): + return self._value + + def close(self, needs_lock=True): + del needs_lock # ignored + self._value.close() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index ecc83e98691..59cd4e84793 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,29 +8,27 @@ from ..core import indexing from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type from ..core.utils import FrozenOrderedDict, close_on_error -from .common import ( - HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root) +from .common import WritableCFDataStore +from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( - BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding, - _get_datatype, _nc4_require_group) + BaseNetCDF4Array, GroupWrapper, _encode_nc4_variable, + _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group) class H5NetCDFArrayWrapper(BaseNetCDF4Array): def __getitem__(self, key): - key, np_inds = indexing.decompose_indexer( - key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR) + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, + self._getitem) + def _getitem(self, key): # h5py requires using lists for fancy indexing: # https://github.com/h5py/h5py/issues/992 - key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in - key.tuple) - with self.datastore.ensure_open(autoclose=True): - array = self.get_array()[key] - - if len(np_inds.tuple) > 0: - array = indexing.NumpyIndexingAdapter(array)[np_inds] - - return array + key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key) + array = self.get_array() + with self.datastore.lock: + return array[key] def maybe_decode_bytes(txt): @@ -65,104 +63,102 @@ def _open_h5netcdf_group(filename, mode, group): import h5netcdf ds = h5netcdf.File(filename, mode=mode) with close_on_error(ds): - return _nc4_require_group( + ds = _nc4_require_group( ds, group, mode, create_group=_h5netcdf_create_group) + return GroupWrapper(ds) -class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): +class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None, autoclose=False, lock=HDF5_LOCK): + lock=None, autoclose=False): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') - opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, - group=group) - self._ds = opener() - if autoclose: - raise NotImplementedError('autoclose=True is not implemented ' - 'for the h5netcdf backend pending ' - 'further exploration, e.g., bug fixes ' - '(in h5netcdf?)') - self._autoclose = False - self._isopen = True + self._manager = CachingFileManager( + _open_h5netcdf_group, filename, mode=mode, + kwargs=dict(group=group)) + + if lock is None: + if mode == 'r': + lock = HDF5_LOCK + else: + lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) + self.format = format - self._opener = opener self._filename = filename self._mode = mode - super(H5NetCDFStore, self).__init__(writer, lock=lock) + self.lock = ensure_lock(lock) + self.autoclose = autoclose + + @property + def ds(self): + return self._manager.acquire().value def open_store_variable(self, name, var): import h5py - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - H5NetCDFArrayWrapper(name, self)) - attrs = _read_attributes(var) - - # netCDF4 specific encoding - encoding = { - 'chunksizes': var.chunks, - 'fletcher32': var.fletcher32, - 'shuffle': var.shuffle, - } - # Convert h5py-style compression options to NetCDF4-Python - # style, if possible - if var.compression == 'gzip': - encoding['zlib'] = True - encoding['complevel'] = var.compression_opts - elif var.compression is not None: - encoding['compression'] = var.compression - encoding['compression_opts'] = var.compression_opts - - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - - vlen_dtype = h5py.check_dtype(vlen=var.dtype) - if vlen_dtype is unicode_type: - encoding['dtype'] = str - elif vlen_dtype is not None: # pragma: no cover - # xarray doesn't support writing arbitrary vlen dtypes yet. - pass - else: - encoding['dtype'] = var.dtype + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + H5NetCDFArrayWrapper(name, self)) + attrs = _read_attributes(var) + + # netCDF4 specific encoding + encoding = { + 'chunksizes': var.chunks, + 'fletcher32': var.fletcher32, + 'shuffle': var.shuffle, + } + # Convert h5py-style compression options to NetCDF4-Python + # style, if possible + if var.compression == 'gzip': + encoding['zlib'] = True + encoding['complevel'] = var.compression_opts + elif var.compression is not None: + encoding['compression'] = var.compression + encoding['compression_opts'] = var.compression_opts + + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + + vlen_dtype = h5py.check_dtype(vlen=var.dtype) + if vlen_dtype is unicode_type: + encoding['dtype'] = str + elif vlen_dtype is not None: # pragma: no cover + # xarray doesn't support writing arbitrary vlen dtypes yet. + pass + else: + encoding['dtype'] = var.dtype return Variable(dimensions, data, attrs, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return FrozenOrderedDict(_read_attributes(self.ds)) + return FrozenOrderedDict(_read_attributes(self.ds)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return self.ds.dimensions + return self.ds.dimensions def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v is None} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v is None} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if is_unlimited: - self.ds.dimensions[name] = None - self.ds.resize_dimension(name, length) - else: - self.ds.dimensions[name] = length + if is_unlimited: + self.ds.dimensions[name] = None + self.ds.resize_dimension(name, length) + else: + self.ds.dimensions[name] = length def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self.ds.attrs[key] = value + self.ds.attrs[key] = value def encode_variable(self, variable): return _encode_nc4_variable(variable) @@ -230,18 +226,11 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self, compute=True): - if not compute: - raise NotImplementedError( - 'compute=False is not supported for the h5netcdf backend yet') - with self.ensure_open(autoclose=True): - super(H5NetCDFStore, self).sync(compute=compute) - self.ds.sync() - - def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if not ds._closed: - ds.close() - self._isopen = False + def sync(self): + self.ds.sync() + # if self.autoclose: + # self.close() + # super(H5NetCDFStore, self).sync(compute=compute) + + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py new file mode 100644 index 00000000000..f633280ef1d --- /dev/null +++ b/xarray/backends/locks.py @@ -0,0 +1,191 @@ +import multiprocessing +import threading +import weakref + +try: + from dask.utils import SerializableLock +except ImportError: + # no need to worry about serializing the lock + SerializableLock = threading.Lock + + +# Locks used by multiple backends. +# Neither HDF5 nor the netCDF-C library are thread-safe. +HDF5_LOCK = SerializableLock() +NETCDFC_LOCK = SerializableLock() + + +_FILE_LOCKS = weakref.WeakValueDictionary() + + +def _get_threaded_lock(key): + try: + lock = _FILE_LOCKS[key] + except KeyError: + lock = _FILE_LOCKS[key] = threading.Lock() + return lock + + +def _get_multiprocessing_lock(key): + # TODO: make use of the key -- maybe use locket.py? + # https://github.com/mwilliamson/locket.py + del key # unused + return multiprocessing.Lock() + + +def _get_distributed_lock(key): + from dask.distributed import Lock + return Lock(key) + + +_LOCK_MAKERS = { + None: _get_threaded_lock, + 'threaded': _get_threaded_lock, + 'multiprocessing': _get_multiprocessing_lock, + 'distributed': _get_distributed_lock, +} + + +def _get_lock_maker(scheduler=None): + """Returns an appropriate function for creating resource locks. + + Parameters + ---------- + scheduler : str or None + Dask scheduler being used. + + See Also + -------- + dask.utils.get_scheduler_lock + """ + return _LOCK_MAKERS[scheduler] + + +def _get_scheduler(get=None, collection=None): + """Determine the dask scheduler that is being used. + + None is returned if no dask scheduler is active. + + See also + -------- + dask.base.get_scheduler + """ + try: + # dask 0.18.1 and later + from dask.base import get_scheduler + actual_get = get_scheduler(get, collection) + except ImportError: + try: + from dask.utils import effective_get + actual_get = effective_get(get, collection) + except ImportError: + return None + + try: + from dask.distributed import Client + if isinstance(actual_get.__self__, Client): + return 'distributed' + except (ImportError, AttributeError): + try: + import dask.multiprocessing + if actual_get == dask.multiprocessing.get: + return 'multiprocessing' + else: + return 'threaded' + except ImportError: + return 'threaded' + + +def get_write_lock(key): + """Get a scheduler appropriate lock for writing to the given resource. + + Parameters + ---------- + key : str + Name of the resource for which to acquire a lock. Typically a filename. + + Returns + ------- + Lock object that can be used like a threading.Lock object. + """ + scheduler = _get_scheduler() + lock_maker = _get_lock_maker(scheduler) + return lock_maker(key) + + +class CombinedLock(object): + """A combination of multiple locks. + + Like a locked door, a CombinedLock is locked if any of its constituent + locks are locked. + """ + + def __init__(self, locks): + self.locks = tuple(set(locks)) # remove duplicates + + def acquire(self, *args): + return all(lock.acquire(*args) for lock in self.locks) + + def release(self, *args): + for lock in self.locks: + lock.release(*args) + + def __enter__(self): + for lock in self.locks: + lock.__enter__() + + def __exit__(self, *args): + for lock in self.locks: + lock.__exit__(*args) + + @property + def locked(self): + return any(lock.locked for lock in self.locks) + + def __repr__(self): + return "CombinedLock(%r)" % list(self.locks) + + +class DummyLock(object): + """DummyLock provides the lock API without any actual locking.""" + + def acquire(self, *args): + pass + + def release(self, *args): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + @property + def locked(self): + return False + + +def combine_locks(locks): + """Combine a sequence of locks into a single lock.""" + all_locks = [] + for lock in locks: + if isinstance(lock, CombinedLock): + all_locks.extend(lock.locks) + elif lock is not None: + all_locks.append(lock) + + num_locks = len(all_locks) + if num_locks > 1: + return CombinedLock(all_locks) + elif num_locks == 1: + return all_locks[0] + else: + return DummyLock() + + +def ensure_lock(lock): + """Ensure that the given object is a lock.""" + if lock is None or lock is False: + return DummyLock() + return lock diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py new file mode 100644 index 00000000000..321a1ca4da4 --- /dev/null +++ b/xarray/backends/lru_cache.py @@ -0,0 +1,91 @@ +import collections +import threading + +from ..core.pycompat import move_to_end + + +class LRUCache(collections.MutableMapping): + """Thread-safe LRUCache based on an OrderedDict. + + All dict operations (__getitem__, __setitem__, __contains__) update the + priority of the relevant key and take O(1) time. The dict is iterated over + in order from the oldest to newest key, which means that a complete pass + over the dict should not affect the order of any entries. + + When a new item is set and the maximum size of the cache is exceeded, the + oldest item is dropped and called with ``on_evict(key, value)``. + + The ``maxsize`` property can be used to view or adjust the capacity of + the cache, e.g., ``cache.maxsize = new_size``. + """ + def __init__(self, maxsize, on_evict=None): + """ + Parameters + ---------- + maxsize : int + Integer maximum number of items to hold in the cache. + on_evict: callable, optional + Function to call like ``on_evict(key, value)`` when items are + evicted. + """ + if not isinstance(maxsize, int): + raise TypeError('maxsize must be an integer') + if maxsize < 0: + raise ValueError('maxsize must be non-negative') + self._maxsize = maxsize + self._on_evict = on_evict + self._cache = collections.OrderedDict() + self._lock = threading.RLock() + + def __getitem__(self, key): + # record recent use of the key by moving it to the front of the list + with self._lock: + value = self._cache[key] + move_to_end(self._cache, key) + return value + + def _enforce_size_limit(self, capacity): + """Shrink the cache if necessary, evicting the oldest items.""" + while len(self._cache) > capacity: + key, value = self._cache.popitem(last=False) + if self._on_evict is not None: + self._on_evict(key, value) + + def __setitem__(self, key, value): + with self._lock: + if key in self._cache: + # insert the new value at the end + del self._cache[key] + self._cache[key] = value + elif self._maxsize: + # make room if necessary + self._enforce_size_limit(self._maxsize - 1) + self._cache[key] = value + elif self._on_evict is not None: + # not saving, immediately evict + self._on_evict(key, value) + + def __delitem__(self, key): + del self._cache[key] + + def __iter__(self): + # create a list, so accessing the cache during iteration cannot change + # the iteration order + return iter(list(self._cache)) + + def __len__(self): + return len(self._cache) + + @property + def maxsize(self): + """Maximum number of items can be held in the cache.""" + return self._maxsize + + @maxsize.setter + def maxsize(self, size): + """Resize the cache, evicting the oldest items if necessary.""" + if size < 0: + raise ValueError('maxsize must be non-negative') + with self._lock: + self._enforce_size_limit(size) + self._maxsize = size diff --git a/xarray/backends/memory.py b/xarray/backends/memory.py index dcf092557b8..195d4647534 100644 --- a/xarray/backends/memory.py +++ b/xarray/backends/memory.py @@ -17,10 +17,9 @@ class InMemoryDataStore(AbstractWritableDataStore): This store exists purely for internal testing purposes. """ - def __init__(self, variables=None, attributes=None, writer=None): + def __init__(self, variables=None, attributes=None): self._variables = OrderedDict() if variables is None else variables self._attributes = OrderedDict() if attributes is None else attributes - super(InMemoryDataStore, self).__init__(writer) def get_attrs(self): return self._attributes diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index d26b2b5321e..08ba085b77e 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -10,12 +10,13 @@ from .. import Variable, coding from ..coding.variables import pop_to from ..core import indexing -from ..core.pycompat import ( - PY3, OrderedDict, basestring, iteritems, suppress) +from ..core.pycompat import PY3, OrderedDict, basestring, iteritems, suppress from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri from .common import ( - HDF5_LOCK, BackendArray, DataStorePickleMixin, WritableCFDataStore, - find_root, robust_getitem) + BackendArray, WritableCFDataStore, find_root, robust_getitem) +from .locks import (NETCDFC_LOCK, HDF5_LOCK, + combine_locks, ensure_lock, get_write_lock) +from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable # This lookup table maps from dtype.byteorder to a readable endian @@ -26,6 +27,9 @@ '|': 'native'} +NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) + + class BaseNetCDF4Array(BackendArray): def __init__(self, variable_name, datastore): self.datastore = datastore @@ -43,42 +47,44 @@ def __init__(self, variable_name, datastore): self.dtype = dtype def __setitem__(self, key, value): - with self.datastore.ensure_open(autoclose=True): + with self.datastore.lock: data = self.get_array() data[key] = value + if self.datastore.autoclose: + self.datastore.close(needs_lock=False) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] class NetCDF4ArrayWrapper(BaseNetCDF4Array): def __getitem__(self, key): - key, np_inds = indexing.decompose_indexer( - key, self.shape, indexing.IndexingSupport.OUTER) + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER, + self._getitem) + + def _getitem(self, key): if self.datastore.is_remote: # pragma: no cover getitem = functools.partial(robust_getitem, catch=RuntimeError) else: getitem = operator.getitem - with self.datastore.ensure_open(autoclose=True): - try: - array = getitem(self.get_array(), key.tuple) - except IndexError: - # Catch IndexError in netCDF4 and return a more informative - # error message. This is most often called when an unsorted - # indexer is used before the data is loaded from disk. - msg = ('The indexing operation you are attempting to perform ' - 'is not valid on netCDF4.Variable object. Try loading ' - 'your data into memory first by calling .load().') - if not PY3: - import traceback - msg += '\n\nOriginal traceback:\n' + traceback.format_exc() - raise IndexError(msg) - - if len(np_inds.tuple) > 0: - array = indexing.NumpyIndexingAdapter(array)[np_inds] + original_array = self.get_array() + try: + with self.datastore.lock: + array = getitem(original_array, key) + except IndexError: + # Catch IndexError in netCDF4 and return a more informative + # error message. This is most often called when an unsorted + # indexer is used before the data is loaded from disk. + msg = ('The indexing operation you are attempting to perform ' + 'is not valid on netCDF4.Variable object. Try loading ' + 'your data into memory first by calling .load().') + if not PY3: + import traceback + msg += '\n\nOriginal traceback:\n' + traceback.format_exc() + raise IndexError(msg) return array @@ -225,7 +231,17 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, return encoding -def _open_netcdf4_group(filename, mode, group=None, **kwargs): +class GroupWrapper(object): + """Wrap netCDF4.Group objects so closing them closes the root group.""" + def __init__(self, value): + self.value = value + + def close(self): + # netCDF4 only allows closing the root group + find_root(self.value).close() + + +def _open_netcdf4_group(filename, lock, mode, group=None, **kwargs): import netCDF4 as nc4 ds = nc4.Dataset(filename, mode=mode, **kwargs) @@ -235,7 +251,7 @@ def _open_netcdf4_group(filename, mode, group=None, **kwargs): _disable_auto_decode_group(ds) - return ds + return GroupWrapper(ds) def _disable_auto_decode_variable(var): @@ -281,40 +297,33 @@ def _set_nc_attribute(obj, key, value): obj.setncattr(key, value) -class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): +class NetCDF4DataStore(WritableCFDataStore): """Store for reading and writing data via the Python-NetCDF4 library. This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ - def __init__(self, netcdf4_dataset, mode='r', writer=None, opener=None, - autoclose=False, lock=HDF5_LOCK): - - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + def __init__(self, manager, lock=NETCDF4_PYTHON_LOCK, autoclose=False): + import netCDF4 - _disable_auto_decode_group(netcdf4_dataset) + if isinstance(manager, netCDF4.Dataset): + _disable_auto_decode_group(manager) + manager = DummyFileManager(GroupWrapper(manager)) - self._ds = netcdf4_dataset - self._autoclose = autoclose - self._isopen = True + self._manager = manager self.format = self.ds.data_model self._filename = self.ds.filepath() self.is_remote = is_remote_uri(self._filename) - self._mode = mode = 'a' if mode == 'w' else mode - if opener: - self._opener = functools.partial(opener, mode=self._mode) - else: - self._opener = opener - super(NetCDF4DataStore, self).__init__(writer, lock=lock) + self.lock = ensure_lock(lock) + self.autoclose = autoclose @classmethod def open(cls, filename, mode='r', format='NETCDF4', group=None, - writer=None, clobber=True, diskless=False, persist=False, - autoclose=False, lock=HDF5_LOCK): - import netCDF4 as nc4 + clobber=True, diskless=False, persist=False, + lock=None, lock_maker=None, autoclose=False): + import netCDF4 if (len(filename) == 88 and - LooseVersion(nc4.__version__) < "1.3.1"): + LooseVersion(netCDF4.__version__) < "1.3.1"): warnings.warn( 'A segmentation fault may occur when the ' 'file path has exactly 88 characters as it does ' @@ -325,86 +334,91 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, 'https://github.com/pydata/xarray/issues/1745') if format is None: format = 'NETCDF4' - opener = functools.partial(_open_netcdf4_group, filename, mode=mode, - group=group, clobber=clobber, - diskless=diskless, persist=persist, - format=format) - ds = opener() - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose, lock=lock) - def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - NetCDF4ArrayWrapper(name, self)) - attributes = OrderedDict((k, var.getncattr(k)) - for k in var.ncattrs()) - _ensure_fill_value_valid(data, attributes) - # netCDF4 specific encoding; save _FillValue for later - encoding = {} - filters = var.filters() - if filters is not None: - encoding.update(filters) - chunking = var.chunking() - if chunking is not None: - if chunking == 'contiguous': - encoding['contiguous'] = True - encoding['chunksizes'] = None + if lock is None: + if mode == 'r': + if is_remote_uri(filename): + lock = NETCDFC_LOCK + else: + lock = NETCDF4_PYTHON_LOCK + else: + if format is None or format.startswith('NETCDF4'): + base_lock = NETCDF4_PYTHON_LOCK else: - encoding['contiguous'] = False - encoding['chunksizes'] = tuple(chunking) - # TODO: figure out how to round-trip "endian-ness" without raising - # warnings from netCDF4 - # encoding['endian'] = var.endian() - pop_to(attributes, encoding, 'least_significant_digit') - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - encoding['dtype'] = var.dtype + base_lock = NETCDFC_LOCK + lock = combine_locks([base_lock, get_write_lock(filename)]) + + manager = CachingFileManager( + _open_netcdf4_group, filename, lock, mode=mode, + kwargs=dict(group=group, clobber=clobber, diskless=diskless, + persist=persist, format=format)) + return cls(manager, lock=lock, autoclose=autoclose) + + @property + def ds(self): + return self._manager.acquire().value + + def open_store_variable(self, name, var): + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + NetCDF4ArrayWrapper(name, self)) + attributes = OrderedDict((k, var.getncattr(k)) + for k in var.ncattrs()) + _ensure_fill_value_valid(data, attributes) + # netCDF4 specific encoding; save _FillValue for later + encoding = {} + filters = var.filters() + if filters is not None: + encoding.update(filters) + chunking = var.chunking() + if chunking is not None: + if chunking == 'contiguous': + encoding['contiguous'] = True + encoding['chunksizes'] = None + else: + encoding['contiguous'] = False + encoding['chunksizes'] = tuple(chunking) + # TODO: figure out how to round-trip "endian-ness" without raising + # warnings from netCDF4 + # encoding['endian'] = var.endian() + pop_to(attributes, encoding, 'least_significant_digit') + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + encoding['dtype'] = var.dtype return Variable(dimensions, data, attributes, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in - iteritems(self.ds.variables)) + dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in + iteritems(self.ds.variables)) return dsvars def get_attrs(self): - with self.ensure_open(autoclose=True): - attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) - for k in self.ds.ncattrs()) + attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) + for k in self.ds.ncattrs()) return attrs def get_dimensions(self): - with self.ensure_open(autoclose=True): - dims = FrozenOrderedDict((k, len(v)) - for k, v in iteritems(self.ds.dimensions)) + dims = FrozenOrderedDict((k, len(v)) + for k, v in iteritems(self.ds.dimensions)) return dims def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v.isunlimited()} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v.isunlimited()} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, size=dim_length) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, size=dim_length) def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - if self.format != 'NETCDF4': - value = encode_nc3_attr_value(value) - _set_nc_attribute(self.ds, key, value) - - def set_variables(self, *args, **kwargs): - with self.ensure_open(autoclose=False): - super(NetCDF4DataStore, self).set_variables(*args, **kwargs) + if self.format != 'NETCDF4': + value = encode_nc3_attr_value(value) + _set_nc_attribute(self.ds, key, value) def encode_variable(self, variable): variable = _force_native_endianness(variable) @@ -462,15 +476,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self, compute=True): - with self.ensure_open(autoclose=True): - super(NetCDF4DataStore, self).sync(compute=compute) - self.ds.sync() + def sync(self): + self.ds.sync() - def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if ds._isopen: - ds.close() - self._isopen = False + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index c481bf848b9..606ed5251ac 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -1,17 +1,18 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools +from __future__ import absolute_import, division, print_function import numpy as np from .. import Variable -from ..core.pycompat import OrderedDict -from ..core.utils import (FrozenOrderedDict, Frozen) from ..core import indexing +from ..core.pycompat import OrderedDict +from ..core.utils import Frozen, FrozenOrderedDict +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock + -from .common import AbstractDataStore, DataStorePickleMixin, BackendArray +# psuedonetcdf can invoke netCDF libraries internally +PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) class PncArrayWrapper(BackendArray): @@ -24,69 +25,63 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.dtype) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): - key, np_inds = indexing.decompose_indexer( - key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR) + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, + self._getitem) - with self.datastore.ensure_open(autoclose=True): - array = self.get_array()[key.tuple] # index backend array - - if len(np_inds.tuple) > 0: - # index the loaded np.ndarray - array = indexing.NumpyIndexingAdapter(array)[np_inds] - return array + def _getitem(self, key): + array = self.get_array() + with self.datastore.lock: + return array[key] -class PseudoNetCDFDataStore(AbstractDataStore, DataStorePickleMixin): +class PseudoNetCDFDataStore(AbstractDataStore): """Store for accessing datasets via PseudoNetCDF """ @classmethod - def open(cls, filename, format=None, writer=None, - autoclose=False, **format_kwds): + def open(cls, filename, lock=None, **format_kwds): from PseudoNetCDF import pncopen - opener = functools.partial(pncopen, filename, **format_kwds) - ds = opener() - mode = format_kwds.get('mode', 'r') - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose) - def __init__(self, pnc_dataset, mode='r', writer=None, opener=None, - autoclose=False): + keywords = dict(kwargs=format_kwds) + # only include mode if explicitly passed + mode = format_kwds.pop('mode', None) + if mode is not None: + keywords['mode'] = mode + + if lock is None: + lock = PNETCDF_LOCK + + manager = CachingFileManager(pncopen, filename, lock=lock, **keywords) + return cls(manager, lock) - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + def __init__(self, manager, lock=None): + self._manager = manager + self.lock = ensure_lock(lock) - self._ds = pnc_dataset - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode - super(PseudoNetCDFDataStore, self).__init__() + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - data = indexing.LazilyOuterIndexedArray( - PncArrayWrapper(name, self) - ) + data = indexing.LazilyOuterIndexedArray( + PncArrayWrapper(name, self) + ) attrs = OrderedDict((k, getattr(var, k)) for k in var.ncattrs()) return Variable(var.dimensions, data, attrs) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(dict([(k, getattr(self.ds, k)) - for k in self.ds.ncattrs()])) + return Frozen(dict([(k, getattr(self.ds, k)) + for k in self.ds.ncattrs()])) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -96,6 +91,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 4a932e3dad2..71ea4841b71 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -22,22 +22,20 @@ def dtype(self): return self.array.dtype def __getitem__(self, key): - key, np_inds = indexing.decompose_indexer( - key, self.shape, indexing.IndexingSupport.BASIC) + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) + def _getitem(self, key): # pull the data from the array attribute if possible, to avoid # downloading coordinate data twice array = getattr(self.array, 'array', self.array) - result = robust_getitem(array, key.tuple, catch=ValueError) + result = robust_getitem(array, key, catch=ValueError) # pydap doesn't squeeze axes automatically like numpy - axis = tuple(n for n, k in enumerate(key.tuple) + axis = tuple(n for n, k in enumerate(key) if isinstance(k, integer_types)) if len(axis) > 0: result = np.squeeze(result, axis) - if len(np_inds.tuple) > 0: - result = indexing.NumpyIndexingAdapter(np.asarray(result))[np_inds] - return result diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 3c638b6b057..574fff744e3 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,13 +1,20 @@ from __future__ import absolute_import, division, print_function -import functools - import numpy as np from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenOrderedDict -from .common import AbstractDataStore, BackendArray, DataStorePickleMixin +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager +from .locks import ( + HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, SerializableLock) + + +# PyNIO can invoke netCDF libraries internally +# Add a dedicated lock just in case NCL as well isn't thread-safe. +NCL_LOCK = SerializableLock() +PYNIO_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK, NCL_LOCK]) class NioArrayWrapper(BackendArray): @@ -20,57 +27,52 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.typecode()) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): - key, np_inds = indexing.decompose_indexer( - key, self.shape, indexing.IndexingSupport.BASIC) + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) - with self.datastore.ensure_open(autoclose=True): - array = self.get_array() - if key.tuple == () and self.ndim == 0: + def _getitem(self, key): + array = self.get_array() + with self.datastore.lock: + if key == () and self.ndim == 0: return array.get_value() + return array[key] - array = array[key.tuple] - if len(np_inds.tuple) > 0: - array = indexing.NumpyIndexingAdapter(array)[np_inds] - return array - - -class NioDataStore(AbstractDataStore, DataStorePickleMixin): +class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO """ - def __init__(self, filename, mode='r', autoclose=False): + def __init__(self, filename, mode='r', lock=None): import Nio - opener = functools.partial(Nio.open_file, filename, mode=mode) - self._ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + if lock is None: + lock = PYNIO_LOCK + self.lock = ensure_lock(lock) + self._manager = CachingFileManager( + Nio.open_file, filename, lock=lock, mode=mode) # xarray provides its own support for FillValue, # so turn off PyNIO's support for the same. self.ds.set_option('MaskedArrayMode', 'MaskedNever') + @property + def ds(self): + return self._manager.acquire() + def open_store_variable(self, name, var): data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self)) return Variable(var.dimensions, data, var.attributes) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.attributes) + return Frozen(self.ds.attributes) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -79,6 +81,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 0f19a1b51be..7a343a6529e 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -1,21 +1,20 @@ import os +import warnings from collections import OrderedDict from distutils.version import LooseVersion -import warnings import numpy as np from .. import DataArray from ..core import indexing from ..core.utils import is_scalar -from .common import BackendArray, PickleByReconstructionWrapper +from .common import BackendArray +from .file_manager import CachingFileManager +from .locks import SerializableLock -try: - from dask.utils import SerializableLock as Lock -except ImportError: - from threading import Lock -RASTERIO_LOCK = Lock() +# TODO: should this be GDAL_LOCK instead? +RASTERIO_LOCK = SerializableLock() _ERROR_MSG = ('The kind of indexing operation you are trying to do is not ' 'valid on rasterio files. Try to load your data with ds.load()' @@ -25,18 +24,22 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, riods): - self.riods = riods - self._shape = (riods.value.count, riods.value.height, - riods.value.width) - self._ndims = len(self.shape) + def __init__(self, manager): + self.manager = manager - @property - def dtype(self): - dtypes = self.riods.value.dtypes + # cannot save riods as an attribute: this would break pickleability + riods = manager.acquire() + + self._shape = (riods.count, riods.height, riods.width) + + dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError('All bands should have the same dtype') - return np.dtype(dtypes[0]) + self._dtype = np.dtype(dtypes[0]) + + @property + def dtype(self): + return self._dtype @property def shape(self): @@ -47,7 +50,7 @@ def _get_indexer(self, key): Parameter --------- - key: ExplicitIndexer + key: tuple of int Returns ------- @@ -60,13 +63,11 @@ def _get_indexer(self, key): -------- indexing.decompose_indexer """ - key, np_inds = indexing.decompose_indexer( - key, self.shape, indexing.IndexingSupport.OUTER) + assert len(key) == 3, 'rasterio datasets should always be 3D' # bands cannot be windowed but they can be listed - band_key = key.tuple[0] - new_shape = [] - np_inds2 = [] + band_key = key[0] + np_inds = [] # bands (axis=0) cannot be windowed but they can be listed if isinstance(band_key, slice): start, stop, step = band_key.indices(self.shape[0]) @@ -74,18 +75,16 @@ def _get_indexer(self, key): # be sure we give out a list band_key = (np.asarray(band_key) + 1).tolist() if isinstance(band_key, list): # if band_key is not a scalar - new_shape.append(len(band_key)) - np_inds2.append(slice(None)) + np_inds.append(slice(None)) # but other dims can only be windowed window = [] squeeze_axis = [] - for i, (k, n) in enumerate(zip(key.tuple[1:], self.shape[1:])): + for i, (k, n) in enumerate(zip(key[1:], self.shape[1:])): if isinstance(k, slice): # step is always positive. see indexing.decompose_indexer start, stop, step = k.indices(n) - np_inds2.append(slice(None, None, step)) - new_shape.append(stop - start) + np_inds.append(slice(None, None, step)) elif is_scalar(k): # windowed operations will always return an array # we will have to squeeze it later @@ -94,21 +93,34 @@ def _get_indexer(self, key): stop = k + 1 else: start, stop = np.min(k), np.max(k) + 1 - np_inds2.append(k - start) - new_shape.append(stop - start) + np_inds.append(k - start) window.append((start, stop)) - np_inds = indexing._combine_indexers( - indexing.OuterIndexer(tuple(np_inds2)), new_shape, np_inds) - return band_key, window, tuple(squeeze_axis), np_inds + if isinstance(key[1], np.ndarray) and isinstance(key[2], np.ndarray): + # do outer-style indexing + np_inds[-2:] = np.ix_(*np_inds[-2:]) - def __getitem__(self, key): + return band_key, tuple(window), tuple(squeeze_axis), tuple(np_inds) + + def _getitem(self, key): band_key, window, squeeze_axis, np_inds = self._get_indexer(key) - out = self.riods.value.read(band_key, window=tuple(window)) + if not band_key or any(start == stop for (start, stop) in window): + # no need to do IO + shape = (len(band_key),) + tuple( + stop - start for (start, stop) in window) + out = np.zeros(shape, dtype=self.dtype) + else: + riods = self.manager.acquire() + out = riods.read(band_key, window=window) + if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) - return indexing.NumpyIndexingAdapter(out)[np_inds] + return out[np_inds] + + def __getitem__(self, key): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER, self._getitem) def _parse_envi(meta): @@ -157,7 +169,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, from affine import Affine da = xr.open_rasterio('path_to_file.tif') - transform = Affine(*da.attrs['transform']) + transform = Affine.from_gdal(*da.attrs['transform']) nx, ny = da.sizes['x'], da.sizes['y'] x, y = np.meshgrid(np.arange(nx)+0.5, np.arange(ny)+0.5) * transform @@ -195,7 +207,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, import rasterio - riods = PickleByReconstructionWrapper(rasterio.open, filename, mode='r') + manager = CachingFileManager(rasterio.open, filename, mode='r') + riods = manager.acquire() if cache is None: cache = chunks is None @@ -203,20 +216,20 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, coords = OrderedDict() # Get bands - if riods.value.count < 1: + if riods.count < 1: raise ValueError('Unknown dims') - coords['band'] = np.asarray(riods.value.indexes) + coords['band'] = np.asarray(riods.indexes) # Get coordinates if LooseVersion(rasterio.__version__) < '1.0': - transform = riods.value.affine + transform = riods.affine else: - transform = riods.value.transform + transform = riods.transform if transform.is_rectilinear: # 1d coordinates parse = True if parse_coordinates is None else parse_coordinates if parse: - nx, ny = riods.value.width, riods.value.height + nx, ny = riods.width, riods.height # xarray coordinates are pixel centered x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform @@ -226,64 +239,60 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # 2d coordinates parse = False if (parse_coordinates is None) else parse_coordinates if parse: - warnings.warn("The file coordinates' transformation isn't " - "rectilinear: xarray won't parse the coordinates " - "in this case. Set `parse_coordinates=False` to " - "suppress this warning.", - RuntimeWarning, stacklevel=3) + warnings.warn( + "The file coordinates' transformation isn't " + "rectilinear: xarray won't parse the coordinates " + "in this case. Set `parse_coordinates=False` to " + "suppress this warning.", + RuntimeWarning, stacklevel=3) # Attributes attrs = dict() # Affine transformation matrix (always available) # This describes coefficients mapping pixel coordinates to CRS # For serialization store as tuple of 6 floats, the last row being - # always (0, 0, 1) per definition (see https://github.com/sgillies/affine) + # always (0, 0, 1) per definition (see + # https://github.com/sgillies/affine) attrs['transform'] = tuple(transform)[:6] - if hasattr(riods.value, 'crs') and riods.value.crs: + if hasattr(riods, 'crs') and riods.crs: # CRS is a dict-like object specific to rasterio # If CRS is not None, we convert it back to a PROJ4 string using # rasterio itself - attrs['crs'] = riods.value.crs.to_string() - if hasattr(riods.value, 'res'): + attrs['crs'] = riods.crs.to_string() + if hasattr(riods, 'res'): # (width, height) tuple of pixels in units of CRS - attrs['res'] = riods.value.res - if hasattr(riods.value, 'is_tiled'): + attrs['res'] = riods.res + if hasattr(riods, 'is_tiled'): # Is the TIF tiled? (bool) # We cast it to an int for netCDF compatibility - attrs['is_tiled'] = np.uint8(riods.value.is_tiled) - with warnings.catch_warnings(): - # casting riods.value.transform to a tuple makes this future proof - warnings.simplefilter('ignore', FutureWarning) - if hasattr(riods.value, 'transform'): - # Affine transformation matrix (tuple of floats) - # Describes coefficients mapping pixel coordinates to CRS - attrs['transform'] = tuple(riods.value.transform) - if hasattr(riods.value, 'nodatavals'): + attrs['is_tiled'] = np.uint8(riods.is_tiled) + if hasattr(riods, 'nodatavals'): # The nodata values for the raster bands - attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval - for nodataval in riods.value.nodatavals]) + attrs['nodatavals'] = tuple( + np.nan if nodataval is None else nodataval + for nodataval in riods.nodatavals) # Parse extra metadata from tags, if supported parsers = {'ENVI': _parse_envi} - driver = riods.value.driver + driver = riods.driver if driver in parsers: - meta = parsers[driver](riods.value.tags(ns=driver)) + meta = parsers[driver](riods.tags(ns=driver)) for k, v in meta.items(): # Add values as coordinates if they match the band count, # as attributes otherwise if (isinstance(v, (list, np.ndarray)) and - len(v) == riods.value.count): + len(v) == riods.count): coords[k] = ('band', np.asarray(v)) else: attrs[k] = v - data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(riods)) + data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) - if cache and (chunks is None): + if cache and chunks is None: data = indexing.MemoryCachedArray(data) result = DataArray(data=data, dims=('band', 'y', 'x'), @@ -305,6 +314,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, lock=lock) # Make the file closeable - result._file_obj = riods + result._file_obj = manager return result diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index cd84431f6b7..b009342efb6 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function -import functools import warnings from distutils.version import LooseVersion from io import BytesIO @@ -11,7 +10,9 @@ from ..core.indexing import NumpyIndexingAdapter from ..core.pycompat import OrderedDict, basestring, iteritems from ..core.utils import Frozen, FrozenOrderedDict -from .common import BackendArray, DataStorePickleMixin, WritableCFDataStore +from .common import BackendArray, WritableCFDataStore +from .locks import get_write_lock +from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import ( encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) @@ -40,31 +41,26 @@ def __init__(self, variable_name, datastore): str(array.dtype.itemsize)) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name].data def __getitem__(self, key): - with self.datastore.ensure_open(autoclose=True): - data = NumpyIndexingAdapter(self.get_array())[key] - # Copy data if the source file is mmapped. - # This makes things consistent - # with the netCDF4 library by ensuring - # we can safely read arrays even - # after closing associated files. - copy = self.datastore.ds.use_mmap - return np.array(data, dtype=self.dtype, copy=copy) + data = NumpyIndexingAdapter(self.get_array())[key] + # Copy data if the source file is mmapped. This makes things consistent + # with the netCDF4 library by ensuring we can safely read arrays even + # after closing associated files. + copy = self.datastore.ds.use_mmap + return np.array(data, dtype=self.dtype, copy=copy) def __setitem__(self, key, value): - with self.datastore.ensure_open(autoclose=True): - data = self.datastore.ds.variables[self.variable_name] - try: - data[key] = value - except TypeError: - if key is Ellipsis: - # workaround for GH: scipy/scipy#6880 - data[:] = value - else: - raise + data = self.datastore.ds.variables[self.variable_name] + try: + data[key] = value + except TypeError: + if key is Ellipsis: + # workaround for GH: scipy/scipy#6880 + data[:] = value + else: + raise def _open_scipy_netcdf(filename, mode, mmap, version): @@ -106,7 +102,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version): raise -class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): +class ScipyDataStore(WritableCFDataStore): """Store for reading and writing data via scipy.io.netcdf. This store has the advantage of being able to be initialized with a @@ -116,7 +112,7 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None, autoclose=False, lock=None): + mmap=None, lock=None): import scipy import scipy.io @@ -140,34 +136,38 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - opener = functools.partial(_open_scipy_netcdf, - filename=filename_or_obj, - mode=mode, mmap=mmap, version=version) - self._ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + if (lock is None and mode != 'r' and + isinstance(filename_or_obj, basestring)): + lock = get_write_lock(filename_or_obj) + + if isinstance(filename_or_obj, basestring): + manager = CachingFileManager( + _open_scipy_netcdf, filename_or_obj, mode=mode, lock=lock, + kwargs=dict(mmap=mmap, version=version)) + else: + scipy_dataset = _open_scipy_netcdf( + filename_or_obj, mode=mode, mmap=mmap, version=version) + manager = DummyFileManager(scipy_dataset) + + self._manager = manager - super(ScipyDataStore, self).__init__(writer, lock=lock) + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - return Variable(var.dimensions, ScipyArrayWrapper(name, self), - _decode_attrs(var._attributes)) + return Variable(var.dimensions, ScipyArrayWrapper(name, self), + _decode_attrs(var._attributes)) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(_decode_attrs(self.ds._attributes)) + return Frozen(_decode_attrs(self.ds._attributes)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -176,22 +176,20 @@ def get_encoding(self): return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if name in self.ds.dimensions: - raise ValueError('%s does not support modifying dimensions' - % type(self).__name__) - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, dim_length) + if name in self.ds.dimensions: + raise ValueError('%s does not support modifying dimensions' + % type(self).__name__) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, dim_length) def _validate_attr_key(self, key): if not is_valid_nc3_name(key): raise ValueError("Not a valid attribute name") def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self._validate_attr_key(key) - value = encode_nc3_attr_value(value) - setattr(self.ds, key, value) + self._validate_attr_key(key) + value = encode_nc3_attr_value(value) + setattr(self.ds, key, value) def encode_variable(self, variable): variable = encode_nc3_variable(variable) @@ -219,27 +217,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, data - def sync(self, compute=True): - if not compute: - raise NotImplementedError( - 'compute=False is not supported for the scipy backend yet') - with self.ensure_open(autoclose=True): - super(ScipyDataStore, self).sync(compute=compute) - self.ds.flush() + def sync(self): + self.ds.sync() def close(self): - self.ds.close() - self._isopen = False - - def __exit__(self, type, value, tb): - self.close() - - def __setstate__(self, state): - filename = state['_opener'].keywords['filename'] - if hasattr(filename, 'seek'): - # it's a file-like object - # seek to the start of the file so scipy can read it - filename.seek(0) - super(ScipyDataStore, self).__setstate__(state) - self._ds = None - self._isopen = False + self._manager.close() diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index c5043ce8a47..06fe7f04e4f 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function -from itertools import product from distutils.version import LooseVersion import numpy as np @@ -80,14 +79,14 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): if var_chunks and enc_chunks is None: if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks): raise ValueError( - "Zarr requires uniform chunk sizes excpet for final chunk." - " Variable %r has incompatible chunks. Consider " + "Zarr requires uniform chunk sizes except for final chunk." + " Variable dask chunks %r are incompatible. Consider " "rechunking using `chunk()`." % (var_chunks,)) if any((chunks[0] < chunks[-1]) for chunks in var_chunks): raise ValueError( - "Final chunk of Zarr array must be smaller than first. " - "Variable %r has incompatible chunks. Consider rechunking " - "using `chunk()`." % var_chunks) + "Final chunk of Zarr array must be the same size or smaller " + "than the first. Variable Dask chunks %r are incompatible. " + "Consider rechunking using `chunk()`." % var_chunks) # return the first chunk for each dimension return tuple(chunk[0] for chunk in var_chunks) @@ -103,9 +102,8 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): enc_chunks_tuple = tuple(enc_chunks) if len(enc_chunks_tuple) != ndim: - raise ValueError("zarr chunks tuple %r must have same length as " - "variable.ndim %g" % - (enc_chunks_tuple, ndim)) + # throw away encoding chunks, start over + return _determine_zarr_chunks(None, var_chunks, ndim) for x in enc_chunks_tuple: if not isinstance(x, int): @@ -128,7 +126,7 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): # threads if var_chunks and enc_chunks_tuple: for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks): - for dchunk in dchunks: + for dchunk in dchunks[:-1]: if dchunk % zchunk: raise NotImplementedError( "Specified zarr chunks %r would overlap multiple dask " @@ -136,6 +134,13 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): " Consider rechunking the data using " "`chunk()` or specifying different chunks in encoding." % (enc_chunks_tuple, var_chunks)) + if dchunks[-1] > zchunk: + raise ValueError( + "Final chunk of Zarr array must be the same size or " + "smaller than the first. The specified Zarr chunk " + "encoding is %r, but %r in variable Dask chunks %r is " + "incompatible. Consider rechunking using `chunk()`." + % (enc_chunks_tuple, dchunks, var_chunks)) return enc_chunks_tuple raise AssertionError( @@ -219,8 +224,7 @@ class ZarrStore(AbstractWritableDataStore): """ @classmethod - def open_group(cls, store, mode='r', synchronizer=None, group=None, - writer=None): + def open_group(cls, store, mode='r', synchronizer=None, group=None): import zarr min_zarr = '2.2' @@ -232,24 +236,14 @@ def open_group(cls, store, mode='r', synchronizer=None, group=None, "#installation" % min_zarr) zarr_group = zarr.open_group(store=store, mode=mode, synchronizer=synchronizer, path=group) - return cls(zarr_group, writer=writer) + return cls(zarr_group) - def __init__(self, zarr_group, writer=None): + def __init__(self, zarr_group): self.ds = zarr_group self._read_only = self.ds.read_only self._synchronizer = self.ds.synchronizer self._group = self.ds.path - if writer is None: - # by default, we should not need a lock for writing zarr because - # we do not (yet) allow overlapping chunks during write - zarr_writer = ArrayWriter(lock=False) - else: - zarr_writer = writer - - # do we need to define attributes for all of the opener keyword args? - super(ZarrStore, self).__init__(zarr_writer) - def open_store_variable(self, name, zarr_array): data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, @@ -336,8 +330,8 @@ def store(self, variables, attributes, *args, **kwargs): AbstractWritableDataStore.store(self, variables, attributes, *args, **kwargs) - def sync(self, compute=True): - self.delayed_store = self.writer.sync(compute=compute) + def sync(self): + pass def open_zarr(store, group=None, synchronizer=None, auto_chunk=True, diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py new file mode 100644 index 00000000000..83e8c7a7e4b --- /dev/null +++ b/xarray/coding/cftime_offsets.py @@ -0,0 +1,735 @@ +"""Time offset classes for use with cftime.datetime objects""" +# The offset classes and mechanisms for generating time ranges defined in +# this module were copied/adapted from those defined in pandas. See in +# particular the objects and methods defined in pandas.tseries.offsets +# and pandas.core.indexes.datetimes. + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import re +from datetime import timedelta +from functools import partial + +import numpy as np + +from ..core.pycompat import basestring +from .cftimeindex import CFTimeIndex, _parse_iso8601_with_reso +from .times import format_cftime_datetime + + +def get_date_type(calendar): + """Return the cftime date type for a given calendar name.""" + try: + import cftime + except ImportError: + raise ImportError( + 'cftime is required for dates with non-standard calendars') + else: + calendars = { + 'noleap': cftime.DatetimeNoLeap, + '360_day': cftime.Datetime360Day, + '365_day': cftime.DatetimeNoLeap, + '366_day': cftime.DatetimeAllLeap, + 'gregorian': cftime.DatetimeGregorian, + 'proleptic_gregorian': cftime.DatetimeProlepticGregorian, + 'julian': cftime.DatetimeJulian, + 'all_leap': cftime.DatetimeAllLeap, + 'standard': cftime.DatetimeProlepticGregorian + } + return calendars[calendar] + + +class BaseCFTimeOffset(object): + _freq = None + + def __init__(self, n=1): + if not isinstance(n, int): + raise TypeError( + "The provided multiple 'n' must be an integer. " + "Instead a value of type {!r} was provided.".format(type(n))) + self.n = n + + def rule_code(self): + return self._freq + + def __eq__(self, other): + return self.n == other.n and self.rule_code() == other.rule_code() + + def __ne__(self, other): + return not self == other + + def __add__(self, other): + return self.__apply__(other) + + def __sub__(self, other): + import cftime + + if isinstance(other, cftime.datetime): + raise TypeError('Cannot subtract a cftime.datetime ' + 'from a time offset.') + elif type(other) == type(self): + return type(self)(self.n - other.n) + else: + return NotImplemented + + def __mul__(self, other): + return type(self)(n=other * self.n) + + def __neg__(self): + return self * -1 + + def __rmul__(self, other): + return self.__mul__(other) + + def __radd__(self, other): + return self.__add__(other) + + def __rsub__(self, other): + if isinstance(other, BaseCFTimeOffset) and type(self) != type(other): + raise TypeError('Cannot subtract cftime offsets of differing ' + 'types') + return -self + other + + def __apply__(self): + return NotImplemented + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + test_date = (self + date) - self + return date == test_date + + def rollforward(self, date): + if self.onOffset(date): + return date + else: + return date + type(self)() + + def rollback(self, date): + if self.onOffset(date): + return date + else: + return date - type(self)() + + def __str__(self): + return '<{}: n={}>'.format(type(self).__name__, self.n) + + def __repr__(self): + return str(self) + + +def _days_in_month(date): + """The number of days in the month of the given date""" + if date.month == 12: + reference = type(date)(date.year + 1, 1, 1) + else: + reference = type(date)(date.year, date.month + 1, 1) + return (reference - timedelta(days=1)).day + + +def _adjust_n_months(other_day, n, reference_day): + """Adjust the number of times a monthly offset is applied based + on the day of a given date, and the reference day provided. + """ + if n > 0 and other_day < reference_day: + n = n - 1 + elif n <= 0 and other_day > reference_day: + n = n + 1 + return n + + +def _adjust_n_years(other, n, month, reference_day): + """Adjust the number of times an annual offset is applied based on + another date, and the reference day provided""" + if n > 0: + if other.month < month or (other.month == month and + other.day < reference_day): + n -= 1 + else: + if other.month > month or (other.month == month and + other.day > reference_day): + n += 1 + return n + + +def _shift_months(date, months, day_option='start'): + """Shift the date to a month start or end a given number of months away. + """ + delta_year = (date.month + months) // 12 + month = (date.month + months) % 12 + + if month == 0: + month = 12 + delta_year = delta_year - 1 + year = date.year + delta_year + + if day_option == 'start': + day = 1 + elif day_option == 'end': + reference = type(date)(year, month, 1) + day = _days_in_month(reference) + else: + raise ValueError(day_option) + return date.replace(year=year, month=month, day=day) + + +class MonthBegin(BaseCFTimeOffset): + _freq = 'MS' + + def __apply__(self, other): + n = _adjust_n_months(other.day, self.n, 1) + return _shift_months(other, n, 'start') + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == 1 + + +class MonthEnd(BaseCFTimeOffset): + _freq = 'M' + + def __apply__(self, other): + n = _adjust_n_months(other.day, self.n, _days_in_month(other)) + return _shift_months(other, n, 'end') + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == _days_in_month(date) + + +_MONTH_ABBREVIATIONS = { + 1: 'JAN', + 2: 'FEB', + 3: 'MAR', + 4: 'APR', + 5: 'MAY', + 6: 'JUN', + 7: 'JUL', + 8: 'AUG', + 9: 'SEP', + 10: 'OCT', + 11: 'NOV', + 12: 'DEC' +} + + +class YearOffset(BaseCFTimeOffset): + _freq = None + _day_option = None + _default_month = None + + def __init__(self, n=1, month=None): + BaseCFTimeOffset.__init__(self, n) + if month is None: + self.month = self._default_month + else: + self.month = month + if not isinstance(self.month, int): + raise TypeError("'self.month' must be an integer value between 1 " + "and 12. Instead, it was set to a value of " + "{!r}".format(self.month)) + elif not (1 <= self.month <= 12): + raise ValueError("'self.month' must be an integer value between 1 " + "and 12. Instead, it was set to a value of " + "{!r}".format(self.month)) + + def __apply__(self, other): + if self._day_option == 'start': + reference_day = 1 + elif self._day_option == 'end': + reference_day = _days_in_month(other) + else: + raise ValueError(self._day_option) + years = _adjust_n_years(other, self.n, self.month, reference_day) + months = years * 12 + (self.month - other.month) + return _shift_months(other, months, self._day_option) + + def __sub__(self, other): + import cftime + + if isinstance(other, cftime.datetime): + raise TypeError('Cannot subtract cftime.datetime from offset.') + elif type(other) == type(self) and other.month == self.month: + return type(self)(self.n - other.n, month=self.month) + else: + return NotImplemented + + def __mul__(self, other): + return type(self)(n=other * self.n, month=self.month) + + def rule_code(self): + return '{}-{}'.format(self._freq, _MONTH_ABBREVIATIONS[self.month]) + + def __str__(self): + return '<{}: n={}, month={}>'.format( + type(self).__name__, self.n, self.month) + + +class YearBegin(YearOffset): + _freq = 'AS' + _day_option = 'start' + _default_month = 1 + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == 1 and date.month == self.month + + def rollforward(self, date): + """Roll date forward to nearest start of year""" + if self.onOffset(date): + return date + else: + return date + YearBegin(month=self.month) + + def rollback(self, date): + """Roll date backward to nearest start of year""" + if self.onOffset(date): + return date + else: + return date - YearBegin(month=self.month) + + +class YearEnd(YearOffset): + _freq = 'A' + _day_option = 'end' + _default_month = 12 + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == _days_in_month(date) and date.month == self.month + + def rollforward(self, date): + """Roll date forward to nearest end of year""" + if self.onOffset(date): + return date + else: + return date + YearEnd(month=self.month) + + def rollback(self, date): + """Roll date backward to nearest end of year""" + if self.onOffset(date): + return date + else: + return date - YearEnd(month=self.month) + + +class Day(BaseCFTimeOffset): + _freq = 'D' + + def __apply__(self, other): + return other + timedelta(days=self.n) + + +class Hour(BaseCFTimeOffset): + _freq = 'H' + + def __apply__(self, other): + return other + timedelta(hours=self.n) + + +class Minute(BaseCFTimeOffset): + _freq = 'T' + + def __apply__(self, other): + return other + timedelta(minutes=self.n) + + +class Second(BaseCFTimeOffset): + _freq = 'S' + + def __apply__(self, other): + return other + timedelta(seconds=self.n) + + +_FREQUENCIES = { + 'A': YearEnd, + 'AS': YearBegin, + 'Y': YearEnd, + 'YS': YearBegin, + 'M': MonthEnd, + 'MS': MonthBegin, + 'D': Day, + 'H': Hour, + 'T': Minute, + 'min': Minute, + 'S': Second, + 'AS-JAN': partial(YearBegin, month=1), + 'AS-FEB': partial(YearBegin, month=2), + 'AS-MAR': partial(YearBegin, month=3), + 'AS-APR': partial(YearBegin, month=4), + 'AS-MAY': partial(YearBegin, month=5), + 'AS-JUN': partial(YearBegin, month=6), + 'AS-JUL': partial(YearBegin, month=7), + 'AS-AUG': partial(YearBegin, month=8), + 'AS-SEP': partial(YearBegin, month=9), + 'AS-OCT': partial(YearBegin, month=10), + 'AS-NOV': partial(YearBegin, month=11), + 'AS-DEC': partial(YearBegin, month=12), + 'A-JAN': partial(YearEnd, month=1), + 'A-FEB': partial(YearEnd, month=2), + 'A-MAR': partial(YearEnd, month=3), + 'A-APR': partial(YearEnd, month=4), + 'A-MAY': partial(YearEnd, month=5), + 'A-JUN': partial(YearEnd, month=6), + 'A-JUL': partial(YearEnd, month=7), + 'A-AUG': partial(YearEnd, month=8), + 'A-SEP': partial(YearEnd, month=9), + 'A-OCT': partial(YearEnd, month=10), + 'A-NOV': partial(YearEnd, month=11), + 'A-DEC': partial(YearEnd, month=12) +} + + +_FREQUENCY_CONDITION = '|'.join(_FREQUENCIES.keys()) +_PATTERN = '^((?P\d+)|())(?P({0}))$'.format( + _FREQUENCY_CONDITION) + + +def to_offset(freq): + """Convert a frequency string to the appropriate subclass of + BaseCFTimeOffset.""" + if isinstance(freq, BaseCFTimeOffset): + return freq + else: + try: + freq_data = re.match(_PATTERN, freq).groupdict() + except AttributeError: + raise ValueError('Invalid frequency string provided') + + freq = freq_data['freq'] + multiples = freq_data['multiple'] + if multiples is None: + multiples = 1 + else: + multiples = int(multiples) + + return _FREQUENCIES[freq](n=multiples) + + +def to_cftime_datetime(date_str_or_date, calendar=None): + import cftime + + if isinstance(date_str_or_date, basestring): + if calendar is None: + raise ValueError( + 'If converting a string to a cftime.datetime object, ' + 'a calendar type must be provided') + date, _ = _parse_iso8601_with_reso(get_date_type(calendar), + date_str_or_date) + return date + elif isinstance(date_str_or_date, cftime.datetime): + return date_str_or_date + else: + raise TypeError("date_str_or_date must be a string or a " + 'subclass of cftime.datetime. Instead got ' + '{!r}.'.format(date_str_or_date)) + + +def normalize_date(date): + """Round datetime down to midnight.""" + return date.replace(hour=0, minute=0, second=0, microsecond=0) + + +def _maybe_normalize_date(date, normalize): + """Round datetime down to midnight if normalize is True.""" + if normalize: + return normalize_date(date) + else: + return date + + +def _generate_linear_range(start, end, periods): + """Generate an equally-spaced sequence of cftime.datetime objects between + and including two dates (whose length equals the number of periods).""" + import cftime + + total_seconds = (end - start).total_seconds() + values = np.linspace(0., total_seconds, periods, endpoint=True) + units = 'seconds since {}'.format(format_cftime_datetime(start)) + calendar = start.calendar + return cftime.num2date(values, units=units, calendar=calendar, + only_use_cftime_datetimes=True) + + +def _generate_range(start, end, periods, offset): + """Generate a regular range of cftime.datetime objects with a + given time offset. + + Adapted from pandas.tseries.offsets.generate_range. + + Parameters + ---------- + start : cftime.datetime, or None + Start of range + end : cftime.datetime, or None + End of range + periods : int, or None + Number of elements in the sequence + offset : BaseCFTimeOffset + An offset class designed for working with cftime.datetime objects + + Returns + ------- + A generator object + """ + if start: + start = offset.rollforward(start) + + if end: + end = offset.rollback(end) + + if periods is None and end < start: + end = None + periods = 0 + + if end is None: + end = start + (periods - 1) * offset + + if start is None: + start = end - (periods - 1) * offset + + current = start + if offset.n >= 0: + while current <= end: + yield current + + next_date = current + offset + if next_date <= current: + raise ValueError('Offset {offset} did not increment date' + .format(offset=offset)) + current = next_date + else: + while current >= end: + yield current + + next_date = current + offset + if next_date >= current: + raise ValueError('Offset {offset} did not decrement date' + .format(offset=offset)) + current = next_date + + +def _count_not_none(*args): + """Compute the number of non-None arguments.""" + return sum([arg is not None for arg in args]) + + +def cftime_range(start=None, end=None, periods=None, freq='D', + tz=None, normalize=False, name=None, closed=None, + calendar='standard'): + """Return a fixed frequency CFTimeIndex. + + Parameters + ---------- + start : str or cftime.datetime, optional + Left bound for generating dates. + end : str or cftime.datetime, optional + Right bound for generating dates. + periods : integer, optional + Number of periods to generate. + freq : str, default 'D', BaseCFTimeOffset, or None + Frequency strings can have multiples, e.g. '5H'. + normalize : bool, default False + Normalize start/end dates to midnight before generating date range. + name : str, default None + Name of the resulting index + closed : {None, 'left', 'right'}, optional + Make the interval closed with respect to the given frequency to the + 'left', 'right', or both sides (None, the default). + calendar : str + Calendar type for the datetimes (default 'standard'). + + Returns + ------- + CFTimeIndex + + Notes + ----- + + This function is an analog of ``pandas.date_range`` for use in generating + sequences of ``cftime.datetime`` objects. It supports most of the + features of ``pandas.date_range`` (e.g. specifying how the index is + ``closed`` on either side, or whether or not to ``normalize`` the start and + end bounds); however, there are some notable exceptions: + + - You cannot specify a ``tz`` (time zone) argument. + - Start or end dates specified as partial-datetime strings must use the + `ISO-8601 format `_. + - It supports many, but not all, frequencies supported by + ``pandas.date_range``. For example it does not currently support any of + the business-related, semi-monthly, or sub-second frequencies. + - Compound sub-monthly frequencies are not supported, e.g. '1H1min', as + these can easily be written in terms of the finest common resolution, + e.g. '61min'. + + Valid simple frequency strings for use with ``cftime``-calendars include + any multiples of the following. + + +--------+-----------------------+ + | Alias | Description | + +========+=======================+ + | A, Y | Year-end frequency | + +--------+-----------------------+ + | AS, YS | Year-start frequency | + +--------+-----------------------+ + | M | Month-end frequency | + +--------+-----------------------+ + | MS | Month-start frequency | + +--------+-----------------------+ + | D | Day frequency | + +--------+-----------------------+ + | H | Hour frequency | + +--------+-----------------------+ + | T, min | Minute frequency | + +--------+-----------------------+ + | S | Second frequency | + +--------+-----------------------+ + + Any multiples of the following anchored offsets are also supported. + + +----------+-------------------------------------------------------------------+ + | Alias | Description | + +==========+===================================================================+ + | A(S)-JAN | Annual frequency, anchored at the end (or beginning) of January | + +----------+-------------------------------------------------------------------+ + | A(S)-FEB | Annual frequency, anchored at the end (or beginning) of February | + +----------+-------------------------------------------------------------------+ + | A(S)-MAR | Annual frequency, anchored at the end (or beginning) of March | + +----------+-------------------------------------------------------------------+ + | A(S)-APR | Annual frequency, anchored at the end (or beginning) of April | + +----------+-------------------------------------------------------------------+ + | A(S)-MAY | Annual frequency, anchored at the end (or beginning) of May | + +----------+-------------------------------------------------------------------+ + | A(S)-JUN | Annual frequency, anchored at the end (or beginning) of June | + +----------+-------------------------------------------------------------------+ + | A(S)-JUL | Annual frequency, anchored at the end (or beginning) of July | + +----------+-------------------------------------------------------------------+ + | A(S)-AUG | Annual frequency, anchored at the end (or beginning) of August | + +----------+-------------------------------------------------------------------+ + | A(S)-SEP | Annual frequency, anchored at the end (or beginning) of September | + +----------+-------------------------------------------------------------------+ + | A(S)-OCT | Annual frequency, anchored at the end (or beginning) of October | + +----------+-------------------------------------------------------------------+ + | A(S)-NOV | Annual frequency, anchored at the end (or beginning) of November | + +----------+-------------------------------------------------------------------+ + | A(S)-DEC | Annual frequency, anchored at the end (or beginning) of December | + +----------+-------------------------------------------------------------------+ + + Finally, the following calendar aliases are supported. + + +--------------------------------+---------------------------------------+ + | Alias | Date type | + +================================+=======================================+ + | standard, proleptic_gregorian | ``cftime.DatetimeProlepticGregorian`` | + +--------------------------------+---------------------------------------+ + | gregorian | ``cftime.DatetimeGregorian`` | + +--------------------------------+---------------------------------------+ + | noleap, 365_day | ``cftime.DatetimeNoLeap`` | + +--------------------------------+---------------------------------------+ + | all_leap, 366_day | ``cftime.DatetimeAllLeap`` | + +--------------------------------+---------------------------------------+ + | 360_day | ``cftime.Datetime360Day`` | + +--------------------------------+---------------------------------------+ + | julian | ``cftime.DatetimeJulian`` | + +--------------------------------+---------------------------------------+ + + Examples + -------- + + This function returns a ``CFTimeIndex``, populated with ``cftime.datetime`` + objects associated with the specified calendar type, e.g. + + >>> xr.cftime_range(start='2000', periods=6, freq='2MS', calendar='noleap') + CFTimeIndex([2000-01-01 00:00:00, 2000-03-01 00:00:00, 2000-05-01 00:00:00, + 2000-07-01 00:00:00, 2000-09-01 00:00:00, 2000-11-01 00:00:00], + dtype='object') + + As in the standard pandas function, three of the ``start``, ``end``, + ``periods``, or ``freq`` arguments must be specified at a given time, with + the other set to ``None``. See the `pandas documentation + `_ + for more examples of the behavior of ``date_range`` with each of the + parameters. + + See Also + -------- + pandas.date_range + """ # noqa: E501 + # Adapted from pandas.core.indexes.datetimes._generate_range. + if _count_not_none(start, end, periods, freq) != 3: + raise ValueError( + "Of the arguments 'start', 'end', 'periods', and 'freq', three " + "must be specified at a time.") + + if start is not None: + start = to_cftime_datetime(start, calendar) + start = _maybe_normalize_date(start, normalize) + if end is not None: + end = to_cftime_datetime(end, calendar) + end = _maybe_normalize_date(end, normalize) + + if freq is None: + dates = _generate_linear_range(start, end, periods) + else: + offset = to_offset(freq) + dates = np.array(list(_generate_range(start, end, periods, offset))) + + left_closed = False + right_closed = False + + if closed is None: + left_closed = True + right_closed = True + elif closed == 'left': + left_closed = True + elif closed == 'right': + right_closed = True + else: + raise ValueError("Closed must be either 'left', 'right' or None") + + if (not left_closed and len(dates) and + start is not None and dates[0] == start): + dates = dates[1:] + if (not right_closed and len(dates) and + end is not None and dates[-1] == end): + dates = dates[:-1] + + return CFTimeIndex(dates, name=name) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index eb8cae2f398..2ce996b2bd2 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -1,5 +1,48 @@ +"""DatetimeIndex analog for cftime.datetime objects""" +# The pandas.Index subclass defined here was copied and adapted for +# use with cftime.datetime objects based on the source code defining +# pandas.DatetimeIndex. + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + from __future__ import absolute_import + import re +import warnings from datetime import timedelta import numpy as np @@ -8,6 +51,8 @@ from xarray.core import pycompat from xarray.core.utils import is_scalar +from .times import cftime_to_nptime, infer_calendar_name, _STANDARD_CALENDARS + def named(name, pattern): return '(?P<' + name + '>' + pattern + ')' @@ -116,28 +161,43 @@ def f(self): def get_date_type(self): - return type(self._data[0]) + if self._data.size: + return type(self._data[0]) + else: + return None def assert_all_valid_date_type(data): import cftime - sample = data[0] - date_type = type(sample) - if not isinstance(sample, cftime.datetime): - raise TypeError( - 'CFTimeIndex requires cftime.datetime ' - 'objects. Got object of {}.'.format(date_type)) - if not all(isinstance(value, date_type) for value in data): - raise TypeError( - 'CFTimeIndex requires using datetime ' - 'objects of all the same type. Got\n{}.'.format(data)) + if data.size: + sample = data[0] + date_type = type(sample) + if not isinstance(sample, cftime.datetime): + raise TypeError( + 'CFTimeIndex requires cftime.datetime ' + 'objects. Got object of {}.'.format(date_type)) + if not all(isinstance(value, date_type) for value in data): + raise TypeError( + 'CFTimeIndex requires using datetime ' + 'objects of all the same type. Got\n{}.'.format(data)) class CFTimeIndex(pd.Index): """Custom Index for working with CF calendars and dates All elements of a CFTimeIndex must be cftime.datetime objects. + + Parameters + ---------- + data : array or CFTimeIndex + Sequence of cftime.datetime objects to use in index + name : str, default None + Name of the resulting index + + See Also + -------- + cftime_range """ year = _field_accessor('year', 'The year of the datetime') month = _field_accessor('month', 'The month of the datetime') @@ -149,10 +209,14 @@ class CFTimeIndex(pd.Index): 'The microseconds of the datetime') date_type = property(get_date_type) - def __new__(cls, data): + def __new__(cls, data, name=None): + if name is None and hasattr(data, 'name'): + name = data.name + result = object.__new__(cls) - assert_all_valid_date_type(data) - result._data = np.array(data) + result._data = np.array(data, dtype='O') + assert_all_valid_date_type(result._data) + result.name = name return result def _partial_date_slice(self, resolution, parsed): @@ -254,3 +318,144 @@ def __contains__(self, key): def contains(self, key): """Needed for .loc based partial-string indexing""" return self.__contains__(key) + + def shift(self, n, freq): + """Shift the CFTimeIndex a multiple of the given frequency. + + See the documentation for :py:func:`~xarray.cftime_range` for a + complete listing of valid frequency strings. + + Parameters + ---------- + n : int + Periods to shift by + freq : str or datetime.timedelta + A frequency string or datetime.timedelta object to shift by + + Returns + ------- + CFTimeIndex + + See also + -------- + pandas.DatetimeIndex.shift + + Examples + -------- + >>> index = xr.cftime_range('2000', periods=1, freq='M') + >>> index + CFTimeIndex([2000-01-31 00:00:00], dtype='object') + >>> index.shift(1, 'M') + CFTimeIndex([2000-02-29 00:00:00], dtype='object') + """ + from .cftime_offsets import to_offset + + if not isinstance(n, int): + raise TypeError("'n' must be an int, got {}.".format(n)) + if isinstance(freq, timedelta): + return self + n * freq + elif isinstance(freq, pycompat.basestring): + return self + n * to_offset(freq) + else: + raise TypeError( + "'freq' must be of type " + "str or datetime.timedelta, got {}.".format(freq)) + + def __add__(self, other): + if isinstance(other, pd.TimedeltaIndex): + other = other.to_pytimedelta() + return CFTimeIndex(np.array(self) + other) + + def __radd__(self, other): + if isinstance(other, pd.TimedeltaIndex): + other = other.to_pytimedelta() + return CFTimeIndex(other + np.array(self)) + + def __sub__(self, other): + if isinstance(other, CFTimeIndex): + return pd.TimedeltaIndex(np.array(self) - np.array(other)) + elif isinstance(other, pd.TimedeltaIndex): + return CFTimeIndex(np.array(self) - other.to_pytimedelta()) + else: + return CFTimeIndex(np.array(self) - other) + + def _add_delta(self, deltas): + # To support TimedeltaIndex + CFTimeIndex with older versions of + # pandas. No longer used as of pandas 0.23. + return self + deltas + + def to_datetimeindex(self, unsafe=False): + """If possible, convert this index to a pandas.DatetimeIndex. + + Parameters + ---------- + unsafe : bool + Flag to turn off warning when converting from a CFTimeIndex with + a non-standard calendar to a DatetimeIndex (default ``False``). + + Returns + ------- + pandas.DatetimeIndex + + Raises + ------ + ValueError + If the CFTimeIndex contains dates that are not possible in the + standard calendar or outside the pandas.Timestamp-valid range. + + Warns + ----- + RuntimeWarning + If converting from a non-standard calendar to a DatetimeIndex. + + Warnings + -------- + Note that for non-standard calendars, this will change the calendar + type of the index. In that case the result of this method should be + used with caution. + + Examples + -------- + >>> import xarray as xr + >>> times = xr.cftime_range('2000', periods=2, calendar='gregorian') + >>> times + CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00], dtype='object') + >>> times.to_datetimeindex() + DatetimeIndex(['2000-01-01', '2000-01-02'], dtype='datetime64[ns]', freq=None) + """ # noqa: E501 + nptimes = cftime_to_nptime(self) + calendar = infer_calendar_name(self) + if calendar not in _STANDARD_CALENDARS and not unsafe: + warnings.warn( + 'Converting a CFTimeIndex with dates from a non-standard ' + 'calendar, {!r}, to a pandas.DatetimeIndex, which uses dates ' + 'from the standard calendar. This may lead to subtle errors ' + 'in operations that depend on the length of time between ' + 'dates.'.format(calendar), RuntimeWarning) + return pd.DatetimeIndex(nptimes) + + +def _parse_iso8601_without_reso(date_type, datetime_str): + date, _ = _parse_iso8601_with_reso(date_type, datetime_str) + return date + + +def _parse_array_of_cftime_strings(strings, date_type): + """Create a numpy array from an array of strings. + + For use in generating dates from strings for use with interp. Assumes the + array is either 0-dimensional or 1-dimensional. + + Parameters + ---------- + strings : array of strings + Strings to convert to dates + date_type : cftime.datetime type + Calendar type to use for dates + + Returns + ------- + np.array + """ + return np.array([_parse_iso8601_without_reso(date_type, s) + for s in strings.ravel()]).reshape(strings.shape) diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 87b17d9175e..3502fd773d7 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -9,8 +9,8 @@ from ..core.pycompat import bytes_type, dask_array_type, unicode_type from ..core.variable import Variable from .variables import ( - VariableCoder, lazy_elemwise_func, pop_to, - safe_setitem, unpack_for_decoding, unpack_for_encoding) + VariableCoder, lazy_elemwise_func, pop_to, safe_setitem, + unpack_for_decoding, unpack_for_encoding) def create_vlen_dtype(element_type): diff --git a/xarray/coding/times.py b/xarray/coding/times.py index d946e2ed378..dfc4b2fb023 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -9,10 +9,9 @@ import numpy as np import pandas as pd -from ..core.common import contains_cftime_datetimes from ..core import indexing +from ..core.common import contains_cftime_datetimes from ..core.formatting import first_n_items, format_timestamp, last_item -from ..core.options import OPTIONS from ..core.pycompat import PY3 from ..core.variable import Variable from .variables import ( @@ -61,8 +60,9 @@ def _require_standalone_cftime(): try: import cftime # noqa: F401 except ImportError: - raise ImportError('Using a CFTimeIndex requires the standalone ' - 'version of the cftime library.') + raise ImportError('Decoding times with non-standard calendars ' + 'or outside the pandas.Timestamp-valid range ' + 'requires the standalone cftime package.') def _netcdf_to_numpy_timeunit(units): @@ -84,41 +84,32 @@ def _unpack_netcdf_time_units(units): return delta_units, ref_date -def _decode_datetime_with_cftime(num_dates, units, calendar, - enable_cftimeindex): +def _decode_datetime_with_cftime(num_dates, units, calendar): cftime = _import_cftime() - if enable_cftimeindex: - _require_standalone_cftime() + + if cftime.__name__ == 'cftime': dates = np.asarray(cftime.num2date(num_dates, units, calendar, only_use_cftime_datetimes=True)) else: + # Must be using num2date from an old version of netCDF4 which + # does not have the only_use_cftime_datetimes option. dates = np.asarray(cftime.num2date(num_dates, units, calendar)) if (dates[np.nanargmin(num_dates)].year < 1678 or dates[np.nanargmax(num_dates)].year >= 2262): - if not enable_cftimeindex or calendar in _STANDARD_CALENDARS: + if calendar in _STANDARD_CALENDARS: warnings.warn( 'Unable to decode time axis into full ' 'numpy.datetime64 objects, continuing using dummy ' 'cftime.datetime objects instead, reason: dates out ' 'of range', SerializationWarning, stacklevel=3) else: - if enable_cftimeindex: - if calendar in _STANDARD_CALENDARS: - dates = cftime_to_nptime(dates) - else: - try: - dates = cftime_to_nptime(dates) - except ValueError as e: - warnings.warn( - 'Unable to decode time axis into full ' - 'numpy.datetime64 objects, continuing using ' - 'dummy cftime.datetime objects instead, reason:' - '{0}'.format(e), SerializationWarning, stacklevel=3) + if calendar in _STANDARD_CALENDARS: + dates = cftime_to_nptime(dates) return dates -def _decode_cf_datetime_dtype(data, units, calendar, enable_cftimeindex): +def _decode_cf_datetime_dtype(data, units, calendar): # Verify that at least the first and last date can be decoded # successfully. Otherwise, tracebacks end up swallowed by # Dataset.__repr__ when users try to view their lazily decoded array. @@ -128,8 +119,7 @@ def _decode_cf_datetime_dtype(data, units, calendar, enable_cftimeindex): last_item(values) or [0]]) try: - result = decode_cf_datetime(example_value, units, calendar, - enable_cftimeindex) + result = decode_cf_datetime(example_value, units, calendar) except Exception: calendar_msg = ('the default calendar' if calendar is None else 'calendar %r' % calendar) @@ -145,8 +135,7 @@ def _decode_cf_datetime_dtype(data, units, calendar, enable_cftimeindex): return dtype -def decode_cf_datetime(num_dates, units, calendar=None, - enable_cftimeindex=False): +def decode_cf_datetime(num_dates, units, calendar=None): """Given an array of numeric dates in netCDF format, convert it into a numpy array of date time objects. @@ -183,8 +172,11 @@ def decode_cf_datetime(num_dates, units, calendar=None, # fixes: https://github.com/pydata/pandas/issues/14068 # these lines check if the the lowest or the highest value in dates # cause an OutOfBoundsDatetime (Overflow) error - pd.to_timedelta(flat_num_dates.min(), delta) + ref_date - pd.to_timedelta(flat_num_dates.max(), delta) + ref_date + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'invalid value encountered', + RuntimeWarning) + pd.to_timedelta(flat_num_dates.min(), delta) + ref_date + pd.to_timedelta(flat_num_dates.max(), delta) + ref_date # Cast input dates to integers of nanoseconds because `pd.to_datetime` # works much faster when dealing with integers @@ -197,8 +189,7 @@ def decode_cf_datetime(num_dates, units, calendar=None, except (OutOfBoundsDatetime, OverflowError): dates = _decode_datetime_with_cftime( - flat_num_dates.astype(np.float), units, calendar, - enable_cftimeindex) + flat_num_dates.astype(np.float), units, calendar) return dates.reshape(num_dates.shape) @@ -288,7 +279,16 @@ def cftime_to_nptime(times): times = np.asarray(times) new = np.empty(times.shape, dtype='M8[ns]') for i, t in np.ndenumerate(times): - dt = datetime(t.year, t.month, t.day, t.hour, t.minute, t.second) + try: + # Use pandas.Timestamp in place of datetime.datetime, because + # NumPy casts it safely it np.datetime64[ns] for dates outside + # 1678 to 2262 (this is not currently the case for + # datetime.datetime). + dt = pd.Timestamp(t.year, t.month, t.day, t.hour, t.minute, + t.second, t.microsecond) + except ValueError as e: + raise ValueError('Cannot convert date {} to a date in the ' + 'standard calendar. Reason: {}.'.format(t, e)) new[i] = np.datetime64(dt) return new @@ -358,7 +358,12 @@ def encode_cf_datetime(dates, units=None, calendar=None): delta_units = _netcdf_to_numpy_timeunit(delta) time_delta = np.timedelta64(1, delta_units).astype('timedelta64[ns]') ref_date = np.datetime64(pd.Timestamp(ref_date)) - num = (dates - ref_date) / time_delta + + # Wrap the dates in a DatetimeIndex to do the subtraction to ensure + # an OverflowError is raised if the ref_date is too far away from + # dates to be encoded (GH 2272). + num = (pd.DatetimeIndex(dates.ravel()) - ref_date) / time_delta + num = num.values.reshape(dates.shape) except (OutOfBoundsDatetime, OverflowError): num = _encode_datetime_with_cftime(dates, units, calendar) @@ -396,15 +401,12 @@ def encode(self, variable, name=None): def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) - enable_cftimeindex = OPTIONS['enable_cftimeindex'] if 'units' in attrs and 'since' in attrs['units']: units = pop_to(attrs, encoding, 'units') calendar = pop_to(attrs, encoding, 'calendar') - dtype = _decode_cf_datetime_dtype( - data, units, calendar, enable_cftimeindex) + dtype = _decode_cf_datetime_dtype(data, units, calendar) transform = partial( - decode_cf_datetime, units=units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) + decode_cf_datetime, units=units, calendar=calendar) data = lazy_elemwise_func(data, transform, dtype) return Variable(dims, data, attrs, encoding) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 1207f5743cb..b86b77a3707 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -63,7 +63,10 @@ def dtype(self): return np.dtype(self._dtype) def __getitem__(self, key): - return self.func(self.array[key]) + return type(self)(self.array[key], self.func, self.dtype) + + def __array__(self, dtype=None): + return self.func(self.array) def __repr__(self): return ("%s(%r, func=%r, dtype=%r)" % diff --git a/xarray/conventions.py b/xarray/conventions.py index 67dcb8d6d4e..f60ee6b2c15 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -6,11 +6,11 @@ import numpy as np import pandas as pd -from .coding import times, strings, variables +from .coding import strings, times, variables from .coding.variables import SerializationWarning from .core import duck_array_ops, indexing from .core.pycompat import ( - OrderedDict, basestring, bytes_type, iteritems, dask_array_type, + OrderedDict, basestring, bytes_type, dask_array_type, iteritems, unicode_type) from .core.variable import IndexVariable, Variable, as_variable diff --git a/xarray/convert.py b/xarray/convert.py index f3a5ccb2ce5..6cff72103ff 100644 --- a/xarray/convert.py +++ b/xarray/convert.py @@ -2,7 +2,8 @@ """ from __future__ import absolute_import, division, print_function -from collections import OrderedDict +from collections import Counter + import numpy as np import pandas as pd @@ -156,8 +157,12 @@ def to_iris(dataarray): if coord.dims: axis = dataarray.get_axis_num(coord.dims) if coord_name in dataarray.dims: - iris_coord = iris.coords.DimCoord(coord.values, **coord_args) - dim_coords.append((iris_coord, axis)) + try: + iris_coord = iris.coords.DimCoord(coord.values, **coord_args) + dim_coords.append((iris_coord, axis)) + except ValueError: + iris_coord = iris.coords.AuxCoord(coord.values, **coord_args) + aux_coords.append((iris_coord, axis)) else: iris_coord = iris.coords.AuxCoord(coord.values, **coord_args) aux_coords.append((iris_coord, axis)) @@ -183,7 +188,7 @@ def _iris_obj_to_attrs(obj): 'long_name': obj.long_name} if obj.units.calendar: attrs['calendar'] = obj.units.calendar - if obj.units.origin != '1': + if obj.units.origin != '1' and not obj.units.is_unknown(): attrs['units'] = obj.units.origin attrs.update(obj.attributes) return dict((k, v) for k, v in attrs.items() if v is not None) @@ -206,34 +211,46 @@ def _iris_cell_methods_to_str(cell_methods_obj): return ' '.join(cell_methods) +def _name(iris_obj, default='unknown'): + """ Mimicks `iris_obj.name()` but with different name resolution order. + + Similar to iris_obj.name() method, but using iris_obj.var_name first to + enable roundtripping. + """ + return (iris_obj.var_name or iris_obj.standard_name or + iris_obj.long_name or default) + + def from_iris(cube): """ Convert a Iris cube into an DataArray """ import iris.exceptions from xarray.core.pycompat import dask_array_type - name = cube.var_name + name = _name(cube) + if name == 'unknown': + name = None dims = [] for i in range(cube.ndim): try: dim_coord = cube.coord(dim_coords=True, dimensions=(i,)) - dims.append(dim_coord.var_name) + dims.append(_name(dim_coord)) except iris.exceptions.CoordinateNotFoundError: dims.append("dim_{}".format(i)) + if len(set(dims)) != len(dims): + duplicates = [k for k, v in Counter(dims).items() if v > 1] + raise ValueError('Duplicate coordinate name {}.'.format(duplicates)) + coords = OrderedDict() for coord in cube.coords(): coord_attrs = _iris_obj_to_attrs(coord) coord_dims = [dims[i] for i in cube.coord_dims(coord)] - if not coord.var_name: - raise ValueError("Coordinate '{}' has no " - "var_name attribute".format(coord.name())) if coord_dims: - coords[coord.var_name] = (coord_dims, coord.points, coord_attrs) + coords[_name(coord)] = (coord_dims, coord.points, coord_attrs) else: - coords[coord.var_name] = ((), - np.asscalar(coord.points), coord_attrs) + coords[_name(coord)] = ((), np.asscalar(coord.points), coord_attrs) array_attrs = _iris_obj_to_attrs(cube) cell_methods = _iris_cell_methods_to_str(cube.cell_methods) diff --git a/xarray/core/accessors.py b/xarray/core/accessors.py index 81af0532d93..72791ed73ec 100644 --- a/xarray/core/accessors.py +++ b/xarray/core/accessors.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd -from .common import is_np_datetime_like, _contains_datetime_like_objects +from .common import _contains_datetime_like_objects, is_np_datetime_like from .pycompat import dask_array_type diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index b0d2a49c29f..f82ddef25ba 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -174,11 +174,14 @@ def deep_align(objects, join='inner', copy=True, indexes=None, This function is not public API. """ + from .dataarray import DataArray + from .dataset import Dataset + if indexes is None: indexes = {} def is_alignable(obj): - return hasattr(obj, 'indexes') and hasattr(obj, 'reindex') + return isinstance(obj, (DataArray, Dataset)) positions = [] keys = [] diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 430f0e564d6..6853939c02d 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -8,8 +8,8 @@ from .alignment import align from .merge import merge from .pycompat import OrderedDict, basestring, iteritems -from .variable import concat as concat_vars from .variable import IndexVariable, Variable, as_variable +from .variable import concat as concat_vars def concat(objs, dim=None, data_vars='all', coords='different', @@ -125,16 +125,17 @@ def _calc_concat_dim_coord(dim): Infer the dimension name and 1d coordinate variable (if appropriate) for concatenating along the new dimension. """ + from .dataarray import DataArray + if isinstance(dim, basestring): coord = None - elif not hasattr(dim, 'dims'): - # dim is not a DataArray or IndexVariable + elif not isinstance(dim, (DataArray, Variable)): dim_name = getattr(dim, 'name', None) if dim_name is None: dim_name = 'concat_dim' coord = IndexVariable(dim_name, dim) dim = dim_name - elif not hasattr(dim, 'name'): + elif not isinstance(dim, DataArray): coord = as_variable(dim).to_index_variable() dim, = coord.dims else: diff --git a/xarray/core/common.py b/xarray/core/common.py index 3f934fcc769..34057e3715d 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -2,14 +2,19 @@ import warnings from distutils.version import LooseVersion +from textwrap import dedent import numpy as np import pandas as pd -from . import duck_array_ops, dtypes, formatting, ops +from . import dtypes, duck_array_ops, formatting, ops from .arithmetic import SupportsArithmetic from .pycompat import OrderedDict, basestring, dask_array_type, suppress -from .utils import Frozen, SortedKeysDict +from .utils import Frozen, ReprObject, SortedKeysDict, either_dict_or_kwargs +from .options import _get_keep_attrs + +# Used as a sentinel value to indicate a all dimensions +ALL_DIMS = ReprObject('') class ImplementsArrayReduce(object): @@ -17,44 +22,44 @@ class ImplementsArrayReduce(object): def _reduce_method(cls, func, include_skipna, numeric_only): if include_skipna: def wrapped_func(self, dim=None, axis=None, skipna=None, - keep_attrs=False, **kwargs): - return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + **kwargs): + return self.reduce(func, dim, axis, skipna=skipna, allow_lazy=True, **kwargs) else: - def wrapped_func(self, dim=None, axis=None, keep_attrs=False, + def wrapped_func(self, dim=None, axis=None, **kwargs): - return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + return self.reduce(func, dim, axis, allow_lazy=True, **kwargs) return wrapped_func - _reduce_extra_args_docstring = \ - """dim : str or sequence of str, optional + _reduce_extra_args_docstring = dedent("""\ + dim : str or sequence of str, optional Dimension(s) over which to apply `{name}`. axis : int or sequence of int, optional Axis(es) over which to apply `{name}`. Only one of the 'dim' and 'axis' arguments can be supplied. If neither are supplied, then - `{name}` is calculated over axes.""" + `{name}` is calculated over axes.""") - _cum_extra_args_docstring = \ - """dim : str or sequence of str, optional + _cum_extra_args_docstring = dedent("""\ + dim : str or sequence of str, optional Dimension over which to apply `{name}`. axis : int or sequence of int, optional Axis over which to apply `{name}`. Only one of the 'dim' - and 'axis' arguments can be supplied.""" + and 'axis' arguments can be supplied.""") class ImplementsDatasetReduce(object): @classmethod def _reduce_method(cls, func, include_skipna, numeric_only): if include_skipna: - def wrapped_func(self, dim=None, keep_attrs=False, skipna=None, + def wrapped_func(self, dim=None, skipna=None, **kwargs): - return self.reduce(func, dim, keep_attrs, skipna=skipna, + return self.reduce(func, dim, skipna=skipna, numeric_only=numeric_only, allow_lazy=True, **kwargs) else: - def wrapped_func(self, dim=None, keep_attrs=False, **kwargs): - return self.reduce(func, dim, keep_attrs, + def wrapped_func(self, dim=None, **kwargs): + return self.reduce(func, dim, numeric_only=numeric_only, allow_lazy=True, **kwargs) return wrapped_func @@ -308,12 +313,12 @@ def assign_coords(self, **kwargs): assigned : same type as caller A new object with the new coordinates in addition to the existing data. - + Examples -------- - + Convert longitude coordinates from 0-359 to -180-179: - + >>> da = xr.DataArray(np.random.rand(4), ... coords=[np.array([358, 359, 0, 1])], ... dims='lon') @@ -339,6 +344,7 @@ def assign_coords(self, **kwargs): See also -------- Dataset.assign + Dataset.swap_dims """ data = self.copy(deep=False) results = self._calc_assign_results(kwargs) @@ -445,11 +451,11 @@ def groupby(self, group, squeeze=True): grouped : GroupBy A `GroupBy` object patterned after `pandas.GroupBy` that can be iterated over in the form of `(unique_value, grouped_array)` pairs. - + Examples -------- Calculate daily anomalies for daily data: - + >>> da = xr.DataArray(np.linspace(0, 1826, num=1827), ... coords=[pd.date_range('1/1/2000', '31/12/2004', ... freq='D')], @@ -465,7 +471,7 @@ def groupby(self, group, squeeze=True): Coordinates: * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... dayofyear (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 ... - + See Also -------- core.groupby.DataArrayGroupBy @@ -525,24 +531,24 @@ def groupby_bins(self, group, bins, right=True, labels=None, precision=3, 'precision': precision, 'include_lowest': include_lowest}) - def rolling(self, min_periods=None, center=False, **windows): + def rolling(self, dim=None, min_periods=None, center=False, **dim_kwargs): """ Rolling window object. Parameters ---------- + dim: dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. + **dim_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or dim_kwargs must be provided. Returns ------- @@ -581,19 +587,21 @@ def rolling(self, min_periods=None, center=False, **windows): core.rolling.DataArrayRolling core.rolling.DatasetRolling """ + dim = either_dict_or_kwargs(dim, dim_kwargs, 'rolling') + return self._rolling_cls(self, dim, min_periods=min_periods, + center=center) - return self._rolling_cls(self, min_periods=min_periods, - center=center, **windows) - - def resample(self, freq=None, dim=None, how=None, skipna=None, - closed=None, label=None, base=0, keep_attrs=False, **indexer): + def resample(self, indexer=None, skipna=None, closed=None, label=None, + base=0, keep_attrs=None, **indexer_kwargs): """Returns a Resample object for performing resampling operations. - Handles both downsampling and upsampling. If any intervals contain no + Handles both downsampling and upsampling. If any intervals contain no values from the original object, they will be given the value ``NaN``. Parameters ---------- + indexer : {dim: freq}, optional + Mapping from the dimension name to resample frequency. skipna : bool, optional Whether to skip missing values when aggregating in downsampling. closed : 'left' or 'right', optional @@ -608,19 +616,19 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, If True, the object's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. - **indexer : {dim: freq} - Dictionary with a key indicating the dimension name to resample - over and a value corresponding to the resampling frequency. + **indexer_kwargs : {dim: freq} + The keyword arguments form of ``indexer``. + One of indexer or indexer_kwargs must be provided. Returns ------- resampled : same type as caller This object resampled. - + Examples -------- Downsample monthly time-series data to seasonal data: - + >>> da = xr.DataArray(np.linspace(0, 11, num=12), ... coords=[pd.date_range('15/12/1999', ... periods=12, freq=pd.DateOffset(months=1))], @@ -637,46 +645,61 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, * time (time) datetime64[ns] 1999-12-01 2000-03-01 2000-06-01 2000-09-01 Upsample monthly time-series data to daily data: - + >>> da.resample(time='1D').interpolate('linear') array([ 0. , 0.032258, 0.064516, ..., 10.935484, 10.967742, 11. ]) Coordinates: * time (time) datetime64[ns] 1999-12-15 1999-12-16 1999-12-17 ... - + References ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases """ + # TODO support non-string indexer after removing the old API. + from .dataarray import DataArray from .resample import RESAMPLE_DIM + from ..coding.cftimeindex import CFTimeIndex - if dim is not None: - if how is None: - how = 'mean' - return self._resample_immediately(freq, dim, how, skipna, closed, - label, base, keep_attrs) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) - if (how is not None) and indexer: - raise TypeError("If passing an 'indexer' then 'dim' " - "and 'how' should not be used") + # note: the second argument (now 'skipna') use to be 'dim' + if ((skipna is not None and not isinstance(skipna, bool)) + or ('how' in indexer_kwargs and 'how' not in self.dims) + or ('dim' in indexer_kwargs and 'dim' not in self.dims)): + raise TypeError('resample() no longer supports the `how` or ' + '`dim` arguments. Instead call methods on resample ' + "objects, e.g., data.resample(time='1D').mean()") + + indexer = either_dict_or_kwargs(indexer, indexer_kwargs, 'resample') - # More than one indexer is ambiguous, but we do in fact need one if - # "dim" was not provided, until the old API is fully deprecated if len(indexer) != 1: raise ValueError( "Resampling only supported along single dimensions." ) dim, freq = indexer.popitem() - if isinstance(dim, basestring): - dim_name = dim - dim = self[dim] - else: - raise TypeError("Dimension name should be a string; " - "was passed %r" % dim) - group = DataArray(dim, [(dim.dims, dim)], name=RESAMPLE_DIM) + dim_name = dim + dim_coord = self[dim] + + if isinstance(self.indexes[dim_name], CFTimeIndex): + raise NotImplementedError( + 'Resample is currently not supported along a dimension ' + 'indexed by a CFTimeIndex. For certain kinds of downsampling ' + 'it may be possible to work around this by converting your ' + 'time index to a DatetimeIndex using ' + 'CFTimeIndex.to_datetimeindex. Use caution when doing this ' + 'however, because switching to a DatetimeIndex from a ' + 'CFTimeIndex with a non-standard calendar entails a change ' + 'in the calendar type, which could lead to subtle and silent ' + 'errors.' + ) + + group = DataArray(dim_coord, coords=dim_coord.coords, + dims=dim_coord.dims, name=RESAMPLE_DIM) grouper = pd.Grouper(freq=freq, closed=closed, label=label, base=base) resampler = self._resample_cls(self, group=group, dim=dim_name, grouper=grouper, @@ -684,39 +707,6 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, return resampler - def _resample_immediately(self, freq, dim, how, skipna, - closed, label, base, keep_attrs): - """Implement the original version of .resample() which immediately - executes the desired resampling operation. """ - from .dataarray import DataArray - RESAMPLE_DIM = '__resample_dim__' - - warnings.warn("\n.resample() has been modified to defer " - "calculations. Instead of passing 'dim' and " - "how=\"{how}\", instead consider using " - ".resample({dim}=\"{freq}\").{how}('{dim}') ".format( - dim=dim, freq=freq, how=how), - FutureWarning, stacklevel=3) - - if isinstance(dim, basestring): - dim = self[dim] - group = DataArray(dim, [(dim.dims, dim)], name=RESAMPLE_DIM) - grouper = pd.Grouper(freq=freq, how=how, closed=closed, label=label, - base=base) - gb = self._groupby_cls(self, group, grouper=grouper) - if isinstance(how, basestring): - f = getattr(gb, how) - if how in ['first', 'last']: - result = f(skipna=skipna, keep_attrs=keep_attrs) - elif how == 'count': - result = f(dim=dim.name, keep_attrs=keep_attrs) - else: - result = f(dim=dim.name, skipna=skipna, keep_attrs=keep_attrs) - else: - result = gb.reduce(how, dim=dim.name, keep_attrs=keep_attrs) - result = result.rename({RESAMPLE_DIM: dim.name}) - return result - def where(self, cond, other=dtypes.NA, drop=False): """Filter elements from this object according to a condition. @@ -957,8 +947,8 @@ def contains_cftime_datetimes(var): sample = sample.item() return isinstance(sample, cftime_datetime) else: - return False - + return False + def _contains_datetime_like_objects(var): """Check if a variable contains datetime like objects (either diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 9b251bb2c4b..7998cc4f72f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -2,18 +2,19 @@ Functions for applying functions that act on arrays to xarray's labeled data. """ from __future__ import absolute_import, division, print_function -from distutils.version import LooseVersion + import functools import itertools import operator from collections import Counter +from distutils.version import LooseVersion import numpy as np from . import duck_array_ops, utils from .alignment import deep_align from .merge import expand_and_merge_variables -from .pycompat import OrderedDict, dask_array_type, basestring +from .pycompat import OrderedDict, basestring, dask_array_type from .utils import is_dict_like _DEFAULT_FROZEN_SET = frozenset() @@ -919,6 +920,11 @@ def earth_mover_distance(first_samples, if input_core_dims is None: input_core_dims = ((),) * (len(args)) + elif len(input_core_dims) != len(args): + raise ValueError( + 'input_core_dims must be None or a tuple with the length same to ' + 'the number of arguments. Given input_core_dims: {}, ' + 'number of args: {}.'.format(input_core_dims, len(args))) signature = _UFuncSignature(input_core_dims, output_core_dims) diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index c2417345f55..6b53dcffe6e 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -1,7 +1,10 @@ from __future__ import absolute_import, division, print_function -import numpy as np +from distutils.version import LooseVersion + import dask.array as da +import numpy as np +from dask import __version__ as dask_version try: from dask.array import isin @@ -30,3 +33,130 @@ def isin(element, test_elements, assume_unique=False, invert=False): if invert: result = ~result return result + + +if LooseVersion(dask_version) > LooseVersion('1.19.2'): + gradient = da.gradient + +else: # pragma: no cover + # Copied from dask v0.19.2 + # Used under the terms of Dask's license, see licenses/DASK_LICENSE. + import math + from numbers import Integral, Real + + try: + AxisError = np.AxisError + except AttributeError: + try: + np.array([0]).sum(axis=5) + except Exception as e: + AxisError = type(e) + + def validate_axis(axis, ndim): + """ Validate an input to axis= keywords """ + if isinstance(axis, (tuple, list)): + return tuple(validate_axis(ax, ndim) for ax in axis) + if not isinstance(axis, Integral): + raise TypeError("Axis value must be an integer, got %s" % axis) + if axis < -ndim or axis >= ndim: + raise AxisError("Axis %d is out of bounds for array of dimension " + "%d" % (axis, ndim)) + if axis < 0: + axis += ndim + return axis + + def _gradient_kernel(x, block_id, coord, axis, array_locs, grad_kwargs): + """ + x: nd-array + array of one block + coord: 1d-array or scalar + coordinate along which the gradient is computed. + axis: int + axis along which the gradient is computed + array_locs: + actual location along axis. None if coordinate is scalar + grad_kwargs: + keyword to be passed to np.gradient + """ + block_loc = block_id[axis] + if array_locs is not None: + coord = coord[array_locs[0][block_loc]:array_locs[1][block_loc]] + grad = np.gradient(x, coord, axis=axis, **grad_kwargs) + return grad + + def gradient(f, *varargs, **kwargs): + f = da.asarray(f) + + kwargs["edge_order"] = math.ceil(kwargs.get("edge_order", 1)) + if kwargs["edge_order"] > 2: + raise ValueError("edge_order must be less than or equal to 2.") + + drop_result_list = False + axis = kwargs.pop("axis", None) + if axis is None: + axis = tuple(range(f.ndim)) + elif isinstance(axis, Integral): + drop_result_list = True + axis = (axis,) + + axis = validate_axis(axis, f.ndim) + + if len(axis) != len(set(axis)): + raise ValueError("duplicate axes not allowed") + + axis = tuple(ax % f.ndim for ax in axis) + + if varargs == (): + varargs = (1,) + if len(varargs) == 1: + varargs = len(axis) * varargs + if len(varargs) != len(axis): + raise TypeError( + "Spacing must either be a single scalar, or a scalar / " + "1d-array per axis" + ) + + if issubclass(f.dtype.type, (np.bool8, Integral)): + f = f.astype(float) + elif issubclass(f.dtype.type, Real) and f.dtype.itemsize < 4: + f = f.astype(float) + + results = [] + for i, ax in enumerate(axis): + for c in f.chunks[ax]: + if np.min(c) < kwargs["edge_order"] + 1: + raise ValueError( + 'Chunk size must be larger than edge_order + 1. ' + 'Minimum chunk for aixs {} is {}. Rechunk to ' + 'proceed.'.format(np.min(c), ax)) + + if np.isscalar(varargs[i]): + array_locs = None + else: + if isinstance(varargs[i], da.Array): + raise NotImplementedError( + 'dask array coordinated is not supported.') + # coordinate position for each block taking overlap into + # account + chunk = np.array(f.chunks[ax]) + array_loc_stop = np.cumsum(chunk) + 1 + array_loc_start = array_loc_stop - chunk - 2 + array_loc_stop[-1] -= 1 + array_loc_start[0] = 0 + array_locs = (array_loc_start, array_loc_stop) + + results.append(f.map_overlap( + _gradient_kernel, + dtype=f.dtype, + depth={j: 1 if j == ax else 0 for j in range(f.ndim)}, + boundary="none", + coord=varargs[i], + axis=ax, + array_locs=array_locs, + grad_kwargs=kwargs, + )) + + if drop_result_list: + results = results[0] + + return results diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 55ba1c1cbc6..25c572edd54 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,12 +1,21 @@ from __future__ import absolute_import, division, print_function +from distutils.version import LooseVersion + import numpy as np -from . import nputils -from . import dtypes +from . import dtypes, nputils try: + import dask import dask.array as da + # Note: dask has used `ghost` before 0.18.2 + if LooseVersion(dask.__version__) <= LooseVersion('0.18.2'): + overlap = da.ghost.ghost + trim_internal = da.ghost.trim_internal + else: + overlap = da.overlap.overlap + trim_internal = da.overlap.trim_internal except ImportError: pass @@ -15,26 +24,25 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): '''wrapper to apply bottleneck moving window funcs on dask arrays''' dtype, fill_value = dtypes.maybe_promote(a.dtype) a = a.astype(dtype) - # inputs for ghost + # inputs for overlap if axis < 0: axis = a.ndim + axis depth = {d: 0 for d in range(a.ndim)} depth[axis] = (window + 1) // 2 boundary = {d: fill_value for d in range(a.ndim)} - # create ghosted arrays - ag = da.ghost.ghost(a, depth=depth, boundary=boundary) + # Create overlap array. + ag = overlap(a, depth=depth, boundary=boundary) # apply rolling func out = ag.map_blocks(moving_func, window, min_count=min_count, axis=axis, dtype=a.dtype) # trim array - result = da.ghost.trim_internal(out, depth) + result = trim_internal(out, depth) return result def rolling_window(a, axis, window, center, fill_value): """ Dask's equivalence to np.utils.rolling_window """ orig_shape = a.shape - # inputs for ghost if axis < 0: axis = a.ndim + axis depth = {d: 0 for d in range(a.ndim)} @@ -50,7 +58,7 @@ def rolling_window(a, axis, window, center, fill_value): "more evenly divides the shape of your array." % (window, depth[axis], min(a.chunks[axis]))) - # Although dask.ghost pads values to boundaries of the array, + # Although dask.overlap pads values to boundaries of the array, # the size of the generated array is smaller than what we want # if center == False. if center: @@ -60,12 +68,12 @@ def rolling_window(a, axis, window, center, fill_value): start, end = window - 1, 0 pad_size = max(start, end) + offset - depth[axis] drop_size = 0 - # pad_size becomes more than 0 when the ghosted array is smaller than + # pad_size becomes more than 0 when the overlapped array is smaller than # needed. In this case, we need to enlarge the original array by padding - # before ghosting. + # before overlapping. if pad_size > 0: if pad_size < depth[axis]: - # Ghosting requires each chunk larger than depth. If pad_size is + # overlapping requires each chunk larger than depth. If pad_size is # smaller than the depth, we enlarge this and truncate it later. drop_size = depth[axis] - pad_size pad_size = depth[axis] @@ -78,8 +86,8 @@ def rolling_window(a, axis, window, center, fill_value): boundary = {d: fill_value for d in range(a.ndim)} - # create ghosted arrays - ag = da.ghost.ghost(a, depth=depth, boundary=boundary) + # create overlap arrays + ag = overlap(a, depth=depth, boundary=boundary) # apply rolling func def func(x, window, axis=-1): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 35def72c64a..17af3cf2cd1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -16,10 +16,11 @@ assert_coordinate_consistent, remap_label_indexers) from .dataset import Dataset, merge_indexes, split_indexes from .formatting import format_item -from .options import OPTIONS +from .options import OPTIONS, _get_keep_attrs from .pycompat import OrderedDict, basestring, iteritems, range, zip from .utils import ( - decode_numpy_dict_values, either_dict_or_kwargs, ensure_us_time_resolution) + _check_inplace, decode_numpy_dict_values, either_dict_or_kwargs, + ensure_us_time_resolution) from .variable import ( IndexVariable, Variable, as_compatible_data, as_variable, assert_unique_multiindex_level_names) @@ -503,11 +504,7 @@ def _item_sources(self): LevelCoordinatesSource(self)] def __contains__(self, key): - warnings.warn( - 'xarray.DataArray.__contains__ currently checks membership in ' - 'DataArray.coords, but in xarray v0.11 will change to check ' - 'membership in array values.', FutureWarning, stacklevel=2) - return key in self._coords + return key in self.data @property def loc(self): @@ -546,7 +543,7 @@ def coords(self): """ return DataArrayCoordinates(self) - def reset_coords(self, names=None, drop=False, inplace=False): + def reset_coords(self, names=None, drop=False, inplace=None): """Given names of coordinates, reset them to become variables. Parameters @@ -565,6 +562,7 @@ def reset_coords(self, names=None, drop=False, inplace=False): ------- Dataset, or DataArray if ``drop == True`` """ + inplace = _check_inplace(inplace) if inplace and not drop: raise ValueError('cannot reset coordinates in-place on a ' 'DataArray without ``drop == True``') @@ -677,14 +675,77 @@ def persist(self, **kwargs): ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this array. - If `deep=True`, a deep copy is made of all variables in the underlying - dataset. Otherwise, a shallow copy is made, so each variable in the new + If `deep=True`, a deep copy is made of the data array. + Otherwise, a shallow copy is made, so each variable in the new array's dataset is also a variable in this array's dataset. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Whether the data array and its coordinates are loaded into memory + and copied onto the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored for all data variables, + and only used for coords. + + Returns + ------- + object : DataArray + New object with dimensions, attributes, coordinates, name, + encoding, and optionally data copied from original. + + Examples + -------- + + Shallow versus deep copy + + >>> array = xr.DataArray([1, 2, 3], dims='x', + ... coords={'x': ['a', 'b', 'c']}) + >>> array.copy() + + array([1, 2, 3]) + Coordinates: + * x (x) >> array_0 = array.copy(deep=False) + >>> array_0[0] = 7 + >>> array_0 + + array([7, 2, 3]) + Coordinates: + * x (x) >> array + + array([7, 2, 3]) + Coordinates: + * x (x) >> array.copy(data=[0.1, 0.2, 0.3]) + + array([ 0.1, 0.2, 0.3]) + Coordinates: + * x (x) >> array + + array([1, 2, 3]) + Coordinates: + * x (x) >> da = xr.DataArray([1, 3], [('x', np.arange(2))]) + >>> da.interp(x=0.5) + + array(2.0) + Coordinates: + x float64 0.5 """ if self.dtype.kind not in 'uifc': raise TypeError('interp only works for a numeric type array. ' @@ -1083,22 +1153,26 @@ def expand_dims(self, dim, axis=None): ds = self._to_temp_dataset().expand_dims(dim, axis) return self._from_temp_dataset(ds) - def set_index(self, append=False, inplace=False, **indexes): + def set_index(self, indexes=None, append=False, inplace=None, + **indexes_kwargs): """Set DataArray (multi-)indexes using one or more existing coordinates. Parameters ---------- + indexes : {dim: index, ...} + Mapping from names matching dimensions and values given + by (lists of) the names of existing coordinates or variables to set + as new (multi-)index. append : bool, optional If True, append the supplied index(es) to the existing index(es). Otherwise replace the existing index(es) (default). inplace : bool, optional If True, set new index(es) in-place. Otherwise, return a new DataArray object. - **indexes : {dim: index, ...} - Keyword arguments with names matching dimensions and values given - by (lists of) the names of existing coordinates or variables to set - as new (multi-)index. + **indexes_kwargs: optional + The keyword arguments form of ``indexes``. + One of indexes or indexes_kwargs must be provided. Returns ------- @@ -1109,13 +1183,15 @@ def set_index(self, append=False, inplace=False, **indexes): -------- DataArray.reset_index """ + inplace = _check_inplace(inplace) + indexes = either_dict_or_kwargs(indexes, indexes_kwargs, 'set_index') coords, _ = merge_indexes(indexes, self._coords, set(), append=append) if inplace: self._coords = coords else: return self._replace(coords=coords) - def reset_index(self, dims_or_levels, drop=False, inplace=False): + def reset_index(self, dims_or_levels, drop=False, inplace=None): """Reset the specified index(es) or multi-index level(s). Parameters @@ -1140,6 +1216,7 @@ def reset_index(self, dims_or_levels, drop=False, inplace=False): -------- DataArray.set_index """ + inplace = _check_inplace(inplace) coords, _ = split_indexes(dims_or_levels, self._coords, set(), self._level_coords, drop=drop) if inplace: @@ -1147,18 +1224,22 @@ def reset_index(self, dims_or_levels, drop=False, inplace=False): else: return self._replace(coords=coords) - def reorder_levels(self, inplace=False, **dim_order): + def reorder_levels(self, dim_order=None, inplace=None, + **dim_order_kwargs): """Rearrange index levels using input order. Parameters ---------- + dim_order : optional + Mapping from names matching dimensions and values given + by lists representing new level orders. Every given dimension + must have a multi-index. inplace : bool, optional If True, modify the dataarray in-place. Otherwise, return a new DataArray object. - **dim_order : optional - Keyword arguments with names matching dimensions and values given - by lists representing new level orders. Every given dimension - must have a multi-index. + **dim_order_kwargs: optional + The keyword arguments form of ``dim_order``. + One of dim_order or dim_order_kwargs must be provided. Returns ------- @@ -1166,6 +1247,9 @@ def reorder_levels(self, inplace=False, **dim_order): Another dataarray, with this dataarray's data but replaced coordinates. """ + inplace = _check_inplace(inplace) + dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, + 'reorder_levels') replace_coords = {} for dim, order in dim_order.items(): coord = self._coords[dim] @@ -1181,7 +1265,7 @@ def reorder_levels(self, inplace=False, **dim_order): else: return self._replace(coords=coords) - def stack(self, **dimensions): + def stack(self, dimensions=None, **dimensions_kwargs): """ Stack any number of existing dimensions into a single new dimension. @@ -1190,9 +1274,12 @@ def stack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form new_name=(dim1, dim2, ...) + dimensions : Mapping of the form new_name=(dim1, dim2, ...) Names of new dimensions, and the existing dimensions that they replace. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1221,26 +1308,48 @@ def stack(self, **dimensions): -------- DataArray.unstack """ - ds = self._to_temp_dataset().stack(**dimensions) + ds = self._to_temp_dataset().stack(dimensions, **dimensions_kwargs) return self._from_temp_dataset(ds) - def unstack(self, dim): + def unstack(self, dim=None): """ - Unstack an existing dimension corresponding to a MultiIndex into + Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. New dimensions will be added at the end. Parameters ---------- - dim : str - Name of the existing dimension to unstack. + dim : str or sequence of str, optional + Dimension(s) over which to unstack. By default unstacks all + MultiIndexes. Returns ------- unstacked : DataArray Array with unstacked data. + Examples + -------- + + >>> arr = DataArray(np.arange(6).reshape(2, 3), + ... coords=[('x', ['a', 'b']), ('y', [0, 1, 2])]) + >>> arr + + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) |S1 'a' 'b' + * y (y) int64 0 1 2 + >>> stacked = arr.stack(z=('x', 'y')) + >>> stacked.indexes['z'] + MultiIndex(levels=[[u'a', u'b'], [0, 1, 2]], + labels=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + names=[u'x', u'y']) + >>> roundtripped = stacked.unstack() + >>> arr.identical(roundtripped) + True + See also -------- DataArray.stack @@ -1451,7 +1560,7 @@ def combine_first(self, other): """ return ops.fillna(self, other, join="outer") - def reduce(self, func, dim=None, axis=None, keep_attrs=False, **kwargs): + def reduce(self, func, dim=None, axis=None, keep_attrs=None, **kwargs): """Reduce this array by applying `func` along some dimension(s). Parameters @@ -1480,6 +1589,7 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, **kwargs): DataArray with this object's array replaced with an array with summarized data and the indicated dimension(s) removed. """ + var = self.variable.reduce(func, dim, axis, keep_attrs, **kwargs) return self._replace_maybe_drop_dims(var) @@ -1847,7 +1957,7 @@ def _binary_op(f, reflexive=False, join=None, **ignored_kwargs): def func(self, other): if isinstance(other, (Dataset, groupby.GroupBy)): return NotImplemented - if hasattr(other, 'indexes'): + if isinstance(other, DataArray): align_type = (OPTIONS['arithmetic_join'] if join is None else join) self, other = align(self, other, join=align_type, copy=False) @@ -1965,11 +2075,14 @@ def diff(self, dim, n=1, label='upper'): Coordinates: * x (x) int64 3 4 + See Also + -------- + DataArray.differentiate """ ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label) return self._from_temp_dataset(ds) - def shift(self, **shifts): + def shift(self, shifts=None, **shifts_kwargs): """Shift this array by an offset along one or more dimensions. Only the data is moved; coordinates stay in place. Values shifted from @@ -1978,10 +2091,13 @@ def shift(self, **shifts): Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : Mapping with the form of {dim: offset} Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- @@ -2003,17 +2119,23 @@ def shift(self, **shifts): Coordinates: * x (x) int64 0 1 2 """ - variable = self.variable.shift(**shifts) - return self._replace(variable) + ds = self._to_temp_dataset().shift(shifts=shifts, **shifts_kwargs) + return self._from_temp_dataset(ds) - def roll(self, **shifts): + def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): """Roll this array by an offset along one or more dimensions. - Unlike shift, roll rotates all variables, including coordinates. The - direction of rotation is consistent with :py:func:`numpy.roll`. + Unlike shift, roll may rotate all variables, including coordinates + if specified. The direction of rotation is consistent with + :py:func:`numpy.roll`. Parameters ---------- + roll_coords : bool + Indicates whether to roll the coordinates by the offset + The current default of roll_coords (None, equivalent to True) is + deprecated and will change to False in a future version. + Explicitly pass roll_coords to silence the warning. **shifts : keyword arguments of the form {dim: offset} Integer offset to rotate each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. @@ -2037,7 +2159,8 @@ def roll(self, **shifts): Coordinates: * x (x) int64 2 0 1 """ - ds = self._to_temp_dataset().roll(**shifts) + ds = self._to_temp_dataset().roll( + shifts=shifts, roll_coords=roll_coords, **shifts_kwargs) return self._from_temp_dataset(ds) @property @@ -2149,7 +2272,7 @@ def sortby(self, variables, ascending=True): ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) - def quantile(self, q, dim=None, interpolation='linear', keep_attrs=False): + def quantile(self, q, dim=None, interpolation='linear', keep_attrs=None): """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -2195,7 +2318,7 @@ def quantile(self, q, dim=None, interpolation='linear', keep_attrs=False): q, dim=dim, keep_attrs=keep_attrs, interpolation=interpolation) return self._from_temp_dataset(ds) - def rank(self, dim, pct=False, keep_attrs=False): + def rank(self, dim, pct=False, keep_attrs=None): """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -2231,9 +2354,65 @@ def rank(self, dim, pct=False, keep_attrs=False): array([ 1., 2., 3.]) Dimensions without coordinates: x """ + ds = self._to_temp_dataset().rank(dim, pct=pct, keep_attrs=keep_attrs) return self._from_temp_dataset(ds) + def differentiate(self, coord, edge_order=1, datetime_unit=None): + """ Differentiate the array with the second order accurate central + differences. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + Parameters + ---------- + coord: str + The coordinate to be used to compute the gradient. + edge_order: 1 or 2. Default 1 + N-th order accurate differences at the boundaries. + datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', + 'us', 'ns', 'ps', 'fs', 'as'} + Unit to compute gradient. Only valid for datetime coordinate. + + Returns + ------- + differentiated: DataArray + + See also + -------- + numpy.gradient: corresponding numpy function + + Examples + -------- + + >>> da = xr.DataArray(np.arange(12).reshape(4, 3), dims=['x', 'y'], + ... coords={'x': [0, 0.1, 1.1, 1.2]}) + >>> da + + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) float64 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + >>> + >>> da.differentiate('x') + + array([[30. , 30. , 30. ], + [27.545455, 27.545455, 27.545455], + [27.545455, 27.545455, 27.545455], + [30. , 30. , 30. ]]) + Coordinates: + * x (x) float64 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + """ + ds = self._to_temp_dataset().differentiate( + coord, edge_order, datetime_unit) + return self._from_temp_dataset(ds) + # priority most be higher than Variable to properly work with binary ufuncs ops.inject_all_ops_and_reduce_methods(DataArray, priority=60) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 407e4b4f11e..1fd710f9552 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -13,12 +13,14 @@ import xarray as xr from . import ( - alignment, duck_array_ops, formatting, groupby, indexing, ops, resample, - rolling, utils) + alignment, computation, duck_array_ops, formatting, groupby, indexing, ops, + resample, rolling, utils) from .. import conventions +from ..coding.cftimeindex import _parse_array_of_cftime_strings from .alignment import align from .common import ( - DataWithCoords, ImplementsDatasetReduce, _contains_datetime_like_objects) + ALL_DIMS, DataWithCoords, ImplementsDatasetReduce, + _contains_datetime_like_objects) from .coordinates import ( DatasetCoordinates, Indexes, LevelCoordinatesSource, assert_coordinate_consistent, remap_label_indexers) @@ -26,12 +28,13 @@ from .merge import ( dataset_merge_method, dataset_update_method, merge_data_and_coords, merge_variables) -from .options import OPTIONS +from .options import OPTIONS, _get_keep_attrs from .pycompat import ( OrderedDict, basestring, dask_array_type, integer_types, iteritems, range) from .utils import ( - Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values, - ensure_us_time_resolution, hashable, maybe_wrap_array) + _check_inplace, Frozen, SortedKeysDict, datetime_to_numeric, + decode_numpy_dict_values, either_dict_or_kwargs, ensure_us_time_resolution, + hashable, maybe_wrap_array) from .variable import IndexVariable, Variable, as_variable, broadcast_variables from ..plot.plot import _Dataset_PlotMethods @@ -710,16 +713,120 @@ def _replace_indexes(self, indexes): obj = obj.rename(dim_names) return obj - def copy(self, deep=False): + def copy(self, deep=False, data=None): """Returns a copy of this dataset. If `deep=True`, a deep copy is made of each of the component variables. Otherwise, a shallow copy of each of the component variable is made, so that the underlying memory region of the new dataset is the same as in the original dataset. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Whether each component variable is loaded into memory and copied onto + the new object. Default is False. + data : dict-like, optional + Data to use in the new object. Each item in `data` must have same + shape as corresponding data variable in original. When `data` is + used, `deep` is ignored for the data variables and only used for + coords. + + Returns + ------- + object : Dataset + New object with dimensions, attributes, coordinates, name, encoding, + and optionally data copied from original. + + Examples + -------- + + Shallow copy versus deep copy + + >>> da = xr.DataArray(np.random.randn(2, 3)) + >>> ds = xr.Dataset({'foo': da, 'bar': ('x', [-1, 2])}, + coords={'x': ['one', 'two']}) + >>> ds.copy() + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds_0 = ds.copy(deep=False) + >>> ds_0['foo'][0, 0] = 7 + >>> ds_0 + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds.copy(data={'foo': np.arange(6).reshape(2, 3), 'bar': ['a', 'b']}) + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) = n_desired) stop = int(np.ceil(float(n_desired) / np.r_[1, cum_items][n_steps])) - indexer = ((0,) * (len(shape) - 1 - n_steps) + - (slice(stop),) + + indexer = (((-1 if from_end else 0),) * (len(shape) - 1 - n_steps) + + ((slice(-stop, None) if from_end else slice(stop)),) + (slice(None),) * n_steps) return indexer @@ -89,11 +90,28 @@ def first_n_items(array, n_desired): return [] if n_desired < array.size: - indexer = _get_indexer_at_least_n_items(array.shape, n_desired) + indexer = _get_indexer_at_least_n_items(array.shape, n_desired, + from_end=False) array = array[indexer] return np.asarray(array).flat[:n_desired] +def last_n_items(array, n_desired): + """Returns the last n_desired items of an array""" + # Unfortunately, we can't just do array.flat[-n_desired:] here because it + # might not be a numpy.ndarray. Moreover, access to elements of the array + # could be very expensive (e.g. if it's only available over DAP), so go out + # of our way to get them in a single call to __getitem__ using only slices. + if (n_desired == 0) or (array.size == 0): + return [] + + if n_desired < array.size: + indexer = _get_indexer_at_least_n_items(array.shape, n_desired, + from_end=True) + array = array[indexer] + return np.asarray(array).flat[-n_desired:] + + def last_item(array): """Returns the last item of an array in a list or an empty list.""" if array.size == 0: @@ -164,7 +182,7 @@ def format_items(x): day_part = (x[~pd.isnull(x)] .astype('timedelta64[D]') .astype('timedelta64[ns]')) - time_needed = x != day_part + time_needed = x[~pd.isnull(x)] != day_part day_needed = day_part != np.timedelta64(0, 'ns') if np.logical_not(day_needed).all(): timedelta_format = 'time' @@ -180,20 +198,36 @@ def format_array_flat(array, max_width): array that will fit within max_width characters. """ # every item will take up at least two characters, but we always want to - # print at least one item - max_possibly_relevant = max(int(np.ceil(max_width / 2.0)), 1) - relevant_items = first_n_items(array, max_possibly_relevant) - pprint_items = format_items(relevant_items) - - cum_len = np.cumsum([len(s) + 1 for s in pprint_items]) - 1 - if (max_possibly_relevant < array.size or (cum_len > max_width).any()): - end_padding = u' ...' - count = max(np.argmax((cum_len + len(end_padding)) > max_width), 1) - pprint_items = pprint_items[:count] + # print at least first and last items + max_possibly_relevant = min(max(array.size, 1), + max(int(np.ceil(max_width / 2.)), 2)) + relevant_front_items = format_items( + first_n_items(array, (max_possibly_relevant + 1) // 2)) + relevant_back_items = format_items( + last_n_items(array, max_possibly_relevant // 2)) + # interleave relevant front and back items: + # [a, b, c] and [y, z] -> [a, z, b, y, c] + relevant_items = sum(zip_longest(relevant_front_items, + reversed(relevant_back_items)), + ())[:max_possibly_relevant] + + cum_len = np.cumsum([len(s) + 1 for s in relevant_items]) - 1 + if (array.size > 2) and ((max_possibly_relevant < array.size) or + (cum_len > max_width).any()): + padding = u' ... ' + count = min(array.size, + max(np.argmax(cum_len + len(padding) - 1 > max_width), 2)) else: - end_padding = u'' - - pprint_str = u' '.join(pprint_items) + end_padding + count = array.size + padding = u'' if (count <= 1) else u' ' + + num_front = (count + 1) // 2 + num_back = count - num_front + # note that num_back is 0 <--> array.size is 0 or 1 + # <--> relevant_back_items is [] + pprint_str = (u' '.join(relevant_front_items[:num_front]) + + padding + + u' '.join(relevant_back_items[-num_back:])) return pprint_str diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 7068f8e6cae..defe72ab3ee 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,17 +1,19 @@ from __future__ import absolute_import, division, print_function import functools +import warnings import numpy as np import pandas as pd -from . import dtypes, duck_array_ops, nputils, ops +from . import dtypes, duck_array_ops, nputils, ops, utils from .arithmetic import SupportsArithmetic from .combine import concat -from .common import ImplementsArrayReduce, ImplementsDatasetReduce +from .common import ALL_DIMS, ImplementsArrayReduce, ImplementsDatasetReduce from .pycompat import integer_types, range, zip from .utils import hashable, maybe_wrap_array, peek_at, safe_cast_to_index from .variable import IndexVariable, Variable, as_variable +from .options import _get_keep_attrs def unique_value_groups(ar, sort=True): @@ -403,15 +405,17 @@ def _first_or_last(self, op, skipna, keep_attrs): # NB. this is currently only used for reductions along an existing # dimension return self._obj + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) return self.reduce(op, self._group_dim, skipna=skipna, keep_attrs=keep_attrs, allow_lazy=True) - def first(self, skipna=None, keep_attrs=True): + def first(self, skipna=None, keep_attrs=None): """Return the first element of each group along the group dimension """ return self._first_or_last(duck_array_ops.first, skipna, keep_attrs) - def last(self, skipna=None, keep_attrs=True): + def last(self, skipna=None, keep_attrs=None): """Return the last element of each group along the group dimension """ return self._first_or_last(duck_array_ops.last, skipna, keep_attrs) @@ -422,6 +426,7 @@ def assign_coords(self, **kwargs): See also -------- Dataset.assign_coords + Dataset.swap_dims """ return self.apply(lambda ds: ds.assign_coords(**kwargs)) @@ -537,8 +542,8 @@ def _combine(self, applied, shortcut=False): combined = self._maybe_unstack(combined) return combined - def reduce(self, func, dim=None, axis=None, keep_attrs=False, - shortcut=True, **kwargs): + def reduce(self, func, dim=None, axis=None, + keep_attrs=None, shortcut=True, **kwargs): """Reduce the items in this group by applying `func` along some dimension(s). @@ -567,10 +572,42 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = ALL_DIMS + # TODO change this to dim = self._group_dim after + # the deprecation process + if self._obj.ndim > 1: + warnings.warn( + "Default reduction dimension will be changed to the " + "grouped dimension after xarray 0.12. To silence this " + "warning, pass dim=xarray.ALL_DIMS explicitly.", + FutureWarning, stacklevel=2) + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + def reduce_array(ar): return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs) return self.apply(reduce_array, shortcut=shortcut) + # TODO remove the following class method and DEFAULT_DIMS after the + # deprecation cycle + @classmethod + def _reduce_method(cls, func, include_skipna, numeric_only): + if include_skipna: + def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, skipna=None, + keep_attrs=None, **kwargs): + return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + skipna=skipna, allow_lazy=True, **kwargs) + else: + def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, + keep_attrs=None, **kwargs): + return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + allow_lazy=True, **kwargs) + return wrapped_func + + +DEFAULT_DIMS = utils.ReprObject('') ops.inject_reduce_methods(DataArrayGroupBy) ops.inject_binary_ops(DataArrayGroupBy) @@ -620,7 +657,7 @@ def _combine(self, applied): combined = self._maybe_unstack(combined) return combined - def reduce(self, func, dim=None, keep_attrs=False, **kwargs): + def reduce(self, func, dim=None, keep_attrs=None, **kwargs): """Reduce the items in this group by applying `func` along some dimension(s). @@ -649,10 +686,43 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs): Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = ALL_DIMS + # TODO change this to dim = self._group_dim after + # the deprecation process. Do not forget to remove _reduce_method + warnings.warn( + "Default reduction dimension will be changed to the " + "grouped dimension after xarray 0.12. To silence this " + "warning, pass dim=xarray.ALL_DIMS explicitly.", + FutureWarning, stacklevel=2) + elif dim is None: + dim = self._group_dim + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + def reduce_dataset(ds): return ds.reduce(func, dim, keep_attrs, **kwargs) return self.apply(reduce_dataset) + # TODO remove the following class method and DEFAULT_DIMS after the + # deprecation cycle + @classmethod + def _reduce_method(cls, func, include_skipna, numeric_only): + if include_skipna: + def wrapped_func(self, dim=DEFAULT_DIMS, + skipna=None, **kwargs): + return self.reduce(func, dim, + skipna=skipna, numeric_only=numeric_only, + allow_lazy=True, **kwargs) + else: + def wrapped_func(self, dim=DEFAULT_DIMS, + **kwargs): + return self.reduce(func, dim, + numeric_only=numeric_only, allow_lazy=True, + **kwargs) + return wrapped_func + def assign(self, **kwargs): """Assign data variables by group. diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 2c1f08379ab..d51da471c8d 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -749,6 +749,37 @@ class IndexingSupport(object): # could inherit from enum.Enum on Python 3 VECTORIZED = 'VECTORIZED' +def explicit_indexing_adapter( + key, shape, indexing_support, raw_indexing_method): + """Support explicit indexing by delegating to a raw indexing method. + + Outer and/or vectorized indexers are supported by indexing a second time + with a NumPy array. + + Parameters + ---------- + key : ExplicitIndexer + Explicit indexing object. + shape : Tuple[int, ...] + Shape of the indexed array. + indexing_support : IndexingSupport enum + Form of indexing supported by raw_indexing_method. + raw_indexing_method: callable + Function (like ndarray.__getitem__) that when called with indexing key + in the form of a tuple returns an indexed array. + + Returns + ------- + Indexing result, in the form of a duck numpy-array. + """ + raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support) + result = raw_indexing_method(raw_key.tuple) + if numpy_indices.tuple: + # index the loaded np.ndarray + result = NumpyIndexingAdapter(np.asarray(result))[numpy_indices] + return result + + def decompose_indexer(indexer, shape, indexing_support): if isinstance(indexer, VectorizedIndexer): return _decompose_vectorized_indexer(indexer, shape, indexing_support) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index f823717a8af..984dd2fa204 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -190,10 +190,13 @@ def expand_variable_dicts(list_of_variable_dicts): an input's values. The values of each ordered dictionary are all xarray.Variable objects. """ + from .dataarray import DataArray + from .dataset import Dataset + var_dicts = [] for variables in list_of_variable_dicts: - if hasattr(variables, 'variables'): # duck-type Dataset + if isinstance(variables, Dataset): sanitized_vars = variables.variables else: # append coords to var_dicts before appending sanitized_vars, @@ -201,7 +204,7 @@ def expand_variable_dicts(list_of_variable_dicts): sanitized_vars = OrderedDict() for name, var in variables.items(): - if hasattr(var, '_coords'): # duck-type DataArray + if isinstance(var, DataArray): # use private API for speed coords = var._coords.copy() # explicitly overwritten variables should take precedence @@ -232,17 +235,19 @@ def determine_coords(list_of_variable_dicts): All variable found in the input should appear in either the set of coordinate or non-coordinate names. """ + from .dataarray import DataArray + from .dataset import Dataset + coord_names = set() noncoord_names = set() for variables in list_of_variable_dicts: - if hasattr(variables, 'coords') and hasattr(variables, 'data_vars'): - # duck-type Dataset + if isinstance(variables, Dataset): coord_names.update(variables.coords) noncoord_names.update(variables.data_vars) else: for name, var in variables.items(): - if hasattr(var, '_coords'): # duck-type DataArray + if isinstance(var, DataArray): coords = set(var._coords) # use private API for speed # explicitly overwritten variables should take precedence coords.discard(name) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index bec9e2e1931..3f4e0fc3ac9 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function +import warnings from collections import Iterable from functools import partial @@ -7,11 +8,12 @@ import pandas as pd from . import rolling +from .common import _contains_datetime_like_objects from .computation import apply_ufunc +from .duck_array_ops import dask_array_type from .pycompat import iteritems -from .utils import is_scalar, OrderedSet +from .utils import OrderedSet, datetime_to_numeric, is_scalar from .variable import Variable, broadcast_variables -from .duck_array_ops import dask_array_type class BaseInterpolator(object): @@ -57,7 +59,7 @@ def __init__(self, xi, yi, method='linear', fill_value=None, **kwargs): if self.cons_kwargs: raise ValueError( - 'recieved invalid kwargs: %r' % self.cons_kwargs.keys()) + 'received invalid kwargs: %r' % self.cons_kwargs.keys()) if fill_value is None: self._left = np.nan @@ -207,13 +209,16 @@ def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None, interp_class, kwargs = _get_interpolator(method, **kwargs) interpolator = partial(func_interpolate_na, interp_class, **kwargs) - arr = apply_ufunc(interpolator, index, self, - input_core_dims=[[dim], [dim]], - output_core_dims=[[dim]], - output_dtypes=[self.dtype], - dask='parallelized', - vectorize=True, - keep_attrs=True).transpose(*self.dims) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'overflow', RuntimeWarning) + warnings.filterwarnings('ignore', 'invalid value', RuntimeWarning) + arr = apply_ufunc(interpolator, index, self, + input_core_dims=[[dim], [dim]], + output_core_dims=[[dim]], + output_dtypes=[self.dtype], + dask='parallelized', + vectorize=True, + keep_attrs=True).transpose(*self.dims) if limit is not None: arr = arr.where(valids) @@ -402,15 +407,16 @@ def _floatize_x(x, new_x): x = list(x) new_x = list(new_x) for i in range(len(x)): - if x[i].dtype.kind in 'Mm': + if _contains_datetime_like_objects(x[i]): # Scipy casts coordinates to np.float64, which is not accurate # enough for datetime64 (uses 64bit integer). # We assume that the most of the bits are used to represent the # offset (min(x)) and the variation (x - min(x)) can be # represented by float. - xmin = np.min(x[i]) - x[i] = (x[i] - xmin).astype(np.float64) - new_x[i] = (new_x[i] - xmin).astype(np.float64) + xmin = x[i].min() + x[i] = datetime_to_numeric(x[i], offset=xmin, dtype=np.float64) + new_x[i] = datetime_to_numeric( + new_x[i], offset=xmin, dtype=np.float64) return x, new_x diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py new file mode 100644 index 00000000000..4d3f03c899e --- /dev/null +++ b/xarray/core/nanops.py @@ -0,0 +1,207 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +from . import dtypes, nputils +from .duck_array_ops import ( + _dask_or_eager_func, count, fillna, isnull, where_method) +from .pycompat import dask_array_type + +try: + import dask.array as dask_array +except ImportError: + dask_array = None + + +def _replace_nan(a, val): + """ + replace nan in a by val, and returns the replaced array and the nan + position + """ + mask = isnull(a) + return where_method(val, mask, a), mask + + +def _maybe_null_out(result, axis, mask, min_count=1): + """ + xarray version of pandas.core.nanops._maybe_null_out + """ + if hasattr(axis, '__len__'): # if tuple or list + raise ValueError('min_count is not available for reduction ' + 'with more than one dimensions.') + + if axis is not None and getattr(result, 'ndim', False): + null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0 + if null_mask.any(): + dtype, fill_value = dtypes.maybe_promote(result.dtype) + result = result.astype(dtype) + result[null_mask] = fill_value + + elif getattr(result, 'dtype', None) not in dtypes.NAT_TYPES: + null_mask = mask.size - mask.sum() + if null_mask < min_count: + result = np.nan + + return result + + +def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): + """ In house nanargmin, nanargmax for object arrays. Always return integer + type + """ + valid_count = count(value, axis=axis) + value = fillna(value, fill_value) + data = _dask_or_eager_func(func)(value, axis=axis, **kwargs) + + # TODO This will evaluate dask arrays and might be costly. + if (valid_count == 0).any(): + raise ValueError('All-NaN slice encountered') + + return data + + +def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): + """ In house nanmin and nanmax for object array """ + valid_count = count(value, axis=axis) + filled_value = fillna(value, fill_value) + data = getattr(np, func)(filled_value, axis=axis, **kwargs) + if not hasattr(data, 'dtype'): # scalar case + data = dtypes.fill_value(value.dtype) if valid_count == 0 else data + return np.array(data, dtype=value.dtype) + return where_method(data, valid_count != 0) + + +def nanmin(a, axis=None, out=None): + if a.dtype.kind == 'O': + return _nan_minmax_object( + 'min', dtypes.get_pos_infinity(a.dtype), a, axis) + + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanmin(a, axis=axis) + + +def nanmax(a, axis=None, out=None): + if a.dtype.kind == 'O': + return _nan_minmax_object( + 'max', dtypes.get_neg_infinity(a.dtype), a, axis) + + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanmax(a, axis=axis) + + +def nanargmin(a, axis=None): + fill_value = dtypes.get_pos_infinity(a.dtype) + if a.dtype.kind == 'O': + return _nan_argminmax_object('argmin', fill_value, a, axis=axis) + a, mask = _replace_nan(a, fill_value) + if isinstance(a, dask_array_type): + res = dask_array.argmin(a, axis=axis) + else: + res = np.argmin(a, axis=axis) + + if mask is not None: + mask = mask.all(axis=axis) + if mask.any(): + raise ValueError("All-NaN slice encountered") + return res + + +def nanargmax(a, axis=None): + fill_value = dtypes.get_neg_infinity(a.dtype) + if a.dtype.kind == 'O': + return _nan_argminmax_object('argmax', fill_value, a, axis=axis) + + a, mask = _replace_nan(a, fill_value) + if isinstance(a, dask_array_type): + res = dask_array.argmax(a, axis=axis) + else: + res = np.argmax(a, axis=axis) + + if mask is not None: + mask = mask.all(axis=axis) + if mask.any(): + raise ValueError("All-NaN slice encountered") + return res + + +def nansum(a, axis=None, dtype=None, out=None, min_count=None): + a, mask = _replace_nan(a, 0) + result = _dask_or_eager_func('sum')(a, axis=axis, dtype=dtype) + if min_count is not None: + return _maybe_null_out(result, axis, mask, min_count) + else: + return result + + +def _nanmean_ddof_object(ddof, value, axis=None, **kwargs): + """ In house nanmean. ddof argument will be used in _nanvar method """ + from .duck_array_ops import (count, fillna, _dask_or_eager_func, + where_method) + + valid_count = count(value, axis=axis) + value = fillna(value, 0) + # As dtype inference is impossible for object dtype, we assume float + # https://github.com/dask/dask/issues/3162 + dtype = kwargs.pop('dtype', None) + if dtype is None and value.dtype.kind == 'O': + dtype = value.dtype if value.dtype.kind in ['cf'] else float + + data = _dask_or_eager_func('sum')(value, axis=axis, dtype=dtype, **kwargs) + data = data / (valid_count - ddof) + return where_method(data, valid_count != 0) + + +def nanmean(a, axis=None, dtype=None, out=None): + if a.dtype.kind == 'O': + return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype) + + if isinstance(a, dask_array_type): + return dask_array.nanmean(a, axis=axis, dtype=dtype) + + return np.nanmean(a, axis=axis, dtype=dtype) + + +def nanmedian(a, axis=None, out=None): + return _dask_or_eager_func('nanmedian', eager_module=nputils)(a, axis=axis) + + +def _nanvar_object(value, axis=None, **kwargs): + ddof = kwargs.pop('ddof', 0) + kwargs_mean = kwargs.copy() + kwargs_mean.pop('keepdims', None) + value_mean = _nanmean_ddof_object(ddof=0, value=value, axis=axis, + keepdims=True, **kwargs_mean) + squared = (value.astype(value_mean.dtype) - value_mean)**2 + return _nanmean_ddof_object(ddof, squared, axis=axis, **kwargs) + + +def nanvar(a, axis=None, dtype=None, out=None, ddof=0): + if a.dtype.kind == 'O': + return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof) + + return _dask_or_eager_func('nanvar', eager_module=nputils)( + a, axis=axis, dtype=dtype, ddof=ddof) + + +def nanstd(a, axis=None, dtype=None, out=None, ddof=0): + return _dask_or_eager_func('nanstd', eager_module=nputils)( + a, axis=axis, dtype=dtype, ddof=ddof) + + +def nanprod(a, axis=None, dtype=None, out=None, min_count=None): + a, mask = _replace_nan(a, 1) + result = _dask_or_eager_func('nanprod')(a, axis=axis, dtype=dtype, out=out) + if min_count is not None: + return _maybe_null_out(result, axis, mask, min_count) + else: + return result + + +def nancumsum(a, axis=None, dtype=None, out=None): + return _dask_or_eager_func('nancumsum', eager_module=nputils)( + a, axis=axis, dtype=dtype) + + +def nancumprod(a, axis=None, dtype=None, out=None): + return _dask_or_eager_func('nancumprod', eager_module=nputils)( + a, axis=axis, dtype=dtype) diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index 6d4db063b98..efa68c8bad5 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +from distutils.version import LooseVersion + import numpy as np try: @@ -97,3 +99,187 @@ def isin(element, test_elements, assume_unique=False, invert=False): element = np.asarray(element) return np.in1d(element, test_elements, assume_unique=assume_unique, invert=invert).reshape(element.shape) + + +if LooseVersion(np.__version__) >= LooseVersion('1.13'): + gradient = np.gradient +else: + def normalize_axis_tuple(axes, N): + if isinstance(axes, int): + axes = (axes, ) + return tuple([N + a if a < 0 else a for a in axes]) + + def gradient(f, *varargs, **kwargs): + f = np.asanyarray(f) + N = f.ndim # number of dimensions + + axes = kwargs.pop('axis', None) + if axes is None: + axes = tuple(range(N)) + else: + axes = normalize_axis_tuple(axes, N) + + len_axes = len(axes) + n = len(varargs) + if n == 0: + # no spacing argument - use 1 in all axes + dx = [1.0] * len_axes + elif n == 1 and np.ndim(varargs[0]) == 0: + # single scalar for all axes + dx = varargs * len_axes + elif n == len_axes: + # scalar or 1d array for each axis + dx = list(varargs) + for i, distances in enumerate(dx): + if np.ndim(distances) == 0: + continue + elif np.ndim(distances) != 1: + raise ValueError("distances must be either scalars or 1d") + if len(distances) != f.shape[axes[i]]: + raise ValueError("when 1d, distances must match the " + "length of the corresponding dimension") + diffx = np.diff(distances) + # if distances are constant reduce to the scalar case + # since it brings a consistent speedup + if (diffx == diffx[0]).all(): + diffx = diffx[0] + dx[i] = diffx + else: + raise TypeError("invalid number of arguments") + + edge_order = kwargs.pop('edge_order', 1) + if kwargs: + raise TypeError('"{}" are not valid keyword arguments.'.format( + '", "'.join(kwargs.keys()))) + if edge_order > 2: + raise ValueError("'edge_order' greater than 2 not supported") + + # use central differences on interior and one-sided differences on the + # endpoints. This preserves second order-accuracy over the full domain. + + outvals = [] + + # create slice objects --- initially all are [:, :, ..., :] + slice1 = [slice(None)] * N + slice2 = [slice(None)] * N + slice3 = [slice(None)] * N + slice4 = [slice(None)] * N + + otype = f.dtype.char + if otype not in ['f', 'd', 'F', 'D', 'm', 'M']: + otype = 'd' + + # Difference of datetime64 elements results in timedelta64 + if otype == 'M': + # Need to use the full dtype name because it contains unit + # information + otype = f.dtype.name.replace('datetime', 'timedelta') + elif otype == 'm': + # Needs to keep the specific units, can't be a general unit + otype = f.dtype + + # Convert datetime64 data into ints. Make dummy variable `y` + # that is a view of ints if the data is datetime64, otherwise + # just set y equal to the array `f`. + if f.dtype.char in ["M", "m"]: + y = f.view('int64') + else: + y = f + + for i, axis in enumerate(axes): + if y.shape[axis] < edge_order + 1: + raise ValueError( + "Shape of array too small to calculate a numerical " + "gradient, at least (edge_order + 1) elements are " + "required.") + # result allocation + out = np.empty_like(y, dtype=otype) + + uniform_spacing = np.ndim(dx[i]) == 0 + + # Numerical differentiation: 2nd order interior + slice1[axis] = slice(1, -1) + slice2[axis] = slice(None, -2) + slice3[axis] = slice(1, -1) + slice4[axis] = slice(2, None) + + if uniform_spacing: + out[slice1] = (f[slice4] - f[slice2]) / (2. * dx[i]) + else: + dx1 = dx[i][0:-1] + dx2 = dx[i][1:] + a = -(dx2) / (dx1 * (dx1 + dx2)) + b = (dx2 - dx1) / (dx1 * dx2) + c = dx1 / (dx2 * (dx1 + dx2)) + # fix the shape for broadcasting + shape = np.ones(N, dtype=int) + shape[axis] = -1 + a.shape = b.shape = c.shape = shape + # 1D equivalent -- + # out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:] + out[slice1] = a * f[slice2] + b * f[slice3] + c * f[slice4] + + # Numerical differentiation: 1st order edges + if edge_order == 1: + slice1[axis] = 0 + slice2[axis] = 1 + slice3[axis] = 0 + dx_0 = dx[i] if uniform_spacing else dx[i][0] + # 1D equivalent -- out[0] = (y[1] - y[0]) / (x[1] - x[0]) + out[slice1] = (y[slice2] - y[slice3]) / dx_0 + + slice1[axis] = -1 + slice2[axis] = -1 + slice3[axis] = -2 + dx_n = dx[i] if uniform_spacing else dx[i][-1] + # 1D equivalent -- out[-1] = (y[-1] - y[-2]) / (x[-1] - x[-2]) + out[slice1] = (y[slice2] - y[slice3]) / dx_n + + # Numerical differentiation: 2nd order edges + else: + slice1[axis] = 0 + slice2[axis] = 0 + slice3[axis] = 1 + slice4[axis] = 2 + if uniform_spacing: + a = -1.5 / dx[i] + b = 2. / dx[i] + c = -0.5 / dx[i] + else: + dx1 = dx[i][0] + dx2 = dx[i][1] + a = -(2. * dx1 + dx2) / (dx1 * (dx1 + dx2)) + b = (dx1 + dx2) / (dx1 * dx2) + c = - dx1 / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[0] = a * y[0] + b * y[1] + c * y[2] + out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4] + + slice1[axis] = -1 + slice2[axis] = -3 + slice3[axis] = -2 + slice4[axis] = -1 + if uniform_spacing: + a = 0.5 / dx[i] + b = -2. / dx[i] + c = 1.5 / dx[i] + else: + dx1 = dx[i][-2] + dx2 = dx[i][-1] + a = (dx2) / (dx1 * (dx1 + dx2)) + b = - (dx2 + dx1) / (dx1 * dx2) + c = (2. * dx2 + dx1) / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] + out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4] + + outvals.append(out) + + # reset the slice object in this dimension to ":" + slice1[axis] = slice(None) + slice2[axis] = slice(None) + slice3[axis] = slice(None) + slice4[axis] = slice(None) + + if len_axes == 1: + return outvals[0] + else: + return outvals diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 6df2d34bfe3..a8d596abd86 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -5,6 +5,14 @@ import numpy as np import pandas as pd +try: + import bottleneck as bn + _USE_BOTTLENECK = True +except ImportError: + # use numpy methods instead + bn = np + _USE_BOTTLENECK = False + def _validate_axis(data, axis): ndim = data.ndim @@ -195,3 +203,36 @@ def _rolling_window(a, window, axis=-1): rolling = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides, writeable=False) return np.swapaxes(rolling, -2, axis) + + +def _create_bottleneck_method(name, npmodule=np): + def f(values, axis=None, **kwds): + dtype = kwds.get('dtype', None) + bn_func = getattr(bn, name, None) + + if (_USE_BOTTLENECK and bn_func is not None and + not isinstance(axis, tuple) and + values.dtype.kind in 'uifc' and + values.dtype.isnative and + (dtype is None or np.dtype(dtype) == values.dtype)): + # bottleneck does not take care dtype, min_count + kwds.pop('dtype', None) + result = bn_func(values, axis=axis, **kwds) + else: + result = getattr(npmodule, name)(values, axis=axis, **kwds) + + return result + + f.__name__ = name + return f + + +nanmin = _create_bottleneck_method('nanmin') +nanmax = _create_bottleneck_method('nanmax') +nanmean = _create_bottleneck_method('nanmean') +nanmedian = _create_bottleneck_method('nanmedian') +nanvar = _create_bottleneck_method('nanvar') +nanstd = _create_bottleneck_method('nanstd') +nanprod = _create_bottleneck_method('nanprod') +nancumsum = _create_bottleneck_method('nancumsum') +nancumprod = _create_bottleneck_method('nancumprod') diff --git a/xarray/core/ops.py b/xarray/core/ops.py index d9e8ceb65d5..a0dd2212a8f 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -86,7 +86,7 @@ If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been - implemented (object, datetime64 or timedelta64). + implemented (object, datetime64 or timedelta64).{min_count_docs} keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -102,6 +102,12 @@ indicated dimension(s) removed. """ +_MINCOUNT_DOCSTRING = """ +min_count : int, default None + The required number of valid values to perform the operation. + If fewer than min_count non-NA values are present the result will + be NA. New in version 0.10.8: Added with the default being None.""" + _ROLLING_REDUCE_DOCSTRING_TEMPLATE = """\ Reduce this {da_or_ds}'s data windows by applying `{name}` along its dimension. @@ -236,11 +242,15 @@ def inject_reduce_methods(cls): [('count', duck_array_ops.count, False)]) for name, f, include_skipna in methods: numeric_only = getattr(f, 'numeric_only', False) + available_min_count = getattr(f, 'available_min_count', False) + min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else '' + func = cls._reduce_method(f, include_skipna, numeric_only) func.__name__ = name func.__doc__ = _REDUCE_DOCSTRING_TEMPLATE.format( name=name, cls=cls.__name__, - extra_args=cls._reduce_extra_args_docstring.format(name=name)) + extra_args=cls._reduce_extra_args_docstring.format(name=name), + min_count_docs=min_count_docs) setattr(cls, name, func) diff --git a/xarray/core/options.py b/xarray/core/options.py index 48d4567fc99..ab461ca86bc 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,12 +1,71 @@ from __future__ import absolute_import, division, print_function +import warnings + +DISPLAY_WIDTH = 'display_width' +ARITHMETIC_JOIN = 'arithmetic_join' +ENABLE_CFTIMEINDEX = 'enable_cftimeindex' +FILE_CACHE_MAXSIZE = 'file_cache_maxsize' +CMAP_SEQUENTIAL = 'cmap_sequential' +CMAP_DIVERGENT = 'cmap_divergent' +KEEP_ATTRS = 'keep_attrs' + + OPTIONS = { - 'display_width': 80, - 'arithmetic_join': 'inner', - 'enable_cftimeindex': False + DISPLAY_WIDTH: 80, + ARITHMETIC_JOIN: 'inner', + ENABLE_CFTIMEINDEX: True, + FILE_CACHE_MAXSIZE: 128, + CMAP_SEQUENTIAL: 'viridis', + CMAP_DIVERGENT: 'RdBu_r', + KEEP_ATTRS: 'default' +} + +_JOIN_OPTIONS = frozenset(['inner', 'outer', 'left', 'right', 'exact']) + + +def _positive_integer(value): + return isinstance(value, int) and value > 0 + + +_VALIDATORS = { + DISPLAY_WIDTH: _positive_integer, + ARITHMETIC_JOIN: _JOIN_OPTIONS.__contains__, + ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), + FILE_CACHE_MAXSIZE: _positive_integer, + KEEP_ATTRS: lambda choice: choice in [True, False, 'default'] +} + + +def _set_file_cache_maxsize(value): + from ..backends.file_manager import FILE_CACHE + FILE_CACHE.maxsize = value + + +def _warn_on_setting_enable_cftimeindex(enable_cftimeindex): + warnings.warn( + 'The enable_cftimeindex option is now a no-op ' + 'and will be removed in a future version of xarray.', + FutureWarning) + + +_SETTERS = { + FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, + ENABLE_CFTIMEINDEX: _warn_on_setting_enable_cftimeindex } +def _get_keep_attrs(default): + global_choice = OPTIONS['keep_attrs'] + + if global_choice is 'default': + return default + elif global_choice in [True, False]: + return global_choice + else: + raise ValueError("The global option keep_attrs must be one of True, False or 'default'.") + + class set_options(object): """Set options for xarray in a controlled context. @@ -16,9 +75,21 @@ class set_options(object): Default: ``80``. - ``arithmetic_join``: DataArray/Dataset alignment in binary operations. Default: ``'inner'``. - - ``enable_cftimeindex``: flag to enable using a ``CFTimeIndex`` - for time indexes with non-standard calendars or dates outside the - Timestamp-valid range. Default: ``False``. + - ``file_cache_maxsize``: maximum number of open files to hold in xarray's + global least-recently-usage cached. This should be smaller than your + system's per-process file descriptor limit, e.g., ``ulimit -n`` on Linux. + Default: 128. + - ``cmap_sequential``: colormap to use for nondivergent data plots. + Default: ``viridis``. If string, must be matplotlib built-in colormap. + Can also be a Colormap object (e.g. mpl.cm.magma) + - ``cmap_divergent``: colormap to use for divergent data plots. + Default: ``RdBu_r``. If string, must be matplotlib built-in colormap. + Can also be a Colormap object (e.g. mpl.cm.magma) + - ``keep_attrs``: rule for whether to keep attributes on xarray + Datasets/dataarrays after operations. Either ``True`` to always keep + attrs, ``False`` to always discard them, or ``'default'`` to use original + logic that attrs should only be kept in unambiguous circumstances. + Default: ``'default'``. You can use ``set_options`` either as a context manager: @@ -38,16 +109,26 @@ class set_options(object): """ def __init__(self, **kwargs): - invalid_options = {k for k in kwargs if k not in OPTIONS} - if invalid_options: - raise ValueError('argument names %r are not in the set of valid ' - 'options %r' % (invalid_options, set(OPTIONS))) - self.old = OPTIONS.copy() - OPTIONS.update(kwargs) + self.old = {} + for k, v in kwargs.items(): + if k not in OPTIONS: + raise ValueError( + 'argument name %r is not in the set of valid options %r' + % (k, set(OPTIONS))) + if k in _VALIDATORS and not _VALIDATORS[k](v): + raise ValueError( + 'option %r given an invalid value: %r' % (k, v)) + self.old[k] = OPTIONS[k] + self._apply_update(kwargs) + + def _apply_update(self, options_dict): + for k, v in options_dict.items(): + if k in _SETTERS: + _SETTERS[k](v) + OPTIONS.update(options_dict) def __enter__(self): return def __exit__(self, type, value, traceback): - OPTIONS.clear() - OPTIONS.update(self.old) + self._apply_update(self.old) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index df7781ca9c1..b980bc279b0 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -23,10 +23,14 @@ def itervalues(d): range = range zip = zip + from itertools import zip_longest from functools import reduce import builtins from urllib.request import urlretrieve from inspect import getfullargspec as getargspec + + def move_to_end(ordered_dict, key): + ordered_dict.move_to_end(key) else: # pragma: no cover # Python 2 basestring = basestring # noqa @@ -41,12 +45,19 @@ def itervalues(d): return d.itervalues() range = xrange - from itertools import izip as zip, imap as map + from itertools import ( + izip as zip, imap as map, izip_longest as zip_longest, + ) reduce = reduce import __builtin__ as builtins from urllib import urlretrieve from inspect import getargspec + def move_to_end(ordered_dict, key): + value = ordered_dict[key] + del ordered_dict[key] + ordered_dict[key] = value + integer_types = native_int_types + (np.integer,) try: @@ -73,7 +84,6 @@ def itervalues(d): except ImportError as e: path_type = () - try: from contextlib import suppress except ImportError: diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 4933a09b257..edf7dfc3d41 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, division, print_function from . import ops -from .groupby import DataArrayGroupBy, DatasetGroupBy +from .groupby import DEFAULT_DIMS, DataArrayGroupBy, DatasetGroupBy from .pycompat import OrderedDict, dask_array_type RESAMPLE_DIM = '__resample_dim__' @@ -273,19 +273,18 @@ def apply(self, func, **kwargs): return combined.rename({self._resample_dim: self._dim}) - def reduce(self, func, dim=None, keep_attrs=False, **kwargs): + def reduce(self, func, dim=None, keep_attrs=None, **kwargs): """Reduce the items in this group by applying `func` along the pre-defined resampling dimension. - Note that `dim` is by default here and ignored if passed by the user; - this ensures compatibility with the existing reduce interface. - Parameters ---------- func : function Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of collapsing an np.ndarray over an integer valued axis. + dim : str or sequence of str, optional + Dimension(s) over which to apply `func`. keep_attrs : bool, optional If True, the datasets's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -299,8 +298,11 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs): Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = None + return super(DatasetResample, self).reduce( - func, self._dim, keep_attrs, **kwargs) + func, dim, keep_attrs, **kwargs) def _interpolate(self, kind='linear'): """Apply scipy.interpolate.interp1d along resampling dimension.""" diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 24ed280b19e..883dbb34dff 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -44,7 +44,7 @@ class Rolling(object): _attributes = ['window', 'min_periods', 'center', 'dim'] - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object. @@ -52,18 +52,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows): ---------- obj : Dataset or DataArray Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- @@ -115,7 +115,7 @@ def __len__(self): class DataArrayRolling(Rolling): - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for DataArray. You should use DataArray.rolling() method to construct this object @@ -125,18 +125,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows): ---------- obj : DataArray Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- @@ -149,8 +149,8 @@ def __init__(self, obj, min_periods=None, center=False, **windows): Dataset.rolling Dataset.groupby """ - super(DataArrayRolling, self).__init__(obj, min_periods=min_periods, - center=center, **windows) + super(DataArrayRolling, self).__init__( + obj, windows, min_periods=min_periods, center=center) self.window_labels = self.obj[self.dim] @@ -321,7 +321,7 @@ def wrapped_func(self, **kwargs): class DatasetRolling(Rolling): - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for Dataset. You should use Dataset.rolling() method to construct this object @@ -331,18 +331,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows): ---------- obj : Dataset Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- @@ -355,8 +355,7 @@ def __init__(self, obj, min_periods=None, center=False, **windows): Dataset.groupby DataArray.groupby """ - super(DatasetRolling, self).__init__(obj, - min_periods, center, **windows) + super(DatasetRolling, self).__init__(obj, windows, min_periods, center) if self.dim not in self.obj.dims: raise KeyError(self.dim) # Keep each Rolling object as an OrderedDict @@ -364,8 +363,8 @@ def __init__(self, obj, min_periods=None, center=False, **windows): for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on slf.dim if self.dim in da.dims: - self.rollings[key] = DataArrayRolling(da, min_periods, - center, **windows) + self.rollings[key] = DataArrayRolling( + da, windows, min_periods, center) def reduce(self, func, **kwargs): """Reduce the items in this group by applying `func` along some diff --git a/xarray/core/utils.py b/xarray/core/utils.py index c3bb747fac5..50d6ec7e05a 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -5,6 +5,7 @@ import contextlib import functools import itertools +import os.path import re import warnings from collections import Iterable, Mapping, MutableMapping, MutableSet @@ -12,11 +13,20 @@ import numpy as np import pandas as pd -from .options import OPTIONS from .pycompat import ( OrderedDict, basestring, bytes_type, dask_array_type, iteritems) +def _check_inplace(inplace, default=False): + if inplace is None: + inplace = default + else: + warnings.warn('The inplace argument has been deprecated and will be ' + 'removed in xarray 0.12.0.', FutureWarning, stacklevel=3) + + return inplace + + def alias_message(old_name, new_name): return '%s has been deprecated. Use %s instead.' % (old_name, new_name) @@ -40,16 +50,13 @@ def wrapper(*args, **kwargs): def _maybe_cast_to_cftimeindex(index): from ..coding.cftimeindex import CFTimeIndex - if not OPTIONS['enable_cftimeindex']: - return index - else: - if index.dtype == 'O': - try: - return CFTimeIndex(index) - except (ImportError, TypeError): - return index - else: + if index.dtype == 'O': + try: + return CFTimeIndex(index) + except (ImportError, TypeError): return index + else: + return index def safe_cast_to_index(array): @@ -504,6 +511,11 @@ def is_remote_uri(path): return bool(re.search('^https?\://', path)) +def is_grib_path(path): + _, ext = os.path.splitext(path) + return ext in ['.grib', '.grb', '.grib2', '.grb2'] + + def is_uniform_spaced(arr, **kwargs): """Return True if values of an array are uniformly spaced and sorted. @@ -591,3 +603,29 @@ def __iter__(self): def __len__(self): num_hidden = sum([k in self._hidden_keys for k in self._data]) return len(self._data) - num_hidden + + +def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): + """Convert an array containing datetime-like data to an array of floats. + + Parameters + ---------- + da : array + Input data + offset: Scalar with the same type of array or None + If None, subtract minimum values to reduce round off error + datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', + 'us', 'ns', 'ps', 'fs', 'as'} + dtype: target dtype + + Returns + ------- + array + """ + if offset is None: + offset = array.min() + array = array - offset + + if datetime_unit: + return (array / np.timedelta64(1, datetime_unit)).astype(dtype) + return array.astype(dtype) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 52d470accfe..0bff06e7546 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -18,6 +18,7 @@ from .pycompat import ( OrderedDict, basestring, dask_array_type, integer_types, zip) from .utils import OrderedSet, either_dict_or_kwargs +from .options import _get_keep_attrs try: import dask.array as da @@ -64,34 +65,30 @@ def as_variable(obj, name=None): The newly created variable. """ + from .dataarray import DataArray + # TODO: consider extending this method to automatically handle Iris and - # pandas objects. - if hasattr(obj, 'variable'): + if isinstance(obj, DataArray): # extract the primary Variable from DataArrays obj = obj.variable if isinstance(obj, Variable): obj = obj.copy(deep=False) - elif hasattr(obj, 'dims') and (hasattr(obj, 'data') or - hasattr(obj, 'values')): - obj_data = getattr(obj, 'data', None) - if obj_data is None: - obj_data = getattr(obj, 'values') - obj = Variable(obj.dims, obj_data, - getattr(obj, 'attrs', None), - getattr(obj, 'encoding', None)) elif isinstance(obj, tuple): try: obj = Variable(*obj) - except TypeError: + except (TypeError, ValueError) as error: # use .format() instead of % because it handles tuples consistently - raise TypeError('tuples to convert into variables must be of the ' - 'form (dims, data[, attrs, encoding]): ' - '{}'.format(obj)) + raise error.__class__('Could not convert tuple of form ' + '(dims, data[, attrs, encoding]): ' + '{} to Variable.'.format(obj)) elif utils.is_scalar(obj): obj = Variable([], obj) elif isinstance(obj, (pd.Index, IndexVariable)) and obj.name is not None: obj = Variable(obj.name, obj) + elif isinstance(obj, (set, dict)): + raise TypeError( + "variable %r has invalid type %r" % (name, type(obj))) elif name is not None: data = as_compatible_data(obj) if data.ndim != 1: @@ -99,7 +96,7 @@ def as_variable(obj, name=None): 'cannot set variable %r with %r-dimensional data ' 'without explicit dimension names. Pass a tuple of ' '(dims, data) instead.' % (name, data.ndim)) - obj = Variable(name, obj, fastpath=True) + obj = Variable(name, data, fastpath=True) else: raise TypeError('unable to convert object into a variable without an ' 'explicit list of dimensions: %r' % obj) @@ -725,24 +722,81 @@ def encoding(self, value): except ValueError: raise ValueError('encoding must be castable to a dictionary') - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this object. If `deep=True`, the data array is loaded into memory and copied onto the new object. Dimensions, attributes and encodings are always copied. - """ - data = self._data - if isinstance(data, indexing.MemoryCachedArray): - # don't share caching between copies - data = indexing.MemoryCachedArray(data.array) + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Whether the data array is loaded into memory and copied onto + the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored. + + Returns + ------- + object : Variable + New object with dimensions, attributes, encodings, and optionally + data copied from original. - if deep: - if isinstance(data, dask_array_type): - data = data.copy() - elif not isinstance(data, PandasIndexAdapter): - # pandas.Index is immutable - data = np.array(data) + Examples + -------- + + Shallow copy versus deep copy + + >>> var = xr.Variable(data=[1, 2, 3], dims='x') + >>> var.copy() + + array([1, 2, 3]) + >>> var_0 = var.copy(deep=False) + >>> var_0[0] = 7 + >>> var_0 + + array([7, 2, 3]) + >>> var + + array([7, 2, 3]) + + Changing the data using the ``data`` argument maintains the + structure of the original object, but with the new data. Original + object is unaffected. + + >>> var.copy(data=[0.1, 0.2, 0.3]) + + array([ 0.1, 0.2, 0.3]) + >>> var + + array([7, 2, 3]) + + See Also + -------- + pandas.DataFrame.copy + """ + if data is None: + data = self._data + + if isinstance(data, indexing.MemoryCachedArray): + # don't share caching between copies + data = indexing.MemoryCachedArray(data.array) + + if deep: + if isinstance(data, dask_array_type): + data = data.copy() + elif not isinstance(data, PandasIndexAdapter): + # pandas.Index is immutable + data = np.array(data) + else: + data = as_compatible_data(data) + if self.shape != data.shape: + raise ValueError("Data shape {} must match shape of object {}" + .format(data.shape, self.shape)) # note: # dims is already an immutable tuple @@ -874,7 +928,7 @@ def squeeze(self, dim=None): numpy.squeeze """ dims = common.get_squeeze_dims(self, dim) - return self.isel(**{d: 0 for d in dims}) + return self.isel({d: 0 for d in dims}) def _shift_one_dim(self, dim, count): axis = self.get_axis_num(dim) @@ -916,36 +970,46 @@ def _shift_one_dim(self, dim, count): return type(self)(self.dims, data, self._attrs, fastpath=True) - def shift(self, **shifts): + def shift(self, shifts=None, **shifts_kwargs): """ Return a new Variable with shifted data. Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : mapping of the form {dim: offset} Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- shifted : Variable Variable with the same dimensions and attributes but shifted data. """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift') result = self for dim, count in shifts.items(): result = result._shift_one_dim(dim, count) return result - def pad_with_fill_value(self, fill_value=dtypes.NA, **pad_widths): + def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA, + **pad_widths_kwargs): """ Return a new Variable with paddings. Parameters ---------- - **pad_width: keyword arguments of the form {dim: (before, after)} + pad_width: Mapping of the form {dim: (before, after)} Number of values padded to the edges of each dimension. + **pad_widths_kwargs: + Keyword argument for pad_widths """ + pad_widths = either_dict_or_kwargs(pad_widths, pad_widths_kwargs, + 'pad') + if fill_value is dtypes.NA: # np.nan is passed dtype, fill_value = dtypes.maybe_promote(self.dtype) else: @@ -1006,22 +1070,27 @@ def _roll_one_dim(self, dim, count): return type(self)(self.dims, data, self._attrs, fastpath=True) - def roll(self, **shifts): + def roll(self, shifts=None, **shifts_kwargs): """ Return a new Variable with rolld data. Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : mapping of the form {dim: offset} Integer offset to roll along each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- shifted : Variable Variable with the same dimensions and attributes but rolled data. """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'roll') + result = self for dim, count in shifts.items(): result = result._roll_one_dim(dim, count) @@ -1139,7 +1208,7 @@ def _stack_once(self, dims, new_dim): return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) - def stack(self, **dimensions): + def stack(self, dimensions=None, **dimensions_kwargs): """ Stack any number of existing dimensions into a single new dimension. @@ -1148,9 +1217,12 @@ def stack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form new_name=(dim1, dim2, ...) + dimensions : Mapping of form new_name=(dim1, dim2, ...) Names of new dimensions, and the existing dimensions that they replace. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1161,6 +1233,8 @@ def stack(self, **dimensions): -------- Variable.unstack """ + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, + 'stack') result = self for new_dim, dims in dimensions.items(): result = result._stack_once(dims, new_dim) @@ -1192,7 +1266,7 @@ def _unstack_once(self, dims, old_dim): return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) - def unstack(self, **dimensions): + def unstack(self, dimensions=None, **dimensions_kwargs): """ Unstack an existing dimension into multiple new dimensions. @@ -1201,9 +1275,12 @@ def unstack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form old_dim={dim1: size1, ...} + dimensions : mapping of the form old_dim={dim1: size1, ...} Names of existing dimensions, and the new dimensions and sizes that they map to. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1214,6 +1291,8 @@ def unstack(self, **dimensions): -------- Variable.stack """ + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, + 'unstack') result = self for old_dim, dims in dimensions.items(): result = result._unstack_once(dims, old_dim) @@ -1225,8 +1304,8 @@ def fillna(self, value): def where(self, cond, other=dtypes.NA): return ops.where_method(self, cond, other) - def reduce(self, func, dim=None, axis=None, keep_attrs=False, - allow_lazy=False, **kwargs): + def reduce(self, func, dim=None, axis=None, + keep_attrs=None, allow_lazy=False, **kwargs): """Reduce this array by applying `func` along some dimension(s). Parameters @@ -1255,6 +1334,8 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, Array with summarized data and the indicated dimension(s) removed. """ + if dim is common.ALL_DIMS: + dim = None if dim is not None and axis is not None: raise ValueError("cannot supply both 'axis' and 'dim' arguments") @@ -1271,6 +1352,8 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, dims = [adim for n, adim in enumerate(self.dims) if n not in removed_axes] + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) attrs = self._attrs if keep_attrs else None return Variable(dims, data, attrs=attrs) @@ -1688,14 +1771,37 @@ def concat(cls, variables, dim='concat_dim', positions=None, return cls(first_var.dims, data, attrs) - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this object. - `deep` is ignored since data is stored in the form of pandas.Index, - which is already immutable. Dimensions, attributes and encodings are - always copied. + `deep` is ignored since data is stored in the form of + pandas.Index, which is already immutable. Dimensions, attributes + and encodings are always copied. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Deep is always ignored. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + + Returns + ------- + object : Variable + New object with dimensions, attributes, encodings, and optionally + data copied from original. """ - return type(self)(self.dims, self._data, self._attrs, + if data is None: + data = self._data + else: + data = as_compatible_data(data) + if self.shape != data.shape: + raise ValueError("Data shape {} must match shape of object {}" + .format(data.shape, self.shape)) + return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) def equals(self, other, equiv=None): @@ -1873,12 +1979,15 @@ def assert_unique_multiindex_level_names(variables): objects. """ level_names = defaultdict(list) + all_level_names = set() for var_name, var in variables.items(): if isinstance(var._data, PandasIndexAdapter): idx_level_names = var.to_index_variable().level_names if idx_level_names is not None: for n in idx_level_names: level_names[n].append('%r (%s)' % (n, var_name)) + if idx_level_names: + all_level_names.update(idx_level_names) for k, v in level_names.items(): if k in variables: @@ -1889,3 +1998,9 @@ def assert_unique_multiindex_level_names(variables): conflict_str = '\n'.join([', '.join(v) for v in duplicate_names]) raise ValueError('conflicting MultiIndex level name(s):\n%s' % conflict_str) + # Check confliction between level names and dimensions GH:2299 + for k, v in variables.items(): + for d in v.dims: + if d in all_level_names: + raise ValueError('conflicting level / dimension names. {} ' + 'already exists as a level name.'.format(d)) diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index fe2c604a89e..4b53b22243c 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from .plot import (plot, line, contourf, contour, +from .plot import (plot, line, step, contourf, contour, hist, imshow, pcolormesh) from .facetgrid import FacetGrid @@ -9,6 +9,7 @@ __all__ = [ 'plot', 'line', + 'step', 'contour', 'contourf', 'hist', diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index d2016e30679..48a3e090aa3 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -5,6 +5,7 @@ import warnings import numpy as np + from ..core.formatting import format_item from ..core.pycompat import getargspec from .utils import ( @@ -188,6 +189,7 @@ def __init__(self, data, col=None, row=None, col_wrap=None, self._y_var = None self._cmap_extend = None self._mappables = [] + self._finalized = False @property def _left_axes(self): @@ -221,6 +223,11 @@ def map_dataarray(self, func, x, y, **kwargs): cmapkw = kwargs.get('cmap') colorskw = kwargs.get('colors') + cbar_kwargs = kwargs.pop('cbar_kwargs', {}) + cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) + + if kwargs.get('cbar_ax', None) is not None: + raise ValueError('cbar_ax not supported by FacetGrid.') # colors is mutually exclusive with cmap if cmapkw and colorskw: @@ -262,7 +269,7 @@ def map_dataarray(self, func, x, y, **kwargs): self._finalize_grid(x, y) if kwargs.get('add_colorbar', True): - self.add_colorbar() + self.add_colorbar(**cbar_kwargs) return self @@ -338,13 +345,16 @@ def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False, def _finalize_grid(self, *axlabels): """Finalize the annotations and layout.""" - self.set_axis_labels(*axlabels) - self.set_titles() - self.fig.tight_layout() + if not self._finalized: + self.set_axis_labels(*axlabels) + self.set_titles() + self.fig.tight_layout() - for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): - if namedict is None: - ax.set_visible(False) + for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): + if namedict is None: + ax.set_visible(False) + + self._finalized = True def add_legend(self, **kwargs): figlegend = self.fig.legend( @@ -532,9 +542,12 @@ def map(self, func, *args, **kwargs): data = self.data.loc[namedict] plt.sca(ax) innerargs = [data[a].values for a in args] - # TODO: is it possible to verify that an artist is mappable? - mappable = func(*innerargs, **kwargs) - self._mappables.append(mappable) + maybe_mappable = func(*innerargs, **kwargs) + # TODO: better way to verify that an artist is mappable? + # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522 + if (maybe_mappable and + hasattr(maybe_mappable, 'autoscale_None')): + self._mappables.append(maybe_mappable) self._finalize_grid(*args[:2]) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 9c2ffbde048..2cf19f4fb03 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -14,12 +14,15 @@ import numpy as np import pandas as pd +from xarray.core.alignment import align from xarray.core.common import contains_cftime_datetimes from xarray.core.pycompat import basestring from .facetgrid import FacetGrid from .utils import ( - ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels, get_axis, + ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels, + _interval_to_double_bound_points, _interval_to_mid_points, + _resolve_intervals_2dplot, _valid_other_type, get_axis, import_matplotlib_pyplot, label_from_attrs) @@ -35,27 +38,20 @@ def _valid_numpy_subdtype(x, numpy_types): return any(np.issubdtype(x.dtype, t) for t in numpy_types) -def _valid_other_type(x, types): - """ - Do all elements of x have a type from types? - """ - return all(any(isinstance(el, t) for t in types) for el in np.ravel(x)) - - def _ensure_plottable(*args): """ Raise exception if there is anything in args that can't be plotted on an - axis. + axis by matplotlib. """ numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64] other_types = [datetime] for x in args: - if not (_valid_numpy_subdtype(np.array(x), numpy_types) or - _valid_other_type(np.array(x), other_types)): + if not (_valid_numpy_subdtype(np.array(x), numpy_types) + or _valid_other_type(np.array(x), other_types)): raise TypeError('Plotting requires coordinates to be numeric ' 'or dates of type np.datetime64 or ' - 'datetime.datetime.') + 'datetime.datetime or pd.Interval.') def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, @@ -149,8 +145,13 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, darray = darray.squeeze() if contains_cftime_datetimes(darray): - raise NotImplementedError('Plotting arrays of cftime.datetime objects ' - 'is currently not possible.') + raise NotImplementedError( + 'Built-in plotting of arrays of cftime.datetime objects or arrays ' + 'indexed by cftime.datetime objects is currently not implemented ' + 'within xarray. A possible workaround is to use the ' + 'nc-time-axis package ' + '(https://github.com/SciTools/nc-time-axis) to convert the dates ' + 'to a plottable type and plot your data directly with matplotlib.') plot_dims = set(darray.dims) plot_dims.discard(row) @@ -173,8 +174,10 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, kwargs['hue'] = hue elif ndims == 2: if hue: - raise ValueError('hue is not compatible with 2d data') - plotfunc = pcolormesh + plotfunc = line + kwargs['hue'] = hue + else: + plotfunc = pcolormesh else: if row or col or hue: raise ValueError(error_msg) @@ -190,10 +193,10 @@ def _infer_line_data(darray, x, y, hue): .format(', '.join([repr(dd) for dd in darray.dims]))) ndims = len(darray.dims) - if x is not None and x not in darray.dims: + if x is not None and x not in darray.dims and x not in darray.coords: raise ValueError('x ' + error_msg) - if y is not None and y not in darray.dims: + if y is not None and y not in darray.dims and y not in darray.coords: raise ValueError('y ' + error_msg) if x is not None and y is not None: @@ -207,11 +210,11 @@ def _infer_line_data(darray, x, y, hue): hue_label = '' if (x is None and y is None) or x == dim: - xplt = darray.coords[dim] + xplt = darray[dim] yplt = darray else: - yplt = darray.coords[dim] + yplt = darray[dim] xplt = darray else: @@ -221,18 +224,36 @@ def _infer_line_data(darray, x, y, hue): if y is None: xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) - yname = darray.name - xplt = darray.coords[xname] - yplt = darray.transpose(xname, huename) + xplt = darray[xname] + if xplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + otherdim = darray.dims[otherindex] + yplt = darray.transpose(otherdim, huename) + xplt = xplt.transpose(otherdim, huename) + else: + raise ValueError('For 2D inputs, hue must be a dimension' + + ' i.e. one of ' + repr(darray.dims)) + + else: + yplt = darray.transpose(xname, huename) else: yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) - xname = darray.name - xplt = darray.transpose(yname, huename) - yplt = darray.coords[yname] + yplt = darray[yname] + if yplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + xplt = darray.transpose(otherdim, huename) + else: + raise ValueError('For 2D inputs, hue must be a dimension' + + ' i.e. one of ' + repr(darray.dims)) + + else: + xplt = darray.transpose(yname, huename) - hueplt = darray.coords[huename] hue_label = label_from_attrs(darray[huename]) + hueplt = darray[huename] xlabel = label_from_attrs(xplt) ylabel = label_from_attrs(yplt) @@ -337,11 +358,17 @@ def line(darray, *args, **kwargs): Axis on which to plot this figure. By default, use the current axis. Mutually exclusive with ``size`` and ``figsize``. hue : string, optional - Coordinate for which you want multiple lines plotted. + Dimension or coordinate for which you want multiple lines plotted. + If plotting against a 2D coordinate, ``hue`` must be a dimension. x, y : string, optional - Coordinates for x, y axis. Only one of these may be specified. + Dimensions or coordinates for x, y axis. + Only one of these may be specified. The other coordinate plots values from the DataArray on which this plot method is called. + xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional + Specifies scaling for the x- and y-axes respectively + xticks, yticks : Specify tick locations for x- and y-axes + xlim, ylim : Specify x- and y-axes limits xincrease : None, True, or False, optional Should the values on the x axes be increasing from left to right? if None, use the default for the matplotlib function. @@ -377,8 +404,14 @@ def line(darray, *args, **kwargs): hue = kwargs.pop('hue', None) x = kwargs.pop('x', None) y = kwargs.pop('y', None) - xincrease = kwargs.pop('xincrease', True) - yincrease = kwargs.pop('yincrease', True) + xincrease = kwargs.pop('xincrease', None) # default needs to be None + yincrease = kwargs.pop('yincrease', None) + xscale = kwargs.pop('xscale', None) # default needs to be None + yscale = kwargs.pop('yscale', None) + xticks = kwargs.pop('xticks', None) + yticks = kwargs.pop('yticks', None) + xlim = kwargs.pop('xlim', None) + ylim = kwargs.pop('ylim', None) add_legend = kwargs.pop('add_legend', True) _labels = kwargs.pop('_labels', True) if args is (): @@ -388,9 +421,30 @@ def line(darray, *args, **kwargs): xplt, yplt, hueplt, xlabel, ylabel, hue_label = \ _infer_line_data(darray, x, y, hue) - _ensure_plottable(xplt) + # Remove pd.Intervals if contained in xplt.values. + if _valid_other_type(xplt.values, [pd.Interval]): + # Is it a step plot? (see matplotlib.Axes.step) + if kwargs.get('linestyle', '').startswith('steps-'): + xplt_val, yplt_val = _interval_to_double_bound_points(xplt.values, + yplt.values) + # Remove steps-* to be sure that matplotlib is not confused + kwargs['linestyle'] = (kwargs['linestyle'] + .replace('steps-pre', '') + .replace('steps-post', '') + .replace('steps-mid', '')) + if kwargs['linestyle'] == '': + kwargs.pop('linestyle') + else: + xplt_val = _interval_to_mid_points(xplt.values) + yplt_val = yplt.values + xlabel += '_center' + else: + xplt_val = xplt.values + yplt_val = yplt.values - primitive = ax.plot(xplt, yplt, *args, **kwargs) + _ensure_plottable(xplt_val, yplt_val) + + primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) if _labels: if xlabel is not None: @@ -415,11 +469,52 @@ def line(darray, *args, **kwargs): xlabels.set_rotation(30) xlabels.set_ha('right') - _update_axes_limits(ax, xincrease, yincrease) + _update_axes(ax, xincrease, yincrease, xscale, yscale, + xticks, yticks, xlim, ylim) return primitive +def step(darray, *args, **kwargs): + """ + Step plot of DataArray index against values + + Similar to :func:`matplotlib:matplotlib.pyplot.step` + + Parameters + ---------- + where : {'pre', 'post', 'mid'}, optional, default 'pre' + Define where the steps should be placed: + - 'pre': The y value is continued constantly to the left from + every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the + value ``y[i]``. + - 'post': The y value is continued constantly to the right from + every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the + value ``y[i]``. + - 'mid': Steps occur half-way between the *x* positions. + Note that this parameter is ignored if the x coordinate consists of + :py:func:`pandas.Interval` values, e.g. as a result of + :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual + boundaries of the interval are used. + + *args, **kwargs : optional + Additional arguments following :py:func:`xarray.plot.line` + + """ + if ('ls' in kwargs.keys()) and ('linestyle' not in kwargs.keys()): + kwargs['linestyle'] = kwargs.pop('ls') + + where = kwargs.pop('where', 'pre') + + if where not in ('pre', 'post', 'mid'): + raise ValueError("'where' argument to step must be " + "'pre', 'post' or 'mid'") + + kwargs['linestyle'] = 'steps-' + where + kwargs.get('linestyle', '') + + return line(darray, *args, **kwargs) + + def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs): """ Histogram of DataArray @@ -450,37 +545,69 @@ def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs): """ ax = get_axis(figsize, size, aspect, ax) + xincrease = kwargs.pop('xincrease', None) # default needs to be None + yincrease = kwargs.pop('yincrease', None) + xscale = kwargs.pop('xscale', None) # default needs to be None + yscale = kwargs.pop('yscale', None) + xticks = kwargs.pop('xticks', None) + yticks = kwargs.pop('yticks', None) + xlim = kwargs.pop('xlim', None) + ylim = kwargs.pop('ylim', None) + no_nan = np.ravel(darray.values) no_nan = no_nan[pd.notnull(no_nan)] primitive = ax.hist(no_nan, **kwargs) - ax.set_ylabel('Count') - ax.set_title('Histogram') ax.set_xlabel(label_from_attrs(darray)) + _update_axes(ax, xincrease, yincrease, xscale, yscale, + xticks, yticks, xlim, ylim) + return primitive -def _update_axes_limits(ax, xincrease, yincrease): +def _update_axes(ax, xincrease, yincrease, + xscale=None, yscale=None, + xticks=None, yticks=None, + xlim=None, ylim=None): """ - Update axes in place to increase or decrease - For use in _plot2d + Update axes with provided parameters """ if xincrease is None: pass - elif xincrease: - ax.set_xlim(sorted(ax.get_xlim())) - elif not xincrease: - ax.set_xlim(sorted(ax.get_xlim(), reverse=True)) + elif xincrease and ax.xaxis_inverted(): + ax.invert_xaxis() + elif not xincrease and not ax.xaxis_inverted(): + ax.invert_xaxis() if yincrease is None: pass - elif yincrease: - ax.set_ylim(sorted(ax.get_ylim())) - elif not yincrease: - ax.set_ylim(sorted(ax.get_ylim(), reverse=True)) + elif yincrease and ax.yaxis_inverted(): + ax.invert_yaxis() + elif not yincrease and not ax.yaxis_inverted(): + ax.invert_yaxis() + + # The default xscale, yscale needs to be None. + # If we set a scale it resets the axes formatters, + # This means that set_xscale('linear') on a datetime axis + # will remove the date labels. So only set the scale when explicitly + # asked to. https://github.com/matplotlib/matplotlib/issues/8740 + if xscale is not None: + ax.set_xscale(xscale) + if yscale is not None: + ax.set_yscale(yscale) + + if xticks is not None: + ax.set_xticks(xticks) + if yticks is not None: + ax.set_yticks(yticks) + + if xlim is not None: + ax.set_xlim(xlim) + if ylim is not None: + ax.set_ylim(ylim) # MUST run before any 2d plotting functions are defined since @@ -505,12 +632,18 @@ def hist(self, ax=None, **kwargs): def line(self, *args, **kwargs): return line(self._da, *args, **kwargs) + @functools.wraps(step) + def step(self, *args, **kwargs): + return step(self._da, *args, **kwargs) + def _rescale_imshow_rgb(darray, vmin, vmax, robust): assert robust or vmin is not None or vmax is not None + # TODO: remove when min numpy version is bumped to 1.13 # There's a cyclic dependency via DataArray, so we can't import from # xarray.ufuncs in global scope. from xarray.ufuncs import maximum, minimum + # Calculate vmin and vmax automatically for `robust=True` if robust: if vmax is None: @@ -536,7 +669,10 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust): # After scaling, downcast to 32-bit float. This substantially reduces # memory usage after we hand `darray` off to matplotlib. darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4') - return minimum(maximum(darray, 0), 1) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'xarray.ufuncs', + PendingDeprecationWarning) + return minimum(maximum(darray, 0), 1) def _plot2d(plotfunc): @@ -572,6 +708,10 @@ def _plot2d(plotfunc): If passed, make column faceted plots on this dimension name col_wrap : integer, optional Use together with ``col`` to wrap faceted plots + xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional + Specifies scaling for the x- and y-axes respectively + xticks, yticks : Specify tick locations for x- and y-axes + xlim, ylim : Specify x- and y-axes limits xincrease : None, True, or False, optional Should the values on the x axes be increasing from left to right? if None, use the default for the matplotlib function. @@ -582,6 +722,9 @@ def _plot2d(plotfunc): Adds colorbar to axis add_labels : Boolean, optional Use xarray metadata to label axes + norm : ``matplotlib.colors.Normalize`` instance, optional + If the ``norm`` has vmin or vmax specified, the corresponding kwarg + must be None. vmin, vmax : floats, optional Values to anchor the colormap, otherwise they are inferred from the data and other keyword arguments. When a diverging dataset is inferred, @@ -649,7 +792,8 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, cmap=None, center=None, robust=False, extend=None, levels=None, infer_intervals=None, colors=None, subplot_kws=None, cbar_ax=None, cbar_kwargs=None, - **kwargs): + xscale=None, yscale=None, xticks=None, yticks=None, + xlim=None, ylim=None, norm=None, **kwargs): # All 2d plots in xarray share this function signature. # Method signature below should be consistent. @@ -732,7 +876,11 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # Pass the data as a masked ndarray too zval = darray.to_masked_array(copy=False) - _ensure_plottable(xval, yval) + # Replace pd.Intervals if contained in xval or yval. + xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__) + yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__) + + _ensure_plottable(xplt, yplt) if 'contour' in plotfunc.__name__ and levels is None: levels = 7 # this is the matplotlib default @@ -746,6 +894,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, 'extend': extend, 'levels': levels, 'filled': plotfunc.__name__ != 'contour', + 'norm': norm, } cmap_params = _determine_cmap_params(**cmap_kwargs) @@ -756,28 +905,31 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # pcolormesh kwargs['extend'] = cmap_params['extend'] kwargs['levels'] = cmap_params['levels'] + # if colors == a single color, matplotlib draws dashed negative + # contours. we lose this feature if we pass cmap and not colors + if isinstance(colors, basestring): + cmap_params['cmap'] = None + kwargs['colors'] = colors if 'pcolormesh' == plotfunc.__name__: kwargs['infer_intervals'] = infer_intervals - # This allows the user to pass in a custom norm coming via kwargs - kwargs.setdefault('norm', cmap_params['norm']) - if 'imshow' == plotfunc.__name__ and isinstance(aspect, basestring): # forbid usage of mpl strings raise ValueError("plt.imshow's `aspect` kwarg is not available " "in xarray") ax = get_axis(figsize, size, aspect, ax) - primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'], + primitive = plotfunc(xplt, yplt, zval, ax=ax, cmap=cmap_params['cmap'], vmin=cmap_params['vmin'], vmax=cmap_params['vmax'], + norm=cmap_params['norm'], **kwargs) # Label the plot with metadata if add_labels: - ax.set_xlabel(label_from_attrs(darray[xlab])) - ax.set_ylabel(label_from_attrs(darray[ylab])) + ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) + ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) if add_colorbar: @@ -789,17 +941,27 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, cbar_kwargs.setdefault('cax', cbar_ax) cbar = plt.colorbar(primitive, **cbar_kwargs) if add_labels and 'label' not in cbar_kwargs: - cbar.set_label(label_from_attrs(darray), rotation=90) + cbar.set_label(label_from_attrs(darray)) elif cbar_ax is not None or cbar_kwargs is not None: # inform the user about keywords which aren't used raise ValueError("cbar_ax and cbar_kwargs can't be used with " "add_colorbar=False.") - _update_axes_limits(ax, xincrease, yincrease) + # origin kwarg overrides yincrease + if 'origin' in kwargs: + yincrease = None + + _update_axes(ax, xincrease, yincrease, xscale, yscale, + xticks, yticks, xlim, ylim) # Rotate dates on xlabels - if np.issubdtype(xval.dtype, np.datetime64): - ax.get_figure().autofmt_xdate() + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots + if np.issubdtype(xplt.dtype, np.datetime64): + for xlabels in ax.get_xticklabels(): + xlabels.set_rotation(30) + xlabels.set_ha('right') return primitive @@ -811,7 +973,9 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None, add_labels=True, vmin=None, vmax=None, cmap=None, colors=None, center=None, robust=False, extend=None, levels=None, infer_intervals=None, subplot_kws=None, - cbar_ax=None, cbar_kwargs=None, **kwargs): + cbar_ax=None, cbar_kwargs=None, + xscale=None, yscale=None, xticks=None, yticks=None, + xlim=None, ylim=None, norm=None, **kwargs): """ The method should have the same signature as the function. @@ -873,10 +1037,8 @@ def imshow(x, y, z, ax, **kwargs): left, right = x[0] - xstep, x[-1] + xstep bottom, top = y[-1] + ystep, y[0] - ystep - defaults = {'extent': [left, right, bottom, top], - 'origin': 'upper', - 'interpolation': 'nearest', - } + defaults = {'origin': 'upper', + 'interpolation': 'nearest'} if not hasattr(ax, 'projection'): # not for cartopy geoaxes @@ -885,6 +1047,11 @@ def imshow(x, y, z, ax, **kwargs): # Allow user to override these defaults defaults.update(kwargs) + if defaults['origin'] == 'upper': + defaults['extent'] = [left, right, bottom, top] + else: + defaults['extent'] = [left, right, top, bottom] + if z.ndim == 3: # matplotlib imshow uses black for missing data, but Xarray makes # missing data transparent. We therefore add an alpha channel if @@ -992,14 +1159,22 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): else: infer_intervals = True - if infer_intervals: + if (infer_intervals and + ((np.shape(x)[0] == np.shape(z)[1]) or + ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])))): if len(x.shape) == 1: x = _infer_interval_breaks(x, check_monotonic=True) - y = _infer_interval_breaks(y, check_monotonic=True) else: # we have to infer the intervals on both axes x = _infer_interval_breaks(x, axis=1) x = _infer_interval_breaks(x, axis=0) + + if (infer_intervals and + (np.shape(y)[0] == np.shape(z)[0])): + if len(y.shape) == 1: + y = _infer_interval_breaks(y, check_monotonic=True) + else: + # we have to infer the intervals on both axes y = _infer_interval_breaks(y, axis=1) y = _infer_interval_breaks(y, axis=0) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 4b9645e02d5..41f61554739 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -1,12 +1,13 @@ from __future__ import absolute_import, division, print_function +import itertools +import textwrap import warnings import numpy as np import pandas as pd -import pkg_resources -import textwrap +from ..core.options import OPTIONS from ..core.pycompat import basestring from ..core.utils import is_scalar @@ -173,6 +174,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, # vlim might be computed below vlim = None + # save state; needed later + vmin_was_none = vmin is None + vmax_was_none = vmax is None + if vmin is None: if robust: vmin = np.percentile(calc_data, ROBUST_PERCENTILE) @@ -205,18 +210,42 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, vmin += center vmax += center + # now check norm and harmonize with vmin, vmax + if norm is not None: + if norm.vmin is None: + norm.vmin = vmin + else: + if not vmin_was_none and vmin != norm.vmin: + raise ValueError('Cannot supply vmin and a norm' + + ' with a different vmin.') + vmin = norm.vmin + + if norm.vmax is None: + norm.vmax = vmax + else: + if not vmax_was_none and vmax != norm.vmax: + raise ValueError('Cannot supply vmax and a norm' + + ' with a different vmax.') + vmax = norm.vmax + + # if BoundaryNorm, then set levels + if isinstance(norm, mpl.colors.BoundaryNorm): + levels = norm.boundaries + # Choose default colormaps if not provided if cmap is None: if divergent: - cmap = "RdBu_r" + cmap = OPTIONS['cmap_divergent'] else: - cmap = "viridis" + cmap = OPTIONS['cmap_sequential'] # Handle discrete levels - if levels is not None: + if levels is not None and norm is None: if is_scalar(levels): - if user_minmax or levels == 1: + if user_minmax: levels = np.linspace(vmin, vmax, levels) + elif levels == 1: + levels = np.asarray([(vmin + vmax) / 2]) else: # N in MaxNLocator refers to bins, not ticks ticker = mpl.ticker.MaxNLocator(levels - 1) @@ -226,8 +255,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, if extend is None: extend = _determine_extend(calc_data, vmin, vmax) - if levels is not None: - cmap, norm = _build_discrete_cmap(cmap, levels, extend, filled) + if levels is not None or isinstance(norm, mpl.colors.BoundaryNorm): + cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled) + norm = newnorm if norm is None else norm return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, levels=levels, norm=norm) @@ -298,11 +328,11 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): raise ValueError('DataArray must be 2d') y, x = darray.dims elif x is None: - if y not in darray.dims: + if y not in darray.dims and y not in darray.coords: raise ValueError('y must be a dimension name if x is not supplied') x = darray.dims[0] if y == darray.dims[1] else darray.dims[1] elif y is None: - if x not in darray.dims: + if x not in darray.dims and x not in darray.coords: raise ValueError('x must be a dimension name if y is not supplied') y = darray.dims[0] if x == darray.dims[1] else darray.dims[1] elif any(k not in darray.coords and k not in darray.dims for k in (x, y)): @@ -339,7 +369,7 @@ def get_axis(figsize, size, aspect, ax): return ax -def label_from_attrs(da): +def label_from_attrs(da, extra=''): ''' Makes informative labels if variable metadata (attrs) follows CF conventions. ''' @@ -357,4 +387,66 @@ def label_from_attrs(da): else: units = '' - return '\n'.join(textwrap.wrap(name + units, 30)) + return '\n'.join(textwrap.wrap(name + extra + units, 30)) + + +def _interval_to_mid_points(array): + """ + Helper function which returns an array + with the Intervals' mid points. + """ + + return np.array([x.mid for x in array]) + + +def _interval_to_bound_points(array): + """ + Helper function which returns an array + with the Intervals' boundaries. + """ + + array_boundaries = np.array([x.left for x in array]) + array_boundaries = np.concatenate( + (array_boundaries, np.array([array[-1].right]))) + + return array_boundaries + + +def _interval_to_double_bound_points(xarray, yarray): + """ + Helper function to deal with a xarray consisting of pd.Intervals. Each + interval is replaced with both boundaries. I.e. the length of xarray + doubles. yarray is modified so it matches the new shape of xarray. + """ + + xarray1 = np.array([x.left for x in xarray]) + xarray2 = np.array([x.right for x in xarray]) + + xarray = list(itertools.chain.from_iterable(zip(xarray1, xarray2))) + yarray = list(itertools.chain.from_iterable(zip(yarray, yarray))) + + return xarray, yarray + + +def _resolve_intervals_2dplot(val, func_name): + """ + Helper function to replace the values of a coordinate array containing + pd.Interval with their mid-points or - for pcolormesh - boundaries which + increases length by 1. + """ + label_extra = '' + if _valid_other_type(val, [pd.Interval]): + if func_name == 'pcolormesh': + val = _interval_to_bound_points(val) + else: + val = _interval_to_mid_points(val) + label_extra = '_center' + + return val, label_extra + + +def _valid_other_type(x, types): + """ + Do all elements of x have a type from types? + """ + return all(any(isinstance(el, t) for t in types) for el in np.ravel(x)) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 3b4d69a35f7..a45f71bbc3b 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -3,17 +3,16 @@ from __future__ import print_function import warnings from contextlib import contextmanager -from distutils.version import LooseVersion +from distutils import version import re import importlib import numpy as np from numpy.testing import assert_array_equal # noqa: F401 -from xarray.core.duck_array_ops import allclose_or_equiv +from xarray.core.duck_array_ops import allclose_or_equiv # noqa import pytest from xarray.core import utils -from xarray.core.pycompat import PY3 from xarray.core.indexing import ExplicitlyIndexed from xarray.testing import (assert_equal, assert_identical, # noqa: F401 assert_allclose) @@ -25,10 +24,6 @@ # old location, for pandas < 0.20 from pandas.util.testing import assert_frame_equal # noqa: F401 -try: - import unittest2 as unittest -except ImportError: - import unittest try: from unittest import mock @@ -58,6 +53,13 @@ def _importorskip(modname, minversion=None): return has, func +def LooseVersion(vstring): + # Our development version is something like '0.10.9+aac7bfc' + # This function just ignored the git commit id. + vstring = vstring.split('+')[0] + return version.LooseVersion(vstring) + + has_matplotlib, requires_matplotlib = _importorskip('matplotlib') has_matplotlib2, requires_matplotlib2 = _importorskip('matplotlib', minversion='2') @@ -74,6 +76,8 @@ def _importorskip(modname, minversion=None): has_pathlib, requires_pathlib = _importorskip('pathlib') has_zarr, requires_zarr = _importorskip('zarr', minversion='2.2') has_np113, requires_np113 = _importorskip('numpy', minversion='1.13.0') +has_iris, requires_iris = _importorskip('iris') +has_cfgrib, requires_cfgrib = _importorskip('cfgrib') # some special cases has_scipy_or_netCDF4 = has_scipy or has_netCDF4 @@ -89,7 +93,7 @@ def _importorskip(modname, minversion=None): if LooseVersion(dask.__version__) < '0.18': dask.set_options(get=dask.get) else: - dask.config.set(scheduler='sync') + dask.config.set(scheduler='single-threaded') try: import_seaborn() has_seaborn = True @@ -116,39 +120,6 @@ def _importorskip(modname, minversion=None): "internet connection") -class TestCase(unittest.TestCase): - """ - These functions are all deprecated. Instead, use functions in xr.testing - """ - if PY3: - # Python 3 assertCountEqual is roughly equivalent to Python 2 - # assertItemsEqual - def assertItemsEqual(self, first, second, msg=None): - __tracebackhide__ = True # noqa: F841 - return self.assertCountEqual(first, second, msg) - - @contextmanager - def assertWarns(self, message): - __tracebackhide__ = True # noqa: F841 - with warnings.catch_warnings(record=True) as w: - warnings.filterwarnings('always', message) - yield - assert len(w) > 0 - assert any(message in str(wi.message) for wi in w) - - def assertVariableNotEqual(self, v1, v2): - __tracebackhide__ = True # noqa: F841 - assert not v1.equals(v2) - - def assertEqual(self, a1, a2): - __tracebackhide__ = True # noqa: F841 - assert a1 == a2 or (a1 != a1 and a2 != a2) - - def assertAllClose(self, a1, a2, rtol=1e-05, atol=1e-8): - __tracebackhide__ = True # noqa: F841 - assert allclose_or_equiv(a1, a2, rtol=rtol, atol=atol) - - @contextmanager def raises_regex(error, pattern): __tracebackhide__ = True # noqa: F841 diff --git a/xarray/tests/data/example.grib b/xarray/tests/data/example.grib new file mode 100644 index 00000000000..596a54d98a0 Binary files /dev/null and b/xarray/tests/data/example.grib differ diff --git a/xarray/tests/test_accessors.py b/xarray/tests/test_accessors.py index e1b3a95b942..38038fc8f65 100644 --- a/xarray/tests/test_accessors.py +++ b/xarray/tests/test_accessors.py @@ -7,12 +7,13 @@ import xarray as xr from . import ( - TestCase, assert_array_equal, assert_equal, raises_regex, requires_dask, - has_cftime, has_dask, has_cftime_or_netCDF4) + assert_array_equal, assert_equal, has_cftime, has_cftime_or_netCDF4, + has_dask, raises_regex, requires_dask) -class TestDatetimeAccessor(TestCase): - def setUp(self): +class TestDatetimeAccessor(object): + @pytest.fixture(autouse=True) + def setup(self): nt = 100 data = np.random.rand(10, 10, nt) lons = np.linspace(0, 11, 10) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 9ec68bb0846..fb9c43c0165 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2,12 +2,12 @@ import contextlib import itertools +import math import os.path import pickle import shutil import sys import tempfile -import unittest import warnings from io import BytesIO @@ -19,22 +19,21 @@ from xarray import ( DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, save_mfdataset) -from xarray.backends.common import (robust_getitem, - PickleByReconstructionWrapper) +from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.core import indexing from xarray.core.pycompat import ( - PY2, ExitStack, basestring, dask_array_type, iteritems) + ExitStack, basestring, dask_array_type, iteritems) +from xarray.core.options import set_options from xarray.tests import mock from . import ( - TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_identical, has_dask, has_netCDF4, has_scipy, network, raises_regex, + assert_allclose, assert_array_equal, assert_equal, assert_identical, + has_dask, has_netCDF4, has_scipy, network, raises_regex, requires_cftime, requires_dask, requires_h5netcdf, requires_netCDF4, requires_pathlib, - requires_pydap, requires_pynio, requires_rasterio, requires_scipy, - requires_scipy_or_netCDF4, requires_zarr, requires_pseudonetcdf, - requires_cftime) + requires_pseudonetcdf, requires_pydap, requires_pynio, requires_rasterio, + requires_scipy, requires_scipy_or_netCDF4, requires_zarr, requires_cfgrib) from .test_dataset import create_test_data try: @@ -106,7 +105,7 @@ def create_boolean_data(): return Dataset({'x': ('t', [True, False, False, True], attributes)}) -class TestCommon(TestCase): +class TestCommon(object): def test_robust_getitem(self): class UnreliableArrayFailure(Exception): @@ -126,19 +125,18 @@ def __getitem__(self, key): array = UnreliableArray([0]) with pytest.raises(UnreliableArrayFailure): array[0] - self.assertEqual(array[0], 0) + assert array[0] == 0 actual = robust_getitem(array, 0, catch=UnreliableArrayFailure, initial_delay=0) - self.assertEqual(actual, 0) + assert actual == 0 class NetCDF3Only(object): pass -class DatasetIOTestCases(object): - autoclose = False +class DatasetIOBase(object): engine = None file_format = None @@ -172,8 +170,7 @@ def save(self, dataset, path, **kwargs): @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine=self.engine, autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine=self.engine, **kwargs) as ds: yield ds def test_zero_dimensional_variable(self): @@ -222,11 +219,11 @@ def assert_loads(vars=None): with self.roundtrip(expected) as actual: for k, v in actual.variables.items(): # IndexVariables are eagerly loaded into memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) yield actual for k, v in actual.variables.items(): if k in vars: - self.assertTrue(v._in_memory) + assert v._in_memory assert_identical(expected, actual) with pytest.raises(AssertionError): @@ -252,14 +249,14 @@ def test_dataset_compute(self): # Test Dataset.compute() for k, v in actual.variables.items(): # IndexVariables are eagerly cached - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) computed = actual.compute() for k, v in actual.variables.items(): - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) for v in computed.variables.values(): - self.assertTrue(v._in_memory) + assert v._in_memory assert_identical(expected, actual) assert_identical(expected, computed) @@ -343,12 +340,12 @@ def test_roundtrip_string_encoded_characters(self): expected['x'].encoding['dtype'] = 'S1' with self.roundtrip(expected) as actual: assert_identical(expected, actual) - self.assertEqual(actual['x'].encoding['_Encoding'], 'utf-8') + assert actual['x'].encoding['_Encoding'] == 'utf-8' expected['x'].encoding['_Encoding'] = 'ascii' with self.roundtrip(expected) as actual: assert_identical(expected, actual) - self.assertEqual(actual['x'].encoding['_Encoding'], 'ascii') + assert actual['x'].encoding['_Encoding'] == 'ascii' def test_roundtrip_numpy_datetime_data(self): times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT']) @@ -359,7 +356,7 @@ def test_roundtrip_numpy_datetime_data(self): assert actual.t0.encoding['units'] == 'days since 1950-01-01' @requires_cftime - def test_roundtrip_cftime_datetime_data_enable_cftimeindex(self): + def test_roundtrip_cftime_datetime_data(self): from .test_coding_times import _all_cftime_date_types date_types = _all_cftime_date_types() @@ -376,21 +373,20 @@ def test_roundtrip_cftime_datetime_data_enable_cftimeindex(self): warnings.filterwarnings( 'ignore', 'Unable to decode time axis') - with xr.set_options(enable_cftimeindex=True): - with self.roundtrip(expected, save_kwargs=kwds) as actual: - abs_diff = abs(actual.t.values - expected_decoded_t) - assert (abs_diff <= np.timedelta64(1, 's')).all() - assert (actual.t.encoding['units'] == - 'days since 0001-01-01 00:00:00.000000') - assert (actual.t.encoding['calendar'] == - expected_calendar) - - abs_diff = abs(actual.t0.values - expected_decoded_t0) - assert (abs_diff <= np.timedelta64(1, 's')).all() - assert (actual.t0.encoding['units'] == - 'days since 0001-01-01') - assert (actual.t.encoding['calendar'] == - expected_calendar) + with self.roundtrip(expected, save_kwargs=kwds) as actual: + abs_diff = abs(actual.t.values - expected_decoded_t) + assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (actual.t.encoding['units'] == + 'days since 0001-01-01 00:00:00.000000') + assert (actual.t.encoding['calendar'] == + expected_calendar) + + abs_diff = abs(actual.t0.values - expected_decoded_t0) + assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (actual.t0.encoding['units'] == + 'days since 0001-01-01') + assert (actual.t.encoding['calendar'] == + expected_calendar) def test_roundtrip_timedelta_data(self): time_deltas = pd.to_timedelta(['1h', '2h', 'NaT']) @@ -434,10 +430,10 @@ def test_roundtrip_coordinates_with_space(self): def test_roundtrip_boolean_dtype(self): original = create_boolean_data() - self.assertEqual(original['x'].dtype, 'bool') + assert original['x'].dtype == 'bool' with self.roundtrip(original) as actual: assert_identical(original, actual) - self.assertEqual(actual['x'].dtype, 'bool') + assert actual['x'].dtype == 'bool' def test_orthogonal_indexing(self): in_memory = create_test_data() @@ -596,7 +592,7 @@ def test_ondisk_after_print(self): assert not on_disk['var1']._in_memory -class CFEncodedDataTest(DatasetIOTestCases): +class CFEncodedBase(DatasetIOBase): def test_roundtrip_bytes_with_fill_value(self): values = np.array([b'ab', b'cdef', np.nan], dtype=object) @@ -626,20 +622,20 @@ def test_unsigned_roundtrip_mask_and_scale(self): encoded = create_encoded_unsigned_masked_scaled_data() with self.roundtrip(decoded) as actual: for k in decoded.variables: - self.assertEqual(decoded.variables[k].dtype, - actual.variables[k].dtype) + assert (decoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(decoded, actual, decode_bytes=False) with self.roundtrip(decoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert (encoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(encoded, actual, decode_bytes=False) with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert (encoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(encoded, actual, decode_bytes=False) # make sure roundtrip encoding didn't change the # original dataset. @@ -647,14 +643,14 @@ def test_unsigned_roundtrip_mask_and_scale(self): encoded, create_encoded_unsigned_masked_scaled_data()) with self.roundtrip(encoded) as actual: for k in decoded.variables: - self.assertEqual(decoded.variables[k].dtype, - actual.variables[k].dtype) + assert decoded.variables[k].dtype == \ + actual.variables[k].dtype assert_allclose(decoded, actual, decode_bytes=False) with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert encoded.variables[k].dtype == \ + actual.variables[k].dtype assert_allclose(encoded, actual, decode_bytes=False) def test_roundtrip_mask_and_scale(self): @@ -692,12 +688,11 @@ def equals_latlon(obj): with create_tmp_file() as tmp_file: original.to_netcdf(tmp_file) with open_dataset(tmp_file, decode_coords=False) as ds: - self.assertTrue(equals_latlon(ds['temp'].attrs['coordinates'])) - self.assertTrue( - equals_latlon(ds['precip'].attrs['coordinates'])) - self.assertNotIn('coordinates', ds.attrs) - self.assertNotIn('coordinates', ds['lat'].attrs) - self.assertNotIn('coordinates', ds['lon'].attrs) + assert equals_latlon(ds['temp'].attrs['coordinates']) + assert equals_latlon(ds['precip'].attrs['coordinates']) + assert 'coordinates' not in ds.attrs + assert 'coordinates' not in ds['lat'].attrs + assert 'coordinates' not in ds['lon'].attrs modified = original.drop(['temp', 'precip']) with self.roundtrip(modified) as actual: @@ -705,9 +700,9 @@ def equals_latlon(obj): with create_tmp_file() as tmp_file: modified.to_netcdf(tmp_file) with open_dataset(tmp_file, decode_coords=False) as ds: - self.assertTrue(equals_latlon(ds.attrs['coordinates'])) - self.assertNotIn('coordinates', ds['lat'].attrs) - self.assertNotIn('coordinates', ds['lon'].attrs) + assert equals_latlon(ds.attrs['coordinates']) + assert 'coordinates' not in ds['lat'].attrs + assert 'coordinates' not in ds['lon'].attrs def test_roundtrip_endian(self): ds = Dataset({'x': np.arange(3, 10, dtype='>i2'), @@ -743,8 +738,8 @@ def test_encoding_kwarg(self): ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(ds.x.encoding, {}) + assert actual.x.encoding['dtype'] == 'f4' + assert ds.x.encoding == {} kwargs = dict(encoding={'x': {'foo': 'bar'}}) with raises_regex(ValueError, 'unexpected encoding'): @@ -766,7 +761,7 @@ def test_encoding_kwarg_dates(self): units = 'days since 1900-01-01' kwargs = dict(encoding={'t': {'units': units}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.t.encoding['units'], units) + assert actual.t.encoding['units'] == units assert_identical(actual, ds) def test_encoding_kwarg_fixed_width_string(self): @@ -778,7 +773,7 @@ def test_encoding_kwarg_fixed_width_string(self): ds = Dataset({'x': strings}) kwargs = dict(encoding={'x': {'dtype': 'S1'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual['x'].encoding['dtype'], 'S1') + assert actual['x'].encoding['dtype'] == 'S1' assert_identical(actual, ds) def test_default_fill_value(self): @@ -786,9 +781,8 @@ def test_default_fill_value(self): ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['_FillValue'], - np.nan) - self.assertEqual(ds.x.encoding, {}) + assert math.isnan(actual.x.encoding['_FillValue']) + assert ds.x.encoding == {} # Test default encoding for int: ds = Dataset({'x': ('y', np.arange(10.0))}) @@ -797,14 +791,14 @@ def test_default_fill_value(self): warnings.filterwarnings( 'ignore', '.*floating point data as an integer') with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertTrue('_FillValue' not in actual.x.encoding) - self.assertEqual(ds.x.encoding, {}) + assert '_FillValue' not in actual.x.encoding + assert ds.x.encoding == {} # Test default encoding for implicit int: ds = Dataset({'x': ('y', np.arange(10, dtype='int16'))}) with self.roundtrip(ds) as actual: - self.assertTrue('_FillValue' not in actual.x.encoding) - self.assertEqual(ds.x.encoding, {}) + assert '_FillValue' not in actual.x.encoding + assert ds.x.encoding == {} def test_explicitly_omit_fill_value(self): ds = Dataset({'x': ('y', [np.pi, -np.pi])}) @@ -817,7 +811,7 @@ def test_explicitly_omit_fill_value_via_encoding_kwarg(self): kwargs = dict(encoding={'x': {'_FillValue': None}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: assert '_FillValue' not in actual.x.encoding - self.assertEqual(ds.y.encoding, {}) + assert ds.y.encoding == {} def test_explicitly_omit_fill_value_in_coord(self): ds = Dataset({'x': ('y', [np.pi, -np.pi])}, coords={'y': [0.0, 1.0]}) @@ -830,14 +824,14 @@ def test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg(self): kwargs = dict(encoding={'y': {'_FillValue': None}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: assert '_FillValue' not in actual.y.encoding - self.assertEqual(ds.y.encoding, {}) + assert ds.y.encoding == {} def test_encoding_same_dtype(self): ds = Dataset({'x': ('y', np.arange(10.0, dtype='f4'))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(ds.x.encoding, {}) + assert actual.x.encoding['dtype'] == 'f4' + assert ds.x.encoding == {} def test_append_write(self): # regression for GH1215 @@ -900,7 +894,7 @@ def create_tmp_files(nfiles, suffix='.nc', allow_cleanup_failure=False): yield files -class BaseNetCDF4Test(CFEncodedDataTest): +class NetCDF4Base(CFEncodedBase): """Tests for both netCDF4-python and h5netcdf.""" engine = 'netcdf4' @@ -1015,7 +1009,7 @@ def test_default_to_char_arrays(self): data = Dataset({'x': np.array(['foo', 'zzzz'], dtype='S')}) with self.roundtrip(data) as actual: assert_identical(data, actual) - self.assertEqual(actual['x'].dtype, np.dtype('S4')) + assert actual['x'].dtype == np.dtype('S4') def test_open_encodings(self): # Create a netCDF file with explicit time units @@ -1040,15 +1034,15 @@ def test_open_encodings(self): actual_encoding = dict((k, v) for k, v in iteritems(actual['time'].encoding) if k in expected['time'].encoding) - self.assertDictEqual(actual_encoding, - expected['time'].encoding) + assert actual_encoding == \ + expected['time'].encoding def test_dump_encodings(self): # regression test for #709 ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'zlib': True}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertTrue(actual.x.encoding['zlib']) + assert actual.x.encoding['zlib'] def test_dump_and_open_encodings(self): # Create a netCDF file with explicit time units @@ -1066,8 +1060,7 @@ def test_dump_and_open_encodings(self): with create_tmp_file() as tmp_file2: xarray_dataset.to_netcdf(tmp_file2) with nc4.Dataset(tmp_file2, 'r') as ds: - self.assertEqual( - ds.variables['time'].getncattr('units'), units) + assert ds.variables['time'].getncattr('units') == units assert_array_equal( ds.variables['time'], np.arange(10) + 4) @@ -1080,7 +1073,7 @@ def test_compression_encoding(self): 'original_shape': data.var2.shape}) with self.roundtrip(data) as actual: for k, v in iteritems(data['var2'].encoding): - self.assertEqual(v, actual['var2'].encoding[k]) + assert v == actual['var2'].encoding[k] # regression test for #156 expected = data.isel(dim1=0) @@ -1095,14 +1088,14 @@ def test_encoding_kwarg_compression(self): with self.roundtrip(ds, save_kwargs=kwargs) as actual: assert_equal(actual, ds) - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(actual.x.encoding['zlib'], True) - self.assertEqual(actual.x.encoding['complevel'], 9) - self.assertEqual(actual.x.encoding['fletcher32'], True) - self.assertEqual(actual.x.encoding['chunksizes'], (5,)) - self.assertEqual(actual.x.encoding['shuffle'], True) + assert actual.x.encoding['dtype'] == 'f4' + assert actual.x.encoding['zlib'] + assert actual.x.encoding['complevel'] == 9 + assert actual.x.encoding['fletcher32'] + assert actual.x.encoding['chunksizes'] == (5,) + assert actual.x.encoding['shuffle'] - self.assertEqual(ds.x.encoding, {}) + assert ds.x.encoding == {} def test_encoding_chunksizes_unlimited(self): # regression test for GH1225 @@ -1162,10 +1155,10 @@ def test_already_open_dataset(self): v[...] = 42 nc = nc4.Dataset(tmp_file, mode='r') - with backends.NetCDF4DataStore(nc, autoclose=False) as store: - with open_dataset(store) as ds: - expected = Dataset({'x': ((), 42)}) - assert_identical(expected, ds) + store = backends.NetCDF4DataStore(nc) + with open_dataset(store) as ds: + expected = Dataset({'x': ((), 42)}) + assert_identical(expected, ds) def test_read_variable_len_strings(self): with create_tmp_file() as tmp_file: @@ -1183,8 +1176,7 @@ def test_read_variable_len_strings(self): @requires_netCDF4 -class NetCDF4DataTest(BaseNetCDF4Test, TestCase): - autoclose = False +class TestNetCDF4Data(NetCDF4Base): @contextlib.contextmanager def create_store(self): @@ -1201,7 +1193,7 @@ def test_variable_order(self): ds.coords['c'] = 4 with self.roundtrip(ds) as actual: - self.assertEqual(list(ds.variables), list(actual.variables)) + assert list(ds.variables) == list(actual.variables) def test_unsorted_index_raises(self): # should be fixed in netcdf4 v1.2.1 @@ -1220,7 +1212,7 @@ def test_unsorted_index_raises(self): try: ds2.randovar.values except IndexError as err: - self.assertIn('first by calling .load', str(err)) + assert 'first by calling .load' in str(err) def test_88_character_filename_segmentation_fault(self): # should be fixed in netcdf4 v1.3.1 @@ -1250,18 +1242,22 @@ def test_setncattr_string(self): totest.attrs['bar']) assert one_string == totest.attrs['baz'] - -class NetCDF4DataStoreAutocloseTrue(NetCDF4DataTest): - autoclose = True + def test_autoclose_future_warning(self): + data = create_test_data() + with create_tmp_file() as tmp_file: + self.save(data, tmp_file) + with pytest.warns(FutureWarning): + with self.open(tmp_file, autoclose=True) as actual: + assert_identical(data, actual) @requires_netCDF4 @requires_dask -class NetCDF4ViaDaskDataTest(NetCDF4DataTest): +class TestNetCDF4ViaDaskData(TestNetCDF4Data): @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}, allow_cleanup_failure=False): - with NetCDF4DataTest.roundtrip( + with TestNetCDF4Data.roundtrip( self, data, save_kwargs, open_kwargs, allow_cleanup_failure) as ds: yield ds.chunk() @@ -1293,12 +1289,8 @@ def test_write_inconsistent_chunks(self): assert actual['y'].encoding['chunksizes'] == (100, 50) -class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest): - autoclose = True - - @requires_zarr -class BaseZarrTest(CFEncodedDataTest): +class ZarrBase(CFEncodedBase): DIMENSION_KEY = '_ARRAY_DIMENSIONS' @@ -1335,17 +1327,17 @@ def test_auto_chunk(self): original, open_kwargs={'auto_chunk': False}) as actual: for k, v in actual.variables.items(): # only index variables should be in memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) # there should be no chunks - self.assertEqual(v.chunks, None) + assert v.chunks is None with self.roundtrip( original, open_kwargs={'auto_chunk': True}) as actual: for k, v in actual.variables.items(): # only index variables should be in memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) # chunk size should be the same as original - self.assertEqual(v.chunks, original[k].chunks) + assert v.chunks == original[k].chunks def test_write_uneven_dask_chunks(self): # regression for GH#2225 @@ -1365,7 +1357,7 @@ def test_chunk_encoding(self): data['var2'].encoding.update({'chunks': chunks}) with self.roundtrip(data) as actual: - self.assertEqual(chunks, actual['var2'].encoding['chunks']) + assert chunks == actual['var2'].encoding['chunks'] # expect an error with non-integer chunks data['var2'].encoding.update({'chunks': (5, 4.5)}) @@ -1382,7 +1374,7 @@ def test_chunk_encoding_with_dask(self): # zarr automatically gets chunk information from dask chunks ds_chunk4 = ds.chunk({'x': 4}) with self.roundtrip(ds_chunk4) as actual: - self.assertEqual((4,), actual['var1'].encoding['chunks']) + assert (4,) == actual['var1'].encoding['chunks'] # should fail if dask_chunks are irregular... ds_chunk_irreg = ds.chunk({'x': (5, 4, 3)}) @@ -1395,21 +1387,18 @@ def test_chunk_encoding_with_dask(self): # ... except if the last chunk is smaller than the first ds_chunk_irreg = ds.chunk({'x': (5, 5, 2)}) with self.roundtrip(ds_chunk_irreg) as actual: - self.assertEqual((5,), actual['var1'].encoding['chunks']) + assert (5,) == actual['var1'].encoding['chunks'] + # re-save Zarr arrays + with self.roundtrip(ds_chunk_irreg) as original: + with self.roundtrip(original) as actual: + assert_identical(original, actual) # - encoding specified - # specify compatible encodings for chunk_enc in 4, (4, ): ds_chunk4['var1'].encoding.update({'chunks': chunk_enc}) with self.roundtrip(ds_chunk4) as actual: - self.assertEqual((4,), actual['var1'].encoding['chunks']) - - # specify incompatible encoding - ds_chunk4['var1'].encoding.update({'chunks': (5, 5)}) - with pytest.raises(ValueError) as e_info: - with self.roundtrip(ds_chunk4) as actual: - pass - assert e_info.match('chunks') + assert (4,) == actual['var1'].encoding['chunks'] # TODO: remove this failure once syncronized overlapping writes are # supported by xarray @@ -1495,19 +1484,19 @@ def test_encoding_kwarg_fixed_width_string(self): # makes sense for Zarr backend @pytest.mark.xfail(reason="Zarr caching not implemented") def test_dataset_caching(self): - super(CFEncodedDataTest, self).test_dataset_caching() + super(CFEncodedBase, self).test_dataset_caching() @pytest.mark.xfail(reason="Zarr stores can not be appended to") def test_append_write(self): - super(CFEncodedDataTest, self).test_append_write() + super(CFEncodedBase, self).test_append_write() @pytest.mark.xfail(reason="Zarr stores can not be appended to") def test_append_overwrite_values(self): - super(CFEncodedDataTest, self).test_append_overwrite_values() + super(CFEncodedBase, self).test_append_overwrite_values() @pytest.mark.xfail(reason="Zarr stores can not be appended to") def test_append_with_invalid_dim_raises(self): - super(CFEncodedDataTest, self).test_append_with_invalid_dim_raises() + super(CFEncodedBase, self).test_append_with_invalid_dim_raises() def test_to_zarr_compute_false_roundtrip(self): from dask.delayed import Delayed @@ -1522,39 +1511,54 @@ def test_to_zarr_compute_false_roundtrip(self): with self.open(store) as actual: assert_identical(original, actual) + def test_encoding_chunksizes(self): + # regression test for GH2278 + # see also test_encoding_chunksizes_unlimited + nx, ny, nt = 4, 4, 5 + original = xr.Dataset({}, coords={'x': np.arange(nx), + 'y': np.arange(ny), + 't': np.arange(nt)}) + original['v'] = xr.Variable(('x', 'y', 't'), np.zeros((nx, ny, nt))) + original = original.chunk({'t': 1, 'x': 2, 'y': 2}) + + with self.roundtrip(original) as ds1: + assert_equal(ds1, original) + with self.roundtrip(ds1.isel(t=0)) as ds2: + assert_equal(ds2, original.isel(t=0)) + @requires_zarr -class ZarrDictStoreTest(BaseZarrTest, TestCase): +class TestZarrDictStore(ZarrBase): @contextlib.contextmanager def create_zarr_target(self): yield {} @requires_zarr -class ZarrDirectoryStoreTest(BaseZarrTest, TestCase): +class TestZarrDirectoryStore(ZarrBase): @contextlib.contextmanager def create_zarr_target(self): with create_tmp_file(suffix='.zarr') as tmp: yield tmp -class ScipyWriteTest(CFEncodedDataTest, NetCDF3Only): +class ScipyWriteBase(CFEncodedBase, NetCDF3Only): def test_append_write(self): import scipy if scipy.__version__ == '1.0.1': pytest.xfail('https://github.com/scipy/scipy/issues/8625') - super(ScipyWriteTest, self).test_append_write() + super(ScipyWriteBase, self).test_append_write() def test_append_overwrite_values(self): import scipy if scipy.__version__ == '1.0.1': pytest.xfail('https://github.com/scipy/scipy/issues/8625') - super(ScipyWriteTest, self).test_append_overwrite_values() + super(ScipyWriteBase, self).test_append_overwrite_values() @requires_scipy -class ScipyInMemoryDataTest(ScipyWriteTest, TestCase): +class TestScipyInMemoryData(ScipyWriteBase): engine = 'scipy' @contextlib.contextmanager @@ -1566,21 +1570,16 @@ def test_to_netcdf_explicit_engine(self): # regression test for GH1321 Dataset({'foo': 42}).to_netcdf(engine='scipy') - @pytest.mark.skipif(PY2, reason='cannot pickle BytesIO on Python 2') - def test_bytesio_pickle(self): + def test_bytes_pickle(self): data = Dataset({'foo': ('x', [1, 2, 3])}) - fobj = BytesIO(data.to_netcdf()) - with open_dataset(fobj, autoclose=self.autoclose) as ds: + fobj = data.to_netcdf() + with self.open(fobj) as ds: unpickled = pickle.loads(pickle.dumps(ds)) assert_identical(unpickled, data) -class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest): - autoclose = True - - @requires_scipy -class ScipyFileObjectTest(ScipyWriteTest, TestCase): +class TestScipyFileObject(ScipyWriteBase): engine = 'scipy' @contextlib.contextmanager @@ -1608,7 +1607,7 @@ def test_pickle_dataarray(self): @requires_scipy -class ScipyFilePathTest(ScipyWriteTest, TestCase): +class TestScipyFilePath(ScipyWriteBase): engine = 'scipy' @contextlib.contextmanager @@ -1632,7 +1631,7 @@ def test_netcdf3_endianness(self): # regression test for GH416 expected = open_example_dataset('bears.nc', engine='scipy') for var in expected.variables.values(): - self.assertTrue(var.dtype.isnative) + assert var.dtype.isnative @requires_netCDF4 def test_nc4_scipy(self): @@ -1644,12 +1643,8 @@ def test_nc4_scipy(self): open_dataset(tmp_file, engine='scipy') -class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest): - autoclose = True - - @requires_netCDF4 -class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class TestNetCDF3ViaNetCDF4Data(CFEncodedBase, NetCDF3Only): engine = 'netcdf4' file_format = 'NETCDF3_CLASSIC' @@ -1668,13 +1663,8 @@ def test_encoding_kwarg_vlen_string(self): pass -class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): - autoclose = True - - @requires_netCDF4 -class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, - TestCase): +class TestNetCDF4ClassicViaNetCDF4Data(CFEncodedBase, NetCDF3Only): engine = 'netcdf4' file_format = 'NETCDF4_CLASSIC' @@ -1686,13 +1676,8 @@ def create_store(self): yield store -class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue( - NetCDF4ClassicViaNetCDF4DataTest): - autoclose = True - - @requires_scipy_or_netCDF4 -class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class TestGenericNetCDFData(CFEncodedBase, NetCDF3Only): # verify that we can read and write netCDF3 files as long as we have scipy # or netCDF4-python installed file_format = 'netcdf3_64bit' @@ -1746,34 +1731,30 @@ def test_encoding_unlimited_dims(self): ds = Dataset({'x': ('y', np.arange(10.0))}) with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=['y'])) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) # Regression test for https://github.com/pydata/xarray/issues/2134 with self.roundtrip(ds, save_kwargs=dict(unlimited_dims='y')) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) ds.encoding = {'unlimited_dims': ['y']} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) # Regression test for https://github.com/pydata/xarray/issues/2134 ds.encoding = {'unlimited_dims': 'y'} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) -class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest): - autoclose = True - - @requires_h5netcdf @requires_netCDF4 -class H5NetCDFDataTest(BaseNetCDF4Test, TestCase): +class TestH5NetCDFData(NetCDF4Base): engine = 'h5netcdf' @contextlib.contextmanager @@ -1781,10 +1762,14 @@ def create_store(self): with create_tmp_file() as tmp_file: yield backends.H5NetCDFStore(tmp_file, 'w') + @pytest.mark.filterwarnings('ignore:complex dtypes are supported by h5py') def test_complex(self): expected = Dataset({'x': ('y', np.ones(5) + 1j * np.ones(5))}) - with self.roundtrip(expected) as actual: - assert_equal(expected, actual) + with pytest.warns(FutureWarning): + # TODO: make it possible to write invalid netCDF files from xarray + # without a warning + with self.roundtrip(expected) as actual: + assert_equal(expected, actual) @pytest.mark.xfail(reason='https://github.com/pydata/xarray/issues/535') def test_cross_engine_read_write_netcdf4(self): @@ -1813,11 +1798,11 @@ def test_encoding_unlimited_dims(self): ds = Dataset({'x': ('y', np.arange(10.0))}) with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=['y'])) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) ds.encoding = {'unlimited_dims': ['y']} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) def test_compression_encoding_h5py(self): @@ -1848,7 +1833,7 @@ def test_compression_encoding_h5py(self): compr_out.update(compr_common) with self.roundtrip(data) as actual: for k, v in compr_out.items(): - self.assertEqual(v, actual['var2'].encoding[k]) + assert v == actual['var2'].encoding[k] def test_compression_check_encoding_h5py(self): """When mismatched h5py and NetCDF4-Python encodings are expressed @@ -1889,20 +1874,14 @@ def test_dump_encodings_h5py(self): kwargs = {'encoding': {'x': { 'compression': 'gzip', 'compression_opts': 9}}} with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['zlib'], True) - self.assertEqual(actual.x.encoding['complevel'], 9) + assert actual.x.encoding['zlib'] + assert actual.x.encoding['complevel'] == 9 kwargs = {'encoding': {'x': { 'compression': 'lzf', 'compression_opts': None}}} with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['compression'], 'lzf') - self.assertEqual(actual.x.encoding['compression_opts'], None) - - -# tests pending h5netcdf fix -@unittest.skip -class H5NetCDFDataTestAutocloseTrue(H5NetCDFDataTest): - autoclose = True + assert actual.x.encoding['compression'] == 'lzf' + assert actual.x.encoding['compression_opts'] is None @pytest.fixture(params=['scipy', 'netcdf4', 'h5netcdf', 'pynio']) @@ -1910,14 +1889,19 @@ def readengine(request): return request.param -@pytest.fixture(params=[1, 100]) +@pytest.fixture(params=[1, 20]) def nfiles(request): return request.param -@pytest.fixture(params=[True, False]) -def autoclose(request): - return request.param +@pytest.fixture(params=[5, None]) +def file_cache_maxsize(request): + maxsize = request.param + if maxsize is not None: + with set_options(file_cache_maxsize=maxsize): + yield maxsize + else: + yield maxsize @pytest.fixture(params=[True, False]) @@ -1940,8 +1924,8 @@ def skip_if_not_engine(engine): pytest.importorskip(engine) -def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, - chunks): +def test_open_mfdataset_manyfiles(readengine, nfiles, parallel, chunks, + file_cache_maxsize): # skip certain combinations skip_if_not_engine(readengine) @@ -1949,9 +1933,6 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, if not has_dask and parallel: pytest.skip('parallel requires dask') - if readengine == 'h5netcdf' and autoclose: - pytest.skip('h5netcdf does not support autoclose yet') - if ON_WINDOWS: pytest.skip('Skipping on Windows') @@ -1967,7 +1948,7 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, # check that calculation on opened datasets works properly actual = open_mfdataset(tmpfiles, engine=readengine, parallel=parallel, - autoclose=autoclose, chunks=chunks) + chunks=chunks) # check that using open_mfdataset returns dask arrays for variables assert isinstance(actual['foo'].data, dask_array_type) @@ -1976,7 +1957,7 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, @requires_scipy_or_netCDF4 -class OpenMFDatasetWithDataVarsAndCoordsKwTest(TestCase): +class TestOpenMFDatasetWithDataVarsAndCoordsKw(object): coord_name = 'lon' var_name = 'v1' @@ -2047,9 +2028,9 @@ def test_common_coord_when_datavars_all(self): var_shape = ds[self.var_name].shape - self.assertEqual(var_shape, coord_shape) - self.assertNotEqual(coord_shape1, coord_shape) - self.assertNotEqual(coord_shape2, coord_shape) + assert var_shape == coord_shape + assert coord_shape1 != coord_shape + assert coord_shape2 != coord_shape def test_common_coord_when_datavars_minimal(self): opt = 'minimal' @@ -2064,9 +2045,9 @@ def test_common_coord_when_datavars_minimal(self): var_shape = ds[self.var_name].shape - self.assertNotEqual(var_shape, coord_shape) - self.assertEqual(coord_shape1, coord_shape) - self.assertEqual(coord_shape2, coord_shape) + assert var_shape != coord_shape + assert coord_shape1 == coord_shape + assert coord_shape2 == coord_shape def test_invalid_data_vars_value_should_fail(self): @@ -2084,7 +2065,7 @@ def test_invalid_data_vars_value_should_fail(self): @requires_dask @requires_scipy @requires_netCDF4 -class DaskTest(TestCase, DatasetIOTestCases): +class TestDask(DatasetIOBase): @contextlib.contextmanager def create_store(self): yield Dataset() @@ -2094,7 +2075,7 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, allow_cleanup_failure=False): yield data.chunk() - # Override methods in DatasetIOTestCases - not applicable to dask + # Override methods in DatasetIOBase - not applicable to dask def test_roundtrip_string_encoded_characters(self): pass @@ -2102,35 +2083,15 @@ def test_roundtrip_coordinates_with_space(self): pass def test_roundtrip_numpy_datetime_data(self): - # Override method in DatasetIOTestCases - remove not applicable + # Override method in DatasetIOBase - remove not applicable # save_kwds times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT']) expected = Dataset({'t': ('t', times), 't0': times[0]}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) - def test_roundtrip_cftime_datetime_data_enable_cftimeindex(self): - # Override method in DatasetIOTestCases - remove not applicable - # save_kwds - from .test_coding_times import _all_cftime_date_types - - date_types = _all_cftime_date_types() - for date_type in date_types.values(): - times = [date_type(1, 1, 1), date_type(1, 1, 2)] - expected = Dataset({'t': ('t', times), 't0': times[0]}) - expected_decoded_t = np.array(times) - expected_decoded_t0 = np.array([date_type(1, 1, 1)]) - - with xr.set_options(enable_cftimeindex=True): - with self.roundtrip(expected) as actual: - abs_diff = abs(actual.t.values - expected_decoded_t) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) - - abs_diff = abs(actual.t0.values - expected_decoded_t0) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) - - def test_roundtrip_cftime_datetime_data_disable_cftimeindex(self): - # Override method in DatasetIOTestCases - remove not applicable + def test_roundtrip_cftime_datetime_data(self): + # Override method in DatasetIOBase - remove not applicable # save_kwds from .test_coding_times import _all_cftime_date_types @@ -2141,16 +2102,15 @@ def test_roundtrip_cftime_datetime_data_disable_cftimeindex(self): expected_decoded_t = np.array(times) expected_decoded_t0 = np.array([date_type(1, 1, 1)]) - with xr.set_options(enable_cftimeindex=False): - with self.roundtrip(expected) as actual: - abs_diff = abs(actual.t.values - expected_decoded_t) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + with self.roundtrip(expected) as actual: + abs_diff = abs(actual.t.values - expected_decoded_t) + assert (abs_diff <= np.timedelta64(1, 's')).all() - abs_diff = abs(actual.t0.values - expected_decoded_t0) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + abs_diff = abs(actual.t0.values - expected_decoded_t0) + assert (abs_diff <= np.timedelta64(1, 's')).all() def test_write_store(self): - # Override method in DatasetIOTestCases - not applicable to dask + # Override method in DatasetIOBase - not applicable to dask pass def test_dataset_caching(self): @@ -2166,22 +2126,20 @@ def test_open_mfdataset(self): with create_tmp_file() as tmp2: original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: - self.assertIsInstance(actual.foo.variable.data, da.Array) - self.assertEqual(actual.foo.variable.data.chunks, - ((5, 5),)) + with open_mfdataset([tmp1, tmp2]) as actual: + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == \ + ((5, 5),) assert_identical(original, actual) - with open_mfdataset([tmp1, tmp2], chunks={'x': 3}, - autoclose=self.autoclose) as actual: - self.assertEqual(actual.foo.variable.data.chunks, - ((3, 2, 3, 2),)) + with open_mfdataset([tmp1, tmp2], chunks={'x': 3}) as actual: + assert actual.foo.variable.data.chunks == \ + ((3, 2, 3, 2),) with raises_regex(IOError, 'no files to open'): - open_mfdataset('foo-bar-baz-*.nc', autoclose=self.autoclose) + open_mfdataset('foo-bar-baz-*.nc') with raises_regex(ValueError, 'wild-card'): - open_mfdataset('http://some/remote/uri', autoclose=self.autoclose) + open_mfdataset('http://some/remote/uri') @requires_pathlib def test_open_mfdataset_pathlib(self): @@ -2192,8 +2150,7 @@ def test_open_mfdataset_pathlib(self): tmp2 = Path(tmp2) original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(original, actual) def test_attrs_mfdataset(self): @@ -2209,7 +2166,7 @@ def test_attrs_mfdataset(self): with open_mfdataset([tmp1, tmp2]) as actual: # presumes that attributes inherited from # first dataset loaded - self.assertEqual(actual.test1, ds1.test1) + assert actual.test1 == ds1.test1 # attributes from ds2 are not retained, e.g., with raises_regex(AttributeError, 'no attribute'): @@ -2224,8 +2181,7 @@ def preprocess(ds): return ds.assign_coords(z=0) expected = preprocess(original) - with open_mfdataset(tmp, preprocess=preprocess, - autoclose=self.autoclose) as actual: + with open_mfdataset(tmp, preprocess=preprocess) as actual: assert_identical(expected, actual) def test_save_mfdataset_roundtrip(self): @@ -2235,8 +2191,7 @@ def test_save_mfdataset_roundtrip(self): with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_save_mfdataset_invalid(self): @@ -2262,15 +2217,14 @@ def test_save_mfdataset_pathlib_roundtrip(self): tmp1 = Path(tmp1) tmp2 = Path(tmp2) save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_open_and_do_math(self): original = Dataset({'foo': ('x', np.random.randn(10))}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: actual = 1.0 * ds assert_allclose(original, actual, decode_bytes=False) @@ -2280,8 +2234,7 @@ def test_open_mfdataset_concat_dim_none(self): data = Dataset({'x': 0}) data.to_netcdf(tmp1) Dataset({'x': np.nan}).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], concat_dim=None, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2], concat_dim=None) as actual: assert_identical(data, actual) def test_open_dataset(self): @@ -2289,13 +2242,13 @@ def test_open_dataset(self): with create_tmp_file() as tmp: original.to_netcdf(tmp) with open_dataset(tmp, chunks={'x': 5}) as actual: - self.assertIsInstance(actual.foo.variable.data, da.Array) - self.assertEqual(actual.foo.variable.data.chunks, ((5, 5),)) + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == ((5, 5),) assert_identical(original, actual) with open_dataset(tmp, chunks=5) as actual: assert_identical(original, actual) with open_dataset(tmp) as actual: - self.assertIsInstance(actual.foo.variable.data, np.ndarray) + assert isinstance(actual.foo.variable.data, np.ndarray) assert_identical(original, actual) def test_open_single_dataset(self): @@ -2308,8 +2261,7 @@ def test_open_single_dataset(self): {'baz': [100]}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset([tmp], concat_dim=dim, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp], concat_dim=dim) as actual: assert_identical(expected, actual) def test_dask_roundtrip(self): @@ -2328,65 +2280,46 @@ def test_deterministic_names(self): with create_tmp_file() as tmp: data = create_test_data() data.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: original_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: repeat_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) for var_name, dask_name in original_names.items(): - self.assertIn(var_name, dask_name) - self.assertEqual(dask_name[:13], 'open_dataset-') - self.assertEqual(original_names, repeat_names) + assert var_name in dask_name + assert dask_name[:13] == 'open_dataset-' + assert original_names == repeat_names def test_dataarray_compute(self): # Test DataArray.compute() on dask backend. - # The test for Dataset.compute() is already in DatasetIOTestCases; + # The test for Dataset.compute() is already in DatasetIOBase; # however dask is the only tested backend which supports DataArrays actual = DataArray([1, 2]).chunk() computed = actual.compute() - self.assertFalse(actual._in_memory) - self.assertTrue(computed._in_memory) + assert not actual._in_memory + assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) - def test_to_netcdf_compute_false_roundtrip(self): - from dask.delayed import Delayed - - original = create_test_data().chunk() - - with create_tmp_file() as tmp_file: - # dataset, path, **kwargs): - delayed_obj = self.save(original, tmp_file, compute=False) - assert isinstance(delayed_obj, Delayed) - delayed_obj.compute() - - with self.open(tmp_file) as actual: - assert_identical(original, actual) - def test_save_mfdataset_compute_false_roundtrip(self): from dask.delayed import Delayed original = Dataset({'foo': ('x', np.random.randn(10))}).chunk() datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] - with create_tmp_file() as tmp1: - with create_tmp_file() as tmp2: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp1: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp2: delayed_obj = save_mfdataset(datasets, [tmp1, tmp2], engine=self.engine, compute=False) assert isinstance(delayed_obj, Delayed) delayed_obj.compute() - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) -class DaskTestAutocloseTrue(DaskTest): - autoclose = True - - @requires_scipy_or_netCDF4 @requires_pydap -class PydapTest(TestCase): +class TestPydap(object): def convert_to_pydap_dataset(self, original): from pydap.model import GridType, BaseType, DatasetType ds = DatasetType('bears', **original.attrs) @@ -2418,8 +2351,8 @@ def test_cmp_local_file(self): assert_equal(actual, expected) # global attributes should be global attributes on the dataset - self.assertNotIn('NC_GLOBAL', actual.attrs) - self.assertIn('history', actual.attrs) + assert 'NC_GLOBAL' not in actual.attrs + assert 'history' in actual.attrs # we don't check attributes exactly with assertDatasetIdentical() # because the test DAP server seems to insert some extra @@ -2427,8 +2360,7 @@ def test_cmp_local_file(self): assert actual.attrs.keys() == expected.attrs.keys() with self.create_datasets() as (actual, expected): - assert_equal( - actual.isel(l=2), expected.isel(l=2)) # noqa: E741 + assert_equal(actual.isel(l=2), expected.isel(l=2)) # noqa with self.create_datasets() as (actual, expected): assert_equal(actual.isel(i=0, j=-1), @@ -2467,7 +2399,7 @@ def test_dask(self): @network @requires_scipy_or_netCDF4 @requires_pydap -class PydapOnlineTest(PydapTest): +class TestPydapOnline(TestPydap): @contextlib.contextmanager def create_datasets(self, **kwargs): url = 'http://test.opendap.org/opendap/hyrax/data/nc/bears.nc' @@ -2488,15 +2420,14 @@ def test_session(self): @requires_scipy @requires_pynio -class PyNioTest(ScipyWriteTest, TestCase): +class TestPyNio(ScipyWriteBase): def test_write_store(self): # pynio is read-only for now pass @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine='pynio', autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine='pynio', **kwargs) as ds: yield ds def save(self, dataset, path, **kwargs): @@ -2514,18 +2445,34 @@ def test_weakrefs(self): assert_identical(actual, expected) -class PyNioTestAutocloseTrue(PyNioTest): - autoclose = True +@requires_cfgrib +class TestCfGrib(object): + + def test_read(self): + expected = {'number': 2, 'time': 3, 'isobaricInhPa': 2, 'latitude': 3, + 'longitude': 4} + with open_example_dataset('example.grib', engine='cfgrib') as ds: + assert ds.dims == expected + assert list(ds.data_vars) == ['z', 't'] + assert ds['z'].min() == 12660. + + def test_read_filter_by_keys(self): + kwargs = {'filter_by_keys': {'shortName': 't'}} + expected = {'number': 2, 'time': 3, 'isobaricInhPa': 2, 'latitude': 3, + 'longitude': 4} + with open_example_dataset('example.grib', engine='cfgrib', + backend_kwargs=kwargs) as ds: + assert ds.dims == expected + assert list(ds.data_vars) == ['t'] + assert ds['t'].min() == 231. @requires_pseudonetcdf -class PseudoNetCDFFormatTest(TestCase): - autoclose = True +@pytest.mark.filterwarnings('ignore:IOAPI_ISPH is assumed to be 6370000') +class TestPseudoNetCDFFormat(object): def open(self, path, **kwargs): - return open_dataset(path, engine='pseudonetcdf', - autoclose=self.autoclose, - **kwargs) + return open_dataset(path, engine='pseudonetcdf', **kwargs) @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}, @@ -2542,7 +2489,6 @@ def test_ict_format(self): """ ictfile = open_example_dataset('example.ict', engine='pseudonetcdf', - autoclose=False, backend_kwargs={'format': 'ffi1001'}) stdattr = { 'fill_value': -9999.0, @@ -2640,7 +2586,6 @@ def test_ict_format_write(self): fmtkw = {'format': 'ffi1001'} expected = open_example_dataset('example.ict', engine='pseudonetcdf', - autoclose=False, backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, open_kwargs={'backend_kwargs': fmtkw}) as actual: @@ -2650,14 +2595,10 @@ def test_uamiv_format_read(self): """ Open a CAMx file and test data variables """ - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) - camxfile = open_example_dataset('example.uamiv', - engine='pseudonetcdf', - autoclose=True, - backend_kwargs={'format': 'uamiv'}) + + camxfile = open_example_dataset('example.uamiv', + engine='pseudonetcdf', + backend_kwargs={'format': 'uamiv'}) data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, dict(units='ppm', long_name='O3'.ljust(16), @@ -2679,17 +2620,13 @@ def test_uamiv_format_mfread(self): """ Open a CAMx file and test data variables """ - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) - camxfile = open_example_mfdataset( - ['example.uamiv', - 'example.uamiv'], - engine='pseudonetcdf', - autoclose=True, - concat_dim='TSTEP', - backend_kwargs={'format': 'uamiv'}) + + camxfile = open_example_mfdataset( + ['example.uamiv', + 'example.uamiv'], + engine='pseudonetcdf', + concat_dim='TSTEP', + backend_kwargs={'format': 'uamiv'}) data1 = np.arange(20, dtype='f').reshape(1, 1, 4, 5) data = np.concatenate([data1] * 2, axis=0) @@ -2701,30 +2638,28 @@ def test_uamiv_format_mfread(self): data1 = np.array(['2002-06-03'], 'datetime64[ns]') data = np.concatenate([data1] * 2, axis=0) - expected = xr.Variable(('TSTEP',), data, - dict(bounds='time_bounds', - long_name=('synthesized time coordinate ' + - 'from SDATE, STIME, STEP ' + - 'global attributes'))) + attrs = dict(bounds='time_bounds', + long_name=('synthesized time coordinate ' + + 'from SDATE, STIME, STEP ' + + 'global attributes')) + expected = xr.Variable(('TSTEP',), data, attrs) actual = camxfile.variables['time'] assert_allclose(expected, actual) camxfile.close() def test_uamiv_format_write(self): fmtkw = {'format': 'uamiv'} - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) - expected = open_example_dataset('example.uamiv', - engine='pseudonetcdf', - autoclose=False, - backend_kwargs=fmtkw) + + expected = open_example_dataset('example.uamiv', + engine='pseudonetcdf', + backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, open_kwargs={'backend_kwargs': fmtkw}) as actual: assert_identical(expected, actual) + expected.close() + def save(self, dataset, path, **save_kwargs): import PseudoNetCDF as pnc pncf = pnc.PseudoNetCDFFile() @@ -2789,7 +2724,7 @@ def create_tmp_geotiff(nx=4, ny=3, nz=3, @requires_rasterio -class TestRasterio(TestCase): +class TestRasterio(object): @requires_scipy_or_netCDF4 def test_serialization(self): @@ -2809,6 +2744,7 @@ def test_utm(self): assert isinstance(rioda.attrs['res'], tuple) assert isinstance(rioda.attrs['is_tiled'], np.uint8) assert isinstance(rioda.attrs['transform'], tuple) + assert len(rioda.attrs['transform']) == 6 np.testing.assert_array_equal(rioda.attrs['nodatavals'], [np.NaN, np.NaN, np.NaN]) @@ -2830,9 +2766,11 @@ def test_non_rectilinear(self): assert isinstance(rioda.attrs['res'], tuple) assert isinstance(rioda.attrs['is_tiled'], np.uint8) assert isinstance(rioda.attrs['transform'], tuple) + assert len(rioda.attrs['transform']) == 6 # See if a warning is raised if we force it - with self.assertWarns("transformation isn't rectilinear"): + with pytest.warns(Warning, + match="transformation isn't rectilinear"): with xr.open_rasterio(tmp_file, parse_coordinates=True) as rioda: assert 'x' not in rioda.coords @@ -2849,6 +2787,7 @@ def test_platecarree(self): assert isinstance(rioda.attrs['res'], tuple) assert isinstance(rioda.attrs['is_tiled'], np.uint8) assert isinstance(rioda.attrs['transform'], tuple) + assert len(rioda.attrs['transform']) == 6 np.testing.assert_array_equal(rioda.attrs['nodatavals'], [-9765.]) @@ -2886,6 +2825,7 @@ def test_notransform(self): assert isinstance(rioda.attrs['res'], tuple) assert isinstance(rioda.attrs['is_tiled'], np.uint8) assert isinstance(rioda.attrs['transform'], tuple) + assert len(rioda.attrs['transform']) == 6 def test_indexing(self): with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.], @@ -2921,6 +2861,10 @@ def test_indexing(self): assert_allclose(expected.isel(**ind), actual.isel(**ind)) assert not actual.variable._in_memory + ind = {'band': 0, 'x': np.array([0, 0]), 'y': np.array([1, 1, 1])} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + # minus-stepped slice ind = {'band': np.array([2, 1, 0]), 'x': slice(-1, None, -1), 'y': 0} @@ -2932,12 +2876,16 @@ def test_indexing(self): assert_allclose(expected.isel(**ind), actual.isel(**ind)) assert not actual.variable._in_memory - # None is selected + # empty selection ind = {'band': np.array([2, 1, 0]), 'x': 1, 'y': slice(2, 2, 1)} assert_allclose(expected.isel(**ind), actual.isel(**ind)) assert not actual.variable._in_memory + ind = {'band': slice(0, 0), 'x': 1, 'y': 2} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + # vectorized indexer ind = {'band': DataArray([2, 1, 0], dims='a'), 'x': DataArray([1, 0, 0], dims='a'), @@ -3013,7 +2961,7 @@ def test_chunks(self): with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) assert 'open_rasterio' in actual.data.name # do some arithmetic @@ -3076,6 +3024,7 @@ def test_ENVI_tags(self): assert isinstance(rioda.attrs['res'], tuple) assert isinstance(rioda.attrs['is_tiled'], np.uint8) assert isinstance(rioda.attrs['transform'], tuple) + assert len(rioda.attrs['transform']) == 6 # from ENVI tags assert isinstance(rioda.attrs['description'], basestring) assert isinstance(rioda.attrs['map_info'], basestring) @@ -3093,7 +3042,7 @@ def test_no_mftime(self): with mock.patch('os.path.getmtime', side_effect=OSError): with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) assert_allclose(actual, expected) @network @@ -3106,10 +3055,10 @@ def test_http_url(self): # make sure chunking works with xr.open_rasterio(url, chunks=(1, 256, 256)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) -class TestEncodingInvalid(TestCase): +class TestEncodingInvalid(object): def test_extract_nc4_variable_encoding(self): var = xr.Variable(('x',), [1, 2, 3], {}, {'foo': 'bar'}) @@ -3118,12 +3067,12 @@ def test_extract_nc4_variable_encoding(self): var = xr.Variable(('x',), [1, 2, 3], {}, {'chunking': (2, 1)}) encoding = _extract_nc4_variable_encoding(var) - self.assertEqual({}, encoding) + assert {} == encoding # regression test var = xr.Variable(('x',), [1, 2, 3], {}, {'shuffle': True}) encoding = _extract_nc4_variable_encoding(var, raise_on_invalid=True) - self.assertEqual({'shuffle': True}, encoding) + assert {'shuffle': True} == encoding def test_extract_h5nc_encoding(self): # not supported with h5netcdf (yet) @@ -3138,7 +3087,7 @@ class MiscObject: @requires_netCDF4 -class TestValidateAttrs(TestCase): +class TestValidateAttrs(object): def test_validating_attrs(self): def new_dataset(): return Dataset({'data': ('y', np.arange(10.0))}, @@ -3238,7 +3187,7 @@ def new_dataset_and_coord_attrs(): @requires_scipy_or_netCDF4 -class TestDataArrayToNetCDF(TestCase): +class TestDataArrayToNetCDF(object): def test_dataarray_to_netcdf_no_name(self): original_da = DataArray(np.arange(12).reshape((3, 4))) @@ -3299,27 +3248,10 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): assert_identical(original_da, loaded_da) -def test_pickle_reconstructor(): - - lines = ['foo bar spam eggs'] - - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp: - with open(tmp, 'w') as f: - f.writelines(lines) - - obj = PickleByReconstructionWrapper(open, tmp) - - assert obj.value.readlines() == lines - - p_obj = pickle.dumps(obj) - obj.value.close() # for windows - obj2 = pickle.loads(p_obj) - - assert obj2.value.readlines() == lines - - # roundtrip again to make sure we can fully restore the state - p_obj2 = pickle.dumps(obj2) - obj2.value.close() # for windows - obj3 = pickle.loads(p_obj2) - - assert obj3.value.readlines() == lines +@requires_scipy_or_netCDF4 +def test_no_warning_from_dask_effective_get(): + with create_tmp_file() as tmpfile: + with pytest.warns(None) as record: + ds = Dataset() + ds.to_netcdf(tmpfile) + assert len(record) == 0 diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py new file mode 100644 index 00000000000..ed49dd721d2 --- /dev/null +++ b/xarray/tests/test_backends_api.py @@ -0,0 +1,22 @@ + +import pytest + +from xarray.backends.api import _get_default_engine +from . import requires_netCDF4, requires_scipy + + +@requires_netCDF4 +@requires_scipy +def test__get_default_engine(): + engine_remote = _get_default_engine('http://example.org/test.nc', + allow_remote=True) + assert engine_remote == 'netcdf4' + + engine_gz = _get_default_engine('/example.gz') + assert engine_gz == 'scipy' + + with pytest.raises(ValueError): + _get_default_engine('/example.grib') + + engine_default = _get_default_engine('/example') + assert engine_default == 'netcdf4' diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py new file mode 100644 index 00000000000..591c981cd45 --- /dev/null +++ b/xarray/tests/test_backends_file_manager.py @@ -0,0 +1,114 @@ +import pickle +import threading +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.file_manager import CachingFileManager +from xarray.backends.lru_cache import LRUCache + + +@pytest.fixture(params=[1, 2, 3, None]) +def file_cache(request): + maxsize = request.param + if maxsize is None: + yield {} + else: + yield LRUCache(maxsize) + + +def test_file_manager_mock_write(file_cache): + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + lock = mock.MagicMock(spec=threading.Lock()) + + manager = CachingFileManager( + opener, 'filename', lock=lock, cache=file_cache) + f = manager.acquire() + f.write('contents') + manager.close() + + assert not file_cache + opener.assert_called_once_with('filename') + mock_file.write.assert_called_once_with('contents') + mock_file.close.assert_called_once_with() + lock.__enter__.assert_has_calls([mock.call(), mock.call()]) + + +def test_file_manager_write_consecutive(tmpdir, file_cache): + path1 = str(tmpdir.join('testing1.txt')) + path2 = str(tmpdir.join('testing2.txt')) + manager1 = CachingFileManager(open, path1, mode='w', cache=file_cache) + manager2 = CachingFileManager(open, path2, mode='w', cache=file_cache) + f1a = manager1.acquire() + f1a.write('foo') + f1a.flush() + f2 = manager2.acquire() + f2.write('bar') + f2.flush() + f1b = manager1.acquire() + f1b.write('baz') + assert (getattr(file_cache, 'maxsize', float('inf')) > 1) == (f1a is f1b) + manager1.close() + manager2.close() + + with open(path1, 'r') as f: + assert f.read() == 'foobaz' + with open(path2, 'r') as f: + assert f.read() == 'bar' + + +def test_file_manager_write_concurrent(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) + f1 = manager.acquire() + f2 = manager.acquire() + f3 = manager.acquire() + assert f1 is f2 + assert f2 is f3 + f1.write('foo') + f1.flush() + f2.write('bar') + f2.flush() + f3.write('baz') + f3.flush() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobarbaz' + + +def test_file_manager_write_pickle(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) + f = manager.acquire() + f.write('foo') + f.flush() + manager2 = pickle.loads(pickle.dumps(manager)) + f2 = manager2.acquire() + f2.write('bar') + manager2.close() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobar' + + +def test_file_manager_read(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + + with open(path, 'w') as f: + f.write('foobar') + + manager = CachingFileManager(open, path, cache=file_cache) + f = manager.acquire() + assert f.read() == 'foobar' + manager.close() + + +def test_file_manager_invalid_kwargs(): + with pytest.raises(TypeError): + CachingFileManager(open, 'dummy', mode='w', invalid=True) diff --git a/xarray/tests/test_backends_locks.py b/xarray/tests/test_backends_locks.py new file mode 100644 index 00000000000..5f83321802e --- /dev/null +++ b/xarray/tests/test_backends_locks.py @@ -0,0 +1,13 @@ +import threading + +from xarray.backends import locks + + +def test_threaded_lock(): + lock1 = locks._get_threaded_lock('foo') + assert isinstance(lock1, type(threading.Lock())) + lock2 = locks._get_threaded_lock('foo') + assert lock1 is lock2 + + lock3 = locks._get_threaded_lock('bar') + assert lock1 is not lock3 diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py new file mode 100644 index 00000000000..03eb6dcf208 --- /dev/null +++ b/xarray/tests/test_backends_lru_cache.py @@ -0,0 +1,91 @@ +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.lru_cache import LRUCache + + +def test_simple(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + + assert cache['x'] == 1 + assert cache['y'] == 2 + assert len(cache) == 2 + assert dict(cache) == {'x': 1, 'y': 2} + assert list(cache.keys()) == ['x', 'y'] + assert list(cache.items()) == [('x', 1), ('y', 2)] + + cache['z'] = 3 + assert len(cache) == 2 + assert list(cache.items()) == [('y', 2), ('z', 3)] + + +def test_trivial(): + cache = LRUCache(maxsize=0) + cache['x'] = 1 + assert len(cache) == 0 + + +def test_invalid(): + with pytest.raises(TypeError): + LRUCache(maxsize=None) + with pytest.raises(ValueError): + LRUCache(maxsize=-1) + + +def test_update_priority(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + assert list(cache) == ['x', 'y'] + assert 'x' in cache # contains + assert list(cache) == ['y', 'x'] + assert cache['y'] == 2 # getitem + assert list(cache) == ['x', 'y'] + cache['x'] = 3 # setitem + assert list(cache.items()) == [('y', 2), ('x', 3)] + + +def test_del(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + del cache['x'] + assert dict(cache) == {'y': 2} + + +def test_on_evict(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=1, on_evict=on_evict) + cache['x'] = 1 + cache['y'] = 2 + on_evict.assert_called_once_with('x', 1) + + +def test_on_evict_trivial(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=0, on_evict=on_evict) + cache['x'] = 1 + on_evict.assert_called_once_with('x', 1) + + +def test_resize(): + cache = LRUCache(maxsize=2) + assert cache.maxsize == 2 + cache['w'] = 0 + cache['x'] = 1 + cache['y'] = 2 + assert list(cache.items()) == [('x', 1), ('y', 2)] + cache.maxsize = 10 + cache['z'] = 3 + assert list(cache.items()) == [('x', 1), ('y', 2), ('z', 3)] + cache.maxsize = 1 + assert list(cache.items()) == [('z', 3)] + + with pytest.raises(ValueError): + cache.maxsize = -1 diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py new file mode 100644 index 00000000000..7acd764cab3 --- /dev/null +++ b/xarray/tests/test_cftime_offsets.py @@ -0,0 +1,799 @@ +from itertools import product + +import numpy as np +import pytest + +from xarray import CFTimeIndex +from xarray.coding.cftime_offsets import ( + _MONTH_ABBREVIATIONS, BaseCFTimeOffset, Day, Hour, Minute, MonthBegin, + MonthEnd, Second, YearBegin, YearEnd, _days_in_month, cftime_range, + get_date_type, to_cftime_datetime, to_offset) + +cftime = pytest.importorskip('cftime') + + +_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', + '366_day', 'gregorian', 'proleptic_gregorian', 'standard'] + + +def _id_func(param): + """Called on each parameter passed to pytest.mark.parametrize""" + return str(param) + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def calendar(request): + return request.param + + +@pytest.mark.parametrize( + ('offset', 'expected_n'), + [(BaseCFTimeOffset(), 1), + (YearBegin(), 1), + (YearEnd(), 1), + (BaseCFTimeOffset(n=2), 2), + (YearBegin(n=2), 2), + (YearEnd(n=2), 2)], + ids=_id_func +) +def test_cftime_offset_constructor_valid_n(offset, expected_n): + assert offset.n == expected_n + + +@pytest.mark.parametrize( + ('offset', 'invalid_n'), + [(BaseCFTimeOffset, 1.5), + (YearBegin, 1.5), + (YearEnd, 1.5)], + ids=_id_func +) +def test_cftime_offset_constructor_invalid_n(offset, invalid_n): + with pytest.raises(TypeError): + offset(n=invalid_n) + + +@pytest.mark.parametrize( + ('offset', 'expected_month'), + [(YearBegin(), 1), + (YearEnd(), 12), + (YearBegin(month=5), 5), + (YearEnd(month=5), 5)], + ids=_id_func +) +def test_year_offset_constructor_valid_month(offset, expected_month): + assert offset.month == expected_month + + +@pytest.mark.parametrize( + ('offset', 'invalid_month', 'exception'), + [(YearBegin, 0, ValueError), + (YearEnd, 0, ValueError), + (YearBegin, 13, ValueError,), + (YearEnd, 13, ValueError), + (YearBegin, 1.5, TypeError), + (YearEnd, 1.5, TypeError)], + ids=_id_func +) +def test_year_offset_constructor_invalid_month( + offset, invalid_month, exception): + with pytest.raises(exception): + offset(month=invalid_month) + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), None), + (MonthBegin(), 'MS'), + (YearBegin(), 'AS-JAN')], + ids=_id_func +) +def test_rule_code(offset, expected): + assert offset.rule_code() == expected + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), ''), + (YearBegin(), '')], + ids=_id_func +) +def test_str_and_repr(offset, expected): + assert str(offset) == expected + assert repr(offset) == expected + + +@pytest.mark.parametrize( + 'offset', + [BaseCFTimeOffset(), MonthBegin(), YearBegin()], + ids=_id_func +) +def test_to_offset_offset_input(offset): + assert to_offset(offset) == offset + + +@pytest.mark.parametrize( + ('freq', 'expected'), + [('M', MonthEnd()), + ('2M', MonthEnd(n=2)), + ('MS', MonthBegin()), + ('2MS', MonthBegin(n=2)), + ('D', Day()), + ('2D', Day(n=2)), + ('H', Hour()), + ('2H', Hour(n=2)), + ('T', Minute()), + ('2T', Minute(n=2)), + ('min', Minute()), + ('2min', Minute(n=2)), + ('S', Second()), + ('2S', Second(n=2))], + ids=_id_func +) +def test_to_offset_sub_annual(freq, expected): + assert to_offset(freq) == expected + + +_ANNUAL_OFFSET_TYPES = { + 'A': YearEnd, + 'AS': YearBegin +} + + +@pytest.mark.parametrize(('month_int', 'month_label'), + list(_MONTH_ABBREVIATIONS.items()) + [('', '')]) +@pytest.mark.parametrize('multiple', [None, 2]) +@pytest.mark.parametrize('offset_str', ['AS', 'A']) +def test_to_offset_annual(month_label, month_int, multiple, offset_str): + freq = offset_str + offset_type = _ANNUAL_OFFSET_TYPES[offset_str] + if month_label: + freq = '-'.join([freq, month_label]) + if multiple: + freq = '{}'.format(multiple) + freq + result = to_offset(freq) + + if multiple and month_int: + expected = offset_type(n=multiple, month=month_int) + elif multiple: + expected = offset_type(n=multiple) + elif month_int: + expected = offset_type(month=month_int) + else: + expected = offset_type() + assert result == expected + + +@pytest.mark.parametrize('freq', ['Z', '7min2', 'AM', 'M-', 'AS-', '1H1min']) +def test_invalid_to_offset_str(freq): + with pytest.raises(ValueError): + to_offset(freq) + + +@pytest.mark.parametrize( + ('argument', 'expected_date_args'), + [('2000-01-01', (2000, 1, 1)), + ((2000, 1, 1), (2000, 1, 1))], + ids=_id_func +) +def test_to_cftime_datetime(calendar, argument, expected_date_args): + date_type = get_date_type(calendar) + expected = date_type(*expected_date_args) + if isinstance(argument, tuple): + argument = date_type(*argument) + result = to_cftime_datetime(argument, calendar=calendar) + assert result == expected + + +def test_to_cftime_datetime_error_no_calendar(): + with pytest.raises(ValueError): + to_cftime_datetime('2000') + + +def test_to_cftime_datetime_error_type_error(): + with pytest.raises(TypeError): + to_cftime_datetime(1) + + +_EQ_TESTS_A = [ + BaseCFTimeOffset(), YearBegin(), YearEnd(), YearBegin(month=2), + YearEnd(month=2), MonthBegin(), MonthEnd(), Day(), Hour(), Minute(), + Second() +] +_EQ_TESTS_B = [ + BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2), + YearBegin(n=2, month=2), YearEnd(n=2, month=2), MonthBegin(n=2), + MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), Second(n=2) +] + + +@pytest.mark.parametrize( + ('a', 'b'), product(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func +) +def test_neq(a, b): + assert a != b + + +_EQ_TESTS_B_COPY = [ + BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2), + YearBegin(n=2, month=2), YearEnd(n=2, month=2), MonthBegin(n=2), + MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), Second(n=2) +] + + +@pytest.mark.parametrize( + ('a', 'b'), zip(_EQ_TESTS_B, _EQ_TESTS_B_COPY), ids=_id_func +) +def test_eq(a, b): + assert a == b + + +_MUL_TESTS = [ + (BaseCFTimeOffset(), BaseCFTimeOffset(n=3)), + (YearEnd(), YearEnd(n=3)), + (YearBegin(), YearBegin(n=3)), + (MonthEnd(), MonthEnd(n=3)), + (MonthBegin(), MonthBegin(n=3)), + (Day(), Day(n=3)), + (Hour(), Hour(n=3)), + (Minute(), Minute(n=3)), + (Second(), Second(n=3)) +] + + +@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func) +def test_mul(offset, expected): + assert offset * 3 == expected + + +@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func) +def test_rmul(offset, expected): + assert 3 * offset == expected + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), BaseCFTimeOffset(n=-1)), + (YearEnd(), YearEnd(n=-1)), + (YearBegin(), YearBegin(n=-1)), + (MonthEnd(), MonthEnd(n=-1)), + (MonthBegin(), MonthBegin(n=-1)), + (Day(), Day(n=-1)), + (Hour(), Hour(n=-1)), + (Minute(), Minute(n=-1)), + (Second(), Second(n=-1))], + ids=_id_func) +def test_neg(offset, expected): + assert -offset == expected + + +_ADD_TESTS = [ + (Day(n=2), (1, 1, 3)), + (Hour(n=2), (1, 1, 1, 2)), + (Minute(n=2), (1, 1, 1, 0, 2)), + (Second(n=2), (1, 1, 1, 0, 0, 2)) +] + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + _ADD_TESTS, + ids=_id_func +) +def test_add_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + expected = date_type(*expected_date_args) + result = offset + initial + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + _ADD_TESTS, + ids=_id_func +) +def test_radd_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + expected = date_type(*expected_date_args) + result = initial + offset + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + [(Day(n=2), (1, 1, 1)), + (Hour(n=2), (1, 1, 2, 22)), + (Minute(n=2), (1, 1, 2, 23, 58)), + (Second(n=2), (1, 1, 2, 23, 59, 58))], + ids=_id_func +) +def test_rsub_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 3) + expected = date_type(*expected_date_args) + result = initial - offset + assert result == expected + + +@pytest.mark.parametrize('offset', _EQ_TESTS_A, ids=_id_func) +def test_sub_error(offset, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + with pytest.raises(TypeError): + offset - initial + + +@pytest.mark.parametrize( + ('a', 'b'), + zip(_EQ_TESTS_A, _EQ_TESTS_B), + ids=_id_func +) +def test_minus_offset(a, b): + result = b - a + expected = a + assert result == expected + + +@pytest.mark.parametrize( + ('a', 'b'), + list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) + + [(YearEnd(month=1), YearEnd(month=2))], + ids=_id_func +) +def test_minus_offset_error(a, b): + with pytest.raises(TypeError): + b - a + + +def test_days_in_month_non_december(calendar): + date_type = get_date_type(calendar) + reference = date_type(1, 4, 1) + assert _days_in_month(reference) == 30 + + +def test_days_in_month_december(calendar): + if calendar == '360_day': + expected = 30 + else: + expected = 31 + date_type = get_date_type(calendar) + reference = date_type(1, 12, 5) + assert _days_in_month(reference) == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_date_args'), + [((1, 1, 1), MonthBegin(), (1, 2, 1)), + ((1, 1, 1), MonthBegin(n=2), (1, 3, 1)), + ((1, 1, 7), MonthBegin(), (1, 2, 1)), + ((1, 1, 7), MonthBegin(n=2), (1, 3, 1)), + ((1, 3, 1), MonthBegin(n=-1), (1, 2, 1)), + ((1, 3, 1), MonthBegin(n=-2), (1, 1, 1)), + ((1, 3, 3), MonthBegin(n=-1), (1, 3, 1)), + ((1, 3, 3), MonthBegin(n=-2), (1, 2, 1)), + ((1, 2, 1), MonthBegin(n=14), (2, 4, 1)), + ((2, 4, 1), MonthBegin(n=-14), (1, 2, 1)), + ((1, 1, 1, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_begin( + calendar, initial_date_args, offset, expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1, 1), MonthEnd(), (1, 1), ()), + ((1, 1, 1), MonthEnd(n=2), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-1), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-2), (1, 1), ()), + ((1, 2, 1), MonthEnd(n=14), (2, 3), ()), + ((2, 4, 1), MonthEnd(n=-14), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), MonthEnd(), (1, 1), (5, 5, 5, 5)), + ((1, 2, 1, 5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_end( + calendar, initial_date_args, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1), (), MonthEnd(), (1, 2), ()), + ((1, 1), (), MonthEnd(n=2), (1, 3), ()), + ((1, 3), (), MonthEnd(n=-1), (1, 2), ()), + ((1, 3), (), MonthEnd(n=-2), (1, 1), ()), + ((1, 2), (), MonthEnd(n=14), (2, 4), ()), + ((2, 4), (), MonthEnd(n=-14), (1, 2), ()), + ((1, 1), (5, 5, 5, 5), MonthEnd(), (1, 2), (5, 5, 5, 5)), + ((1, 2), (5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_end_onOffset( + calendar, initial_year_month, initial_sub_day, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = (initial_year_month + (_days_in_month(reference),) + + initial_sub_day) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_date_args'), + [((1, 1, 1), YearBegin(), (2, 1, 1)), + ((1, 1, 1), YearBegin(n=2), (3, 1, 1)), + ((1, 1, 1), YearBegin(month=2), (1, 2, 1)), + ((1, 1, 7), YearBegin(n=2), (3, 1, 1)), + ((2, 2, 1), YearBegin(n=-1), (2, 1, 1)), + ((1, 1, 2), YearBegin(n=-1), (1, 1, 1)), + ((1, 1, 1, 5, 5, 5, 5), YearBegin(), (2, 1, 1, 5, 5, 5, 5)), + ((2, 1, 1, 5, 5, 5, 5), YearBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_begin(calendar, initial_date_args, offset, + expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1, 1), YearEnd(), (1, 12), ()), + ((1, 1, 1), YearEnd(n=2), (2, 12), ()), + ((1, 1, 1), YearEnd(month=1), (1, 1), ()), + ((2, 3, 1), YearEnd(n=-1), (1, 12), ()), + ((1, 3, 1), YearEnd(n=-1, month=2), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(), (1, 12), (5, 5, 5, 5)), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(n=2), (2, 12), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_end( + calendar, initial_date_args, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 12), (), YearEnd(), (2, 12), ()), + ((1, 12), (), YearEnd(n=2), (3, 12), ()), + ((2, 12), (), YearEnd(n=-1), (1, 12), ()), + ((3, 12), (), YearEnd(n=-2), (1, 12), ()), + ((1, 1), (), YearEnd(month=2), (1, 2), ()), + ((1, 12), (5, 5, 5, 5), YearEnd(), (2, 12), (5, 5, 5, 5)), + ((2, 12), (5, 5, 5, 5), YearEnd(n=-1), (1, 12), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_end_onOffset( + calendar, initial_year_month, initial_sub_day, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = (initial_year_month + (_days_in_month(reference),) + + initial_sub_day) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +# Note for all sub-monthly offsets, pandas always returns True for onOffset +@pytest.mark.parametrize( + ('date_args', 'offset', 'expected'), + [((1, 1, 1), MonthBegin(), True), + ((1, 1, 1, 1), MonthBegin(), True), + ((1, 1, 5), MonthBegin(), False), + ((1, 1, 5), MonthEnd(), False), + ((1, 1, 1), YearBegin(), True), + ((1, 1, 1, 1), YearBegin(), True), + ((1, 1, 5), YearBegin(), False), + ((1, 12, 1), YearEnd(), False), + ((1, 1, 1), Day(), True), + ((1, 1, 1, 1), Day(), True), + ((1, 1, 1), Hour(), True), + ((1, 1, 1), Minute(), True), + ((1, 1, 1), Second(), True)], + ids=_id_func +) +def test_onOffset(calendar, date_args, offset, expected): + date_type = get_date_type(calendar) + date = date_type(*date_args) + result = offset.onOffset(date) + assert result == expected + + +@pytest.mark.parametrize( + ('year_month_args', 'sub_day_args', 'offset'), + [((1, 1), (), MonthEnd()), + ((1, 1), (1,), MonthEnd()), + ((1, 12), (), YearEnd()), + ((1, 1), (), YearEnd(month=1))], + ids=_id_func +) +def test_onOffset_month_or_year_end( + calendar, year_month_args, sub_day_args, offset): + date_type = get_date_type(calendar) + reference_args = year_month_args + (1,) + reference = date_type(*reference_args) + date_args = year_month_args + (_days_in_month(reference),) + sub_day_args + date = date_type(*date_args) + result = offset.onOffset(date) + assert result + + +@pytest.mark.parametrize( + ('offset', 'initial_date_args', 'partial_expected_date_args'), + [(YearBegin(), (1, 3, 1), (2, 1)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (2, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(), (1, 3, 1), (1, 12)), + (YearEnd(n=2), (1, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(n=2, month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 4)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 4)), + (MonthEnd(), (1, 3, 2), (1, 3)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (MonthEnd(n=2), (1, 3, 2), (1, 3)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))], + ids=_id_func +) +def test_rollforward(calendar, offset, initial_date_args, + partial_expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + if isinstance(offset, (MonthBegin, YearBegin)): + expected_date_args = partial_expected_date_args + (1,) + elif isinstance(offset, (MonthEnd, YearEnd)): + reference_args = partial_expected_date_args + (1,) + reference = date_type(*reference_args) + expected_date_args = (partial_expected_date_args + + (_days_in_month(reference),)) + else: + expected_date_args = partial_expected_date_args + expected = date_type(*expected_date_args) + result = offset.rollforward(initial) + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'initial_date_args', 'partial_expected_date_args'), + [(YearBegin(), (1, 3, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (1, 2)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 2, 1), (1, 2)), + (YearEnd(), (2, 3, 1), (1, 12)), + (YearEnd(n=2), (2, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (2, 3, 1), (2, 2)), + (YearEnd(month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 3)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthEnd(), (1, 3, 2), (1, 2)), + (MonthEnd(n=2), (1, 3, 2), (1, 2)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))], + ids=_id_func +) +def test_rollback(calendar, offset, initial_date_args, + partial_expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + if isinstance(offset, (MonthBegin, YearBegin)): + expected_date_args = partial_expected_date_args + (1,) + elif isinstance(offset, (MonthEnd, YearEnd)): + reference_args = partial_expected_date_args + (1,) + reference = date_type(*reference_args) + expected_date_args = (partial_expected_date_args + + (_days_in_month(reference),)) + else: + expected_date_args = partial_expected_date_args + expected = date_type(*expected_date_args) + result = offset.rollback(initial) + assert result == expected + + +_CFTIME_RANGE_TESTS = [ + ('0001-01-01', '0001-01-04', None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01', '0001-01-04', None, 'D', 'left', False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3)]), + ('0001-01-01', '0001-01-04', None, 'D', 'right', False, + [(1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, False, + [(1, 1, 1, 1), (1, 1, 2, 1), (1, 1, 3, 1)]), + ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, True, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01', None, 4, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + (None, '0001-01-04', 4, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ((1, 1, 1), '0001-01-04', None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ((1, 1, 1), (1, 1, 4), None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-30', '0011-02-01', None, '3AS-JUN', None, False, + [(1, 6, 1), (4, 6, 1), (7, 6, 1), (10, 6, 1)]), + ('0001-01-04', '0001-01-01', None, 'D', None, False, + []), + ('0010', None, 4, YearBegin(n=-2), None, False, + [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)]), + ('0001-01-01', '0001-01-04', 4, None, None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]) +] + + +@pytest.mark.parametrize( + ('start', 'end', 'periods', 'freq', 'closed', 'normalize', + 'expected_date_args'), + _CFTIME_RANGE_TESTS, ids=_id_func +) +def test_cftime_range( + start, end, periods, freq, closed, normalize, calendar, + expected_date_args): + date_type = get_date_type(calendar) + expected_dates = [date_type(*args) for args in expected_date_args] + + if isinstance(start, tuple): + start = date_type(*start) + if isinstance(end, tuple): + end = date_type(*end) + + result = cftime_range( + start=start, end=end, periods=periods, freq=freq, closed=closed, + normalize=normalize, calendar=calendar) + resulting_dates = result.values + + assert isinstance(result, CFTimeIndex) + + if freq is not None: + np.testing.assert_equal(resulting_dates, expected_dates) + else: + # If we create a linear range of dates using cftime.num2date + # we will not get exact round number dates. This is because + # datetime arithmetic in cftime is accurate approximately to + # 1 millisecond (see https://unidata.github.io/cftime/api.html). + deltas = resulting_dates - expected_dates + deltas = np.array([delta.total_seconds() for delta in deltas]) + assert np.max(np.abs(deltas)) < 0.001 + + +def test_cftime_range_name(): + result = cftime_range(start='2000', periods=4, name='foo') + assert result.name == 'foo' + + result = cftime_range(start='2000', periods=4) + assert result.name is None + + +@pytest.mark.parametrize( + ('start', 'end', 'periods', 'freq', 'closed'), + [(None, None, 5, 'A', None), + ('2000', None, None, 'A', None), + (None, '2000', None, 'A', None), + ('2000', '2001', None, None, None), + (None, None, None, None, None), + ('2000', '2001', None, 'A', 'up'), + ('2000', '2001', 5, 'A', None)] +) +def test_invalid_cftime_range_inputs(start, end, periods, freq, closed): + with pytest.raises(ValueError): + cftime_range(start, end, periods, freq, closed=closed) + + +_CALENDAR_SPECIFIC_MONTH_END_TESTS = [ + ('2M', 'noleap', + [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'all_leap', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', '360_day', + [(2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (12, 30)]), + ('2M', 'standard', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'gregorian', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'julian', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]) +] + + +@pytest.mark.parametrize( + ('freq', 'calendar', 'expected_month_day'), + _CALENDAR_SPECIFIC_MONTH_END_TESTS, ids=_id_func +) +def test_calendar_specific_month_end(freq, calendar, expected_month_day): + year = 2000 # Use a leap-year to highlight calendar differences + result = cftime_range( + start='2000-02', end='2001', freq=freq, calendar=calendar).values + date_type = get_date_type(calendar) + expected = [date_type(year, *args) for args in expected_month_day] + np.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize( + ('calendar', 'start', 'end', 'expected_number_of_days'), + [('noleap', '2000', '2001', 365), + ('all_leap', '2000', '2001', 366), + ('360_day', '2000', '2001', 360), + ('standard', '2000', '2001', 366), + ('gregorian', '2000', '2001', 366), + ('julian', '2000', '2001', 366), + ('noleap', '2001', '2002', 365), + ('all_leap', '2001', '2002', 366), + ('360_day', '2001', '2002', 360), + ('standard', '2001', '2002', 365), + ('gregorian', '2001', '2002', 365), + ('julian', '2001', '2002', 365)] +) +def test_calendar_year_length( + calendar, start, end, expected_number_of_days): + result = cftime_range(start, end, freq='D', closed='left', + calendar=calendar) + assert len(result) == expected_number_of_days diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 6f102b60b9d..5e710827ff8 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -1,18 +1,20 @@ from __future__ import absolute_import -import pytest +from datetime import timedelta +import numpy as np import pandas as pd -import xarray as xr +import pytest -from datetime import timedelta +import xarray as xr from xarray.coding.cftimeindex import ( - parse_iso8601, CFTimeIndex, assert_all_valid_date_type, - _parsed_string_to_bounds, _parse_iso8601_with_reso) + CFTimeIndex, _parse_array_of_cftime_strings, _parse_iso8601_with_reso, + _parsed_string_to_bounds, assert_all_valid_date_type, parse_iso8601) from xarray.tests import assert_array_equal, assert_identical -from . import has_cftime, has_cftime_or_netCDF4 -from .test_coding_times import _all_cftime_date_types +from . import has_cftime, has_cftime_or_netCDF4, requires_cftime +from .test_coding_times import (_all_cftime_date_types, _ALL_CALENDARS, + _NON_STANDARD_CALENDARS) def date_dict(year=None, month=None, day=None, @@ -121,22 +123,42 @@ def dec_days(date_type): return 31 +@pytest.fixture +def index_with_name(date_type): + dates = [date_type(1, 1, 1), date_type(1, 2, 1), + date_type(2, 1, 1), date_type(2, 2, 1)] + return CFTimeIndex(dates, name='foo') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize( + ('name', 'expected_name'), + [('bar', 'bar'), + (None, 'foo')]) +def test_constructor_with_name(index_with_name, name, expected_name): + result = CFTimeIndex(index_with_name, name=name).name + assert result == expected_name + + @pytest.mark.skipif(not has_cftime, reason='cftime not installed') def test_assert_all_valid_date_type(date_type, index): import cftime if date_type is cftime.DatetimeNoLeap: - mixed_date_types = [date_type(1, 1, 1), - cftime.DatetimeAllLeap(1, 2, 1)] + mixed_date_types = np.array( + [date_type(1, 1, 1), + cftime.DatetimeAllLeap(1, 2, 1)]) else: - mixed_date_types = [date_type(1, 1, 1), - cftime.DatetimeNoLeap(1, 2, 1)] + mixed_date_types = np.array( + [date_type(1, 1, 1), + cftime.DatetimeNoLeap(1, 2, 1)]) with pytest.raises(TypeError): assert_all_valid_date_type(mixed_date_types) with pytest.raises(TypeError): - assert_all_valid_date_type([1, date_type(1, 1, 1)]) + assert_all_valid_date_type(np.array([1, date_type(1, 1, 1)])) - assert_all_valid_date_type([date_type(1, 1, 1), date_type(1, 2, 1)]) + assert_all_valid_date_type( + np.array([date_type(1, 1, 1), date_type(1, 2, 1)])) @pytest.mark.skipif(not has_cftime, reason='cftime not installed') @@ -339,7 +361,7 @@ def test_groupby(da): @pytest.mark.skipif(not has_cftime, reason='cftime not installed') def test_resample_error(da): - with pytest.raises(TypeError): + with pytest.raises(NotImplementedError, match='to_datetimeindex'): da.resample(time='Y') @@ -573,19 +595,187 @@ def test_indexing_in_dataframe_iloc(df, index): @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('enable_cftimeindex', [False, True]) -def test_concat_cftimeindex(date_type, enable_cftimeindex): - with xr.set_options(enable_cftimeindex=enable_cftimeindex): - da1 = xr.DataArray( - [1., 2.], coords=[[date_type(1, 1, 1), date_type(1, 2, 1)]], - dims=['time']) - da2 = xr.DataArray( - [3., 4.], coords=[[date_type(1, 3, 1), date_type(1, 4, 1)]], - dims=['time']) - da = xr.concat([da1, da2], dim='time') - - if enable_cftimeindex and has_cftime: +def test_concat_cftimeindex(date_type): + da1 = xr.DataArray( + [1., 2.], coords=[[date_type(1, 1, 1), date_type(1, 2, 1)]], + dims=['time']) + da2 = xr.DataArray( + [3., 4.], coords=[[date_type(1, 3, 1), date_type(1, 4, 1)]], + dims=['time']) + da = xr.concat([da1, da2], dim='time') + + if has_cftime: assert isinstance(da.indexes['time'], CFTimeIndex) else: assert isinstance(da.indexes['time'], pd.Index) assert not isinstance(da.indexes['time'], CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_empty_cftimeindex(): + index = CFTimeIndex([]) + assert index.date_type is None + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_add(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=1) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +def test_cftimeindex_add_timedeltaindex(calendar): + a = xr.cftime_range('2000', periods=5, calendar=calendar) + deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + result = a + deltas + expected = a.shift(2, 'D') + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_radd(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = timedelta(days=1) + index + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +def test_timedeltaindex_add_cftimeindex(calendar): + a = xr.cftime_range('2000', periods=5, calendar=calendar) + deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + result = deltas + a + expected = a.shift(2, 'D') + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_sub(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=2) + result = result - timedelta(days=1) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +def test_cftimeindex_sub_cftimeindex(calendar): + a = xr.cftime_range('2000', periods=5, calendar=calendar) + b = a.shift(2, 'D') + result = b - a + expected = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + assert result.equals(expected) + assert isinstance(result, pd.TimedeltaIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _CFTIME_CALENDARS) +def test_cftimeindex_sub_timedeltaindex(calendar): + a = xr.cftime_range('2000', periods=5, calendar=calendar) + deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + result = a - deltas + expected = a.shift(-2, 'D') + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_rsub(index): + with pytest.raises(TypeError): + timedelta(days=1) - index + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('freq', ['D', timedelta(days=1)]) +def test_cftimeindex_shift(index, freq): + date_type = index.date_type + expected_dates = [date_type(1, 1, 3), date_type(1, 2, 3), + date_type(2, 1, 3), date_type(2, 2, 3)] + expected = CFTimeIndex(expected_dates) + result = index.shift(2, freq) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_shift_invalid_n(): + index = xr.cftime_range('2000', periods=3) + with pytest.raises(TypeError): + index.shift('a', 'D') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_shift_invalid_freq(): + index = xr.cftime_range('2000', periods=3) + with pytest.raises(TypeError): + index.shift(1, 1) + + +@requires_cftime +def test_parse_array_of_cftime_strings(): + from cftime import DatetimeNoLeap + + strings = np.array([['2000-01-01', '2000-01-02'], + ['2000-01-03', '2000-01-04']]) + expected = np.array( + [[DatetimeNoLeap(2000, 1, 1), DatetimeNoLeap(2000, 1, 2)], + [DatetimeNoLeap(2000, 1, 3), DatetimeNoLeap(2000, 1, 4)]]) + + result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) + np.testing.assert_array_equal(result, expected) + + # Test scalar array case + strings = np.array('2000-01-01') + expected = np.array(DatetimeNoLeap(2000, 1, 1)) + result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +@pytest.mark.parametrize('unsafe', [False, True]) +def test_to_datetimeindex(calendar, unsafe): + index = xr.cftime_range('2000', periods=5, calendar=calendar) + expected = pd.date_range('2000', periods=5) + + if calendar in _NON_STANDARD_CALENDARS and not unsafe: + with pytest.warns(RuntimeWarning, match='non-standard'): + result = index.to_datetimeindex() + else: + result = index.to_datetimeindex() + + assert result.equals(expected) + np.testing.assert_array_equal(result, expected) + assert isinstance(result, pd.DatetimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +def test_to_datetimeindex_out_of_range(calendar): + index = xr.cftime_range('0001', periods=5, calendar=calendar) + with pytest.raises(ValueError, match='0001'): + index.to_datetimeindex() + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('calendar', ['all_leap', '360_day']) +def test_to_datetimeindex_feb_29(calendar): + index = xr.cftime_range('2001-02-28', periods=2, calendar=calendar) + with pytest.raises(ValueError, match='29'): + index.to_datetimeindex() diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index 53d028e164b..ca138ca8362 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -5,13 +5,13 @@ import pytest from xarray import Variable -from xarray.core.pycompat import bytes_type, unicode_type, suppress from xarray.coding import strings from xarray.core import indexing +from xarray.core.pycompat import bytes_type, suppress, unicode_type -from . import (IndexerMaker, assert_array_equal, assert_identical, - raises_regex, requires_dask) - +from . import ( + IndexerMaker, assert_array_equal, assert_identical, raises_regex, + requires_dask) with suppress(ImportError): import dask.array as da diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 4d6ca731bb2..0ca57f98a6d 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1,25 +1,27 @@ from __future__ import absolute_import, division, print_function -from itertools import product import warnings +from itertools import product import numpy as np import pandas as pd import pytest -from xarray import Variable, coding, set_options, DataArray, decode_cf -from xarray.coding.times import _import_cftime -from xarray.coding.variables import SerializationWarning +from xarray import DataArray, Variable, coding, decode_cf +from xarray.coding.times import (_import_cftime, cftime_to_nptime, + decode_cf_datetime, encode_cf_datetime) from xarray.core.common import contains_cftime_datetimes -from . import (assert_array_equal, has_cftime_or_netCDF4, - requires_cftime_or_netCDF4, has_cftime, has_dask) - +from . import ( + assert_array_equal, has_cftime, has_cftime_or_netCDF4, has_dask, + requires_cftime_or_netCDF4) -_NON_STANDARD_CALENDARS = {'noleap', '365_day', '360_day', - 'julian', 'all_leap', '366_day'} -_ALL_CALENDARS = _NON_STANDARD_CALENDARS.union( - coding.times._STANDARD_CALENDARS) +_NON_STANDARD_CALENDARS_SET = {'noleap', '365_day', '360_day', + 'julian', 'all_leap', '366_day'} +_ALL_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET.union( + coding.times._STANDARD_CALENDARS)) +_NON_STANDARD_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET) +_STANDARD_CALENDARS = sorted(coding.times._STANDARD_CALENDARS) _CF_DATETIME_NUM_DATES_UNITS = [ (np.arange(10), 'days since 2000-01-01'), (np.arange(10).astype('float64'), 'days since 2000-01-01'), @@ -45,19 +47,12 @@ ([0.5, 1.5], 'hours since 1900-01-01T00:00:00'), (0, 'milliseconds since 2000-01-01T00:00:00'), (0, 'microseconds since 2000-01-01T00:00:00'), - (np.int32(788961600), 'seconds since 1981-01-01') # GH2002 + (np.int32(788961600), 'seconds since 1981-01-01'), # GH2002 + (12300 + np.arange(5), 'hour since 1680-01-01 00:00:00.500000') ] _CF_DATETIME_TESTS = [num_dates_units + (calendar,) for num_dates_units, calendar in product(_CF_DATETIME_NUM_DATES_UNITS, - coding.times._STANDARD_CALENDARS)] - - -@np.vectorize -def _ensure_naive_tz(dt): - if hasattr(dt, 'tzinfo'): - return dt.replace(tzinfo=None) - else: - return dt + _STANDARD_CALENDARS)] def _all_cftime_date_types(): @@ -80,24 +75,27 @@ def _all_cftime_date_types(): _CF_DATETIME_TESTS) def test_cf_datetime(num_dates, units, calendar): cftime = _import_cftime() - expected = _ensure_naive_tz( - cftime.num2date(num_dates, units, calendar)) + if cftime.__name__ == 'cftime': + expected = cftime.num2date(num_dates, units, calendar, + only_use_cftime_datetimes=True) + else: + expected = cftime.num2date(num_dates, units, calendar) + min_y = np.ravel(np.atleast_1d(expected))[np.nanargmin(num_dates)].year + max_y = np.ravel(np.atleast_1d(expected))[np.nanargmax(num_dates)].year + if min_y >= 1678 and max_y < 2262: + expected = cftime_to_nptime(expected) + with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'Unable to decode time axis') actual = coding.times.decode_cf_datetime(num_dates, units, calendar) - if (isinstance(actual, np.ndarray) and - np.issubdtype(actual.dtype, np.datetime64)): - # self.assertEqual(actual.dtype.kind, 'M') - # For some reason, numpy 1.8 does not compare ns precision - # datetime64 arrays as equal to arrays of datetime objects, - # but it works for us precision. Thus, convert to us - # precision for the actual array equal comparison... - actual_cmp = actual.astype('M8[us]') - else: - actual_cmp = actual - assert_array_equal(expected, actual_cmp) + + abs_diff = np.atleast_1d(abs(actual - expected)).astype(np.timedelta64) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, 's')).all() encoded, _, _ = coding.times.encode_cf_datetime(actual, units, calendar) if '1-1-1' not in units: @@ -121,8 +119,12 @@ def test_decode_cf_datetime_overflow(): # checks for # https://github.com/pydata/pandas/issues/14068 # https://github.com/pydata/xarray/issues/975 + try: + from cftime import DatetimeGregorian + except ImportError: + from netcdftime import DatetimeGregorian - from datetime import datetime + datetime = DatetimeGregorian units = 'days since 2000-01-01 00:00:00' # date after 2262 and before 1678 @@ -148,39 +150,32 @@ def test_decode_cf_datetime_non_standard_units(): @requires_cftime_or_netCDF4 def test_decode_cf_datetime_non_iso_strings(): # datetime strings that are _almost_ ISO compliant but not quite, - # but which netCDF4.num2date can still parse correctly + # but which cftime.num2date can still parse correctly expected = pd.date_range(periods=100, start='2000-01-01', freq='h') cases = [(np.arange(100), 'hours since 2000-01-01 0'), (np.arange(100), 'hours since 2000-1-1 0'), (np.arange(100), 'hours since 2000-01-01 0:00')] for num_dates, units in cases: actual = coding.times.decode_cf_datetime(num_dates, units) - assert_array_equal(actual, expected) + abs_diff = abs(actual - expected.values) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, 's')).all() @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(coding.times._STANDARD_CALENDARS, [False, True])) -def test_decode_standard_calendar_inside_timestamp_range( - calendar, enable_cftimeindex): - if enable_cftimeindex: - pytest.importorskip('cftime') - +@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) +def test_decode_standard_calendar_inside_timestamp_range(calendar): cftime = _import_cftime() + units = 'days since 0001-01-01' - times = pd.date_range('2001-04-01-00', end='2001-04-30-23', - freq='H') - noleap_time = cftime.date2num(times.to_pydatetime(), units, - calendar=calendar) + times = pd.date_range('2001-04-01-00', end='2001-04-30-23', freq='H') + time = cftime.date2num(times.to_pydatetime(), units, calendar=calendar) expected = times.values expected_dtype = np.dtype('M8[ns]') - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime( - noleap_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) + actual = coding.times.decode_cf_datetime(time, units, calendar=calendar) assert actual.dtype == expected_dtype abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, @@ -190,32 +185,28 @@ def test_decode_standard_calendar_inside_timestamp_range( @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(_NON_STANDARD_CALENDARS, [False, True])) +@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) def test_decode_non_standard_calendar_inside_timestamp_range( - calendar, enable_cftimeindex): - if enable_cftimeindex: - pytest.importorskip('cftime') - + calendar): cftime = _import_cftime() units = 'days since 0001-01-01' times = pd.date_range('2001-04-01-00', end='2001-04-30-23', freq='H') - noleap_time = cftime.date2num(times.to_pydatetime(), units, - calendar=calendar) - if enable_cftimeindex: - expected = cftime.num2date(noleap_time, units, calendar=calendar) - expected_dtype = np.dtype('O') + non_standard_time = cftime.date2num( + times.to_pydatetime(), units, calendar=calendar) + + if cftime.__name__ == 'cftime': + expected = cftime.num2date( + non_standard_time, units, calendar=calendar, + only_use_cftime_datetimes=True) else: - expected = times.values - expected_dtype = np.dtype('M8[ns]') + expected = cftime.num2date(non_standard_time, units, + calendar=calendar) - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime( - noleap_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) + expected_dtype = np.dtype('O') + + actual = coding.times.decode_cf_datetime( + non_standard_time, units, calendar=calendar) assert actual.dtype == expected_dtype abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, @@ -225,33 +216,27 @@ def test_decode_non_standard_calendar_inside_timestamp_range( @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(_ALL_CALENDARS, [False, True])) -def test_decode_dates_outside_timestamp_range( - calendar, enable_cftimeindex): +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +def test_decode_dates_outside_timestamp_range(calendar): from datetime import datetime - - if enable_cftimeindex: - pytest.importorskip('cftime') - cftime = _import_cftime() units = 'days since 0001-01-01' times = [datetime(1, 4, 1, h) for h in range(1, 5)] - noleap_time = cftime.date2num(times, units, calendar=calendar) - if enable_cftimeindex: - expected = cftime.num2date(noleap_time, units, calendar=calendar, + time = cftime.date2num(times, units, calendar=calendar) + + if cftime.__name__ == 'cftime': + expected = cftime.num2date(time, units, calendar=calendar, only_use_cftime_datetimes=True) else: - expected = cftime.num2date(noleap_time, units, calendar=calendar) + expected = cftime.num2date(time, units, calendar=calendar) + expected_date_type = type(expected[0]) with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'Unable to decode time axis') actual = coding.times.decode_cf_datetime( - noleap_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) + time, units, calendar=calendar) assert all(isinstance(value, expected_date_type) for value in actual) abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, @@ -261,57 +246,37 @@ def test_decode_dates_outside_timestamp_range( @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(coding.times._STANDARD_CALENDARS, [False, True])) +@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) def test_decode_standard_calendar_single_element_inside_timestamp_range( - calendar, enable_cftimeindex): - if enable_cftimeindex: - pytest.importorskip('cftime') - + calendar): units = 'days since 0001-01-01' for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'Unable to decode time axis') actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) + num_time, units, calendar=calendar) assert actual.dtype == np.dtype('M8[ns]') @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(_NON_STANDARD_CALENDARS, [False, True])) +@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) def test_decode_non_standard_calendar_single_element_inside_timestamp_range( - calendar, enable_cftimeindex): - if enable_cftimeindex: - pytest.importorskip('cftime') - + calendar): units = 'days since 0001-01-01' for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'Unable to decode time axis') actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) - if enable_cftimeindex: - assert actual.dtype == np.dtype('O') - else: - assert actual.dtype == np.dtype('M8[ns]') + num_time, units, calendar=calendar) + assert actual.dtype == np.dtype('O') @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(_NON_STANDARD_CALENDARS, [False, True])) +@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) def test_decode_single_element_outside_timestamp_range( - calendar, enable_cftimeindex): - if enable_cftimeindex: - pytest.importorskip('cftime') - + calendar): cftime = _import_cftime() units = 'days since 0001-01-01' for days in [1, 1470376]: @@ -320,40 +285,39 @@ def test_decode_single_element_outside_timestamp_range( warnings.filterwarnings('ignore', 'Unable to decode time axis') actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) - expected = cftime.num2date(days, units, calendar) + num_time, units, calendar=calendar) + + if cftime.__name__ == 'cftime': + expected = cftime.num2date(days, units, calendar, + only_use_cftime_datetimes=True) + else: + expected = cftime.num2date(days, units, calendar) + assert isinstance(actual.item(), type(expected)) @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(coding.times._STANDARD_CALENDARS, [False, True])) +@pytest.mark.parametrize('calendar', _STANDARD_CALENDARS) def test_decode_standard_calendar_multidim_time_inside_timestamp_range( - calendar, enable_cftimeindex): - if enable_cftimeindex: - pytest.importorskip('cftime') - + calendar): cftime = _import_cftime() units = 'days since 0001-01-01' times1 = pd.date_range('2001-04-01', end='2001-04-05', freq='D') times2 = pd.date_range('2001-05-01', end='2001-05-05', freq='D') - noleap_time1 = cftime.date2num(times1.to_pydatetime(), - units, calendar=calendar) - noleap_time2 = cftime.date2num(times2.to_pydatetime(), - units, calendar=calendar) - mdim_time = np.empty((len(noleap_time1), 2), ) - mdim_time[:, 0] = noleap_time1 - mdim_time[:, 1] = noleap_time2 + time1 = cftime.date2num(times1.to_pydatetime(), + units, calendar=calendar) + time2 = cftime.date2num(times2.to_pydatetime(), + units, calendar=calendar) + mdim_time = np.empty((len(time1), 2), ) + mdim_time[:, 0] = time1 + mdim_time[:, 1] = time2 expected1 = times1.values expected2 = times2.values actual = coding.times.decode_cf_datetime( - mdim_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) + mdim_time, units, calendar=calendar) assert actual.dtype == np.dtype('M8[ns]') abs_diff1 = abs(actual[:, 0] - expected1) @@ -366,39 +330,35 @@ def test_decode_standard_calendar_multidim_time_inside_timestamp_range( @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(_NON_STANDARD_CALENDARS, [False, True])) +@pytest.mark.parametrize('calendar', _NON_STANDARD_CALENDARS) def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range( - calendar, enable_cftimeindex): - if enable_cftimeindex: - pytest.importorskip('cftime') - + calendar): cftime = _import_cftime() units = 'days since 0001-01-01' times1 = pd.date_range('2001-04-01', end='2001-04-05', freq='D') times2 = pd.date_range('2001-05-01', end='2001-05-05', freq='D') - noleap_time1 = cftime.date2num(times1.to_pydatetime(), - units, calendar=calendar) - noleap_time2 = cftime.date2num(times2.to_pydatetime(), - units, calendar=calendar) - mdim_time = np.empty((len(noleap_time1), 2), ) - mdim_time[:, 0] = noleap_time1 - mdim_time[:, 1] = noleap_time2 - - if enable_cftimeindex: - expected1 = cftime.num2date(noleap_time1, units, calendar) - expected2 = cftime.num2date(noleap_time2, units, calendar) - expected_dtype = np.dtype('O') + time1 = cftime.date2num(times1.to_pydatetime(), + units, calendar=calendar) + time2 = cftime.date2num(times2.to_pydatetime(), + units, calendar=calendar) + mdim_time = np.empty((len(time1), 2), ) + mdim_time[:, 0] = time1 + mdim_time[:, 1] = time2 + + if cftime.__name__ == 'cftime': + expected1 = cftime.num2date(time1, units, calendar, + only_use_cftime_datetimes=True) + expected2 = cftime.num2date(time2, units, calendar, + only_use_cftime_datetimes=True) else: - expected1 = times1.values - expected2 = times2.values - expected_dtype = np.dtype('M8[ns]') + expected1 = cftime.num2date(time1, units, calendar) + expected2 = cftime.num2date(time2, units, calendar) + + expected_dtype = np.dtype('O') actual = coding.times.decode_cf_datetime( - mdim_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) + mdim_time, units, calendar=calendar) assert actual.dtype == expected_dtype abs_diff1 = abs(actual[:, 0] - expected1) @@ -411,41 +371,34 @@ def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range( @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(_ALL_CALENDARS, [False, True])) +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) def test_decode_multidim_time_outside_timestamp_range( - calendar, enable_cftimeindex): + calendar): from datetime import datetime - - if enable_cftimeindex: - pytest.importorskip('cftime') - cftime = _import_cftime() units = 'days since 0001-01-01' times1 = [datetime(1, 4, day) for day in range(1, 6)] times2 = [datetime(1, 5, day) for day in range(1, 6)] - noleap_time1 = cftime.date2num(times1, units, calendar=calendar) - noleap_time2 = cftime.date2num(times2, units, calendar=calendar) - mdim_time = np.empty((len(noleap_time1), 2), ) - mdim_time[:, 0] = noleap_time1 - mdim_time[:, 1] = noleap_time2 - - if enable_cftimeindex: - expected1 = cftime.num2date(noleap_time1, units, calendar, + time1 = cftime.date2num(times1, units, calendar=calendar) + time2 = cftime.date2num(times2, units, calendar=calendar) + mdim_time = np.empty((len(time1), 2), ) + mdim_time[:, 0] = time1 + mdim_time[:, 1] = time2 + + if cftime.__name__ == 'cftime': + expected1 = cftime.num2date(time1, units, calendar, only_use_cftime_datetimes=True) - expected2 = cftime.num2date(noleap_time2, units, calendar, + expected2 = cftime.num2date(time2, units, calendar, only_use_cftime_datetimes=True) else: - expected1 = cftime.num2date(noleap_time1, units, calendar) - expected2 = cftime.num2date(noleap_time2, units, calendar) + expected1 = cftime.num2date(time1, units, calendar) + expected2 = cftime.num2date(time2, units, calendar) with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'Unable to decode time axis') actual = coding.times.decode_cf_datetime( - mdim_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) + mdim_time, units, calendar=calendar) assert actual.dtype == np.dtype('O') @@ -459,66 +412,51 @@ def test_decode_multidim_time_outside_timestamp_range( @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(['360_day', 'all_leap', '366_day'], [False, True])) -def test_decode_non_standard_calendar_single_element_fallback( - calendar, enable_cftimeindex): - if enable_cftimeindex: - pytest.importorskip('cftime') - +@pytest.mark.parametrize('calendar', ['360_day', 'all_leap', '366_day']) +def test_decode_non_standard_calendar_single_element( + calendar): cftime = _import_cftime() - units = 'days since 0001-01-01' + try: dt = cftime.netcdftime.datetime(2001, 2, 29) except AttributeError: - # Must be using standalone netcdftime library + # Must be using the standalone cftime library dt = cftime.datetime(2001, 2, 29) num_time = cftime.date2num(dt, units, calendar) - if enable_cftimeindex: - actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) - else: - with pytest.warns(SerializationWarning, - match='Unable to decode time axis'): - actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar) - expected = np.asarray(cftime.num2date(num_time, units, calendar)) + if cftime.__name__ == 'cftime': + expected = np.asarray(cftime.num2date( + num_time, units, calendar, only_use_cftime_datetimes=True)) + else: + expected = np.asarray(cftime.num2date(num_time, units, calendar)) assert actual.dtype == np.dtype('O') assert expected == actual @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(['360_day'], [False, True])) -def test_decode_non_standard_calendar_fallback( - calendar, enable_cftimeindex): - if enable_cftimeindex: - pytest.importorskip('cftime') - +def test_decode_360_day_calendar(): cftime = _import_cftime() + calendar = '360_day' # ensure leap year doesn't matter for year in [2010, 2011, 2012, 2013, 2014]: units = 'days since {0}-01-01'.format(year) num_times = np.arange(100) - expected = cftime.num2date(num_times, units, calendar) + + if cftime.__name__ == 'cftime': + expected = cftime.num2date(num_times, units, calendar, + only_use_cftime_datetimes=True) + else: + expected = cftime.num2date(num_times, units, calendar) with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') actual = coding.times.decode_cf_datetime( - num_times, units, calendar=calendar, - enable_cftimeindex=enable_cftimeindex) - if enable_cftimeindex: - assert len(w) == 0 - else: - assert len(w) == 1 - assert 'Unable to decode time axis' in str(w[0].message) + num_times, units, calendar=calendar) + assert len(w) == 0 assert actual.dtype == np.dtype('O') assert_array_equal(actual, expected) @@ -536,7 +474,8 @@ def test_cf_datetime_nan(num_dates, units, expected_list): with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'All-NaN') actual = coding.times.decode_cf_datetime(num_dates, units) - expected = np.array(expected_list, dtype='datetime64[ns]') + # use pandas because numpy will deprecate timezone-aware conversions + expected = pd.to_datetime(expected_list) assert_array_equal(expected, actual) @@ -572,28 +511,24 @@ def test_infer_datetime_units(dates, expected): assert expected == coding.times.infer_datetime_units(dates) +_CFTIME_DATETIME_UNITS_TESTS = [ + ([(1900, 1, 1), (1900, 1, 1)], 'days since 1900-01-01 00:00:00.000000'), + ([(1900, 1, 1), (1900, 1, 2), (1900, 1, 2, 0, 0, 1)], + 'seconds since 1900-01-01 00:00:00.000000'), + ([(1900, 1, 1), (1900, 1, 8), (1900, 1, 16)], + 'days since 1900-01-01 00:00:00.000000') +] + + @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -def test_infer_cftime_datetime_units(): - date_types = _all_cftime_date_types() - for date_type in date_types.values(): - for dates, expected in [ - ([date_type(1900, 1, 1), - date_type(1900, 1, 2)], - 'days since 1900-01-01 00:00:00.000000'), - ([date_type(1900, 1, 1, 12), - date_type(1900, 1, 1, 13)], - 'seconds since 1900-01-01 12:00:00.000000'), - ([date_type(1900, 1, 1), - date_type(1900, 1, 2), - date_type(1900, 1, 2, 0, 0, 1)], - 'seconds since 1900-01-01 00:00:00.000000'), - ([date_type(1900, 1, 1), - date_type(1900, 1, 2, 0, 0, 0, 5)], - 'days since 1900-01-01 00:00:00.000000'), - ([date_type(1900, 1, 1), date_type(1900, 1, 8), - date_type(1900, 1, 16)], - 'days since 1900-01-01 00:00:00.000000')]: - assert expected == coding.times.infer_datetime_units(dates) +@pytest.mark.parametrize( + 'calendar', _NON_STANDARD_CALENDARS + ['gregorian', 'proleptic_gregorian']) +@pytest.mark.parametrize(('date_args', 'expected'), + _CFTIME_DATETIME_UNITS_TESTS) +def test_infer_cftime_datetime_units(calendar, date_args, expected): + date_type = _all_cftime_date_types()[calendar] + dates = [date_type(*args) for args in date_args] + assert expected == coding.times.infer_datetime_units(dates) @pytest.mark.parametrize( @@ -667,11 +602,8 @@ def test_format_cftime_datetime(date_args, expected): assert result == expected -@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize( - ['calendar', 'enable_cftimeindex'], - product(_ALL_CALENDARS, [False, True])) -def test_decode_cf_enable_cftimeindex(calendar, enable_cftimeindex): +@pytest.mark.parametrize('calendar', _ALL_CALENDARS) +def test_decode_cf(calendar): days = [1., 2., 3.] da = DataArray(days, coords=[days], dims=['time'], name='test') ds = da.to_dataset() @@ -680,17 +612,13 @@ def test_decode_cf_enable_cftimeindex(calendar, enable_cftimeindex): ds[v].attrs['units'] = 'days since 2001-01-01' ds[v].attrs['calendar'] = calendar - if (not has_cftime and enable_cftimeindex and - calendar not in coding.times._STANDARD_CALENDARS): + if not has_cftime_or_netCDF4 and calendar not in _STANDARD_CALENDARS: with pytest.raises(ValueError): - with set_options(enable_cftimeindex=enable_cftimeindex): - ds = decode_cf(ds) - else: - with set_options(enable_cftimeindex=enable_cftimeindex): ds = decode_cf(ds) + else: + ds = decode_cf(ds) - if (enable_cftimeindex and - calendar not in coding.times._STANDARD_CALENDARS): + if calendar not in _STANDARD_CALENDARS: assert ds.test.dtype == np.dtype('O') else: assert ds.test.dtype == np.dtype('M8[ns]') @@ -760,3 +688,16 @@ def test_contains_cftime_datetimes_non_cftimes(non_cftime_data): @pytest.mark.parametrize('non_cftime_data', [DataArray([]), DataArray([1, 2])]) def test_contains_cftime_datetimes_non_cftimes_dask(non_cftime_data): assert not contains_cftime_datetimes(non_cftime_data.chunk()) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('shape', [(24,), (8, 3), (2, 4, 3)]) +def test_encode_cf_datetime_overflow(shape): + # Test for fix to GH 2272 + dates = pd.date_range('2100', periods=24).values.reshape(shape) + units = 'days since 1800-01-01' + calendar = 'standard' + + num, _, _ = encode_cf_datetime(dates, units, calendar) + roundtrip = decode_cf_datetime(num, units, calendar) + np.testing.assert_array_equal(dates, roundtrip) diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 482a280b355..2004b1e660f 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -10,12 +10,12 @@ from xarray.core.pycompat import OrderedDict, iteritems from . import ( - InaccessibleArray, TestCase, assert_array_equal, assert_equal, - assert_identical, raises_regex, requires_dask) + InaccessibleArray, assert_array_equal, assert_equal, assert_identical, + raises_regex, requires_dask) from .test_dataset import create_test_data -class TestConcatDataset(TestCase): +class TestConcatDataset(object): def test_concat(self): # TODO: simplify and split this test case @@ -235,7 +235,7 @@ def test_concat_multiindex(self): assert isinstance(actual.x.to_index(), pd.MultiIndex) -class TestConcatDataArray(TestCase): +class TestConcatDataArray(object): def test_concat(self): ds = Dataset({'foo': (['x', 'y'], np.random.random((2, 3))), 'bar': (['x', 'y'], np.random.random((2, 3)))}, @@ -295,7 +295,7 @@ def test_concat_lazy(self): assert combined.dims == ('z', 'x', 'y') -class TestAutoCombine(TestCase): +class TestAutoCombine(object): @requires_dask # only for toolz def test_auto_combine(self): diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index e30e7e31390..1003c531018 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -15,7 +15,7 @@ join_dict_keys, ordered_set_intersection, ordered_set_union, result_name, unified_dim_sizes) -from . import raises_regex, requires_dask, has_dask +from . import has_dask, raises_regex, requires_dask def assert_identical(a, b): @@ -274,6 +274,22 @@ def func(x): assert_identical(expected_dataset_x, first_element(dataset.groupby('y'), 'x')) + def multiply(*args): + val = args[0] + for arg in args[1:]: + val = val * arg + return val + + # regression test for GH:2341 + with pytest.raises(ValueError): + apply_ufunc(multiply, data_array, data_array['y'].values, + input_core_dims=[['y']], output_core_dims=[['y']]) + expected = xr.DataArray(multiply(data_array, data_array['y']), + dims=['x', 'y'], coords=data_array.coords) + actual = apply_ufunc(multiply, data_array, data_array['y'].values, + input_core_dims=[['y'], []], output_core_dims=[['y']]) + assert_identical(expected, actual) + def test_apply_output_core_dimension(): diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 5ed482ed2bd..5fa518f5112 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -8,20 +8,20 @@ import pandas as pd import pytest -from xarray import (Dataset, Variable, SerializationWarning, coding, - conventions, open_dataset) +from xarray import ( + Dataset, SerializationWarning, Variable, coding, conventions, open_dataset) from xarray.backends.common import WritableCFDataStore from xarray.backends.memory import InMemoryDataStore from xarray.conventions import decode_cf from xarray.testing import assert_identical from . import ( - TestCase, assert_array_equal, raises_regex, requires_netCDF4, - requires_cftime_or_netCDF4, unittest, requires_dask) -from .test_backends import CFEncodedDataTest + assert_array_equal, raises_regex, requires_cftime_or_netCDF4, + requires_dask, requires_netCDF4) +from .test_backends import CFEncodedBase -class TestBoolTypeArray(TestCase): +class TestBoolTypeArray(object): def test_booltype_array(self): x = np.array([1, 0, 1, 1, 0], dtype='i1') bx = conventions.BoolTypeArray(x) @@ -30,7 +30,7 @@ def test_booltype_array(self): dtype=np.bool)) -class TestNativeEndiannessArray(TestCase): +class TestNativeEndiannessArray(object): def test(self): x = np.arange(5, dtype='>i8') expected = np.arange(5, dtype='int64') @@ -69,7 +69,7 @@ def test_decode_cf_with_conflicting_fill_missing_value(): @requires_cftime_or_netCDF4 -class TestEncodeCFVariable(TestCase): +class TestEncodeCFVariable(object): def test_incompatible_attributes(self): invalid_vars = [ Variable(['t'], pd.date_range('2000-01-01', periods=3), @@ -134,7 +134,7 @@ def test_string_object_warning(self): @requires_cftime_or_netCDF4 -class TestDecodeCF(TestCase): +class TestDecodeCF(object): def test_dataset(self): original = Dataset({ 't': ('t', [0, 1, 2], {'units': 'days since 2000-01-01'}), @@ -255,7 +255,7 @@ def encode_variable(self, var): @requires_netCDF4 -class TestCFEncodedDataStore(CFEncodedDataTest, TestCase): +class TestCFEncodedDataStore(CFEncodedBase): @contextlib.contextmanager def create_store(self): yield CFEncodedInMemoryStore() @@ -267,9 +267,10 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, data.dump_to_store(store, **save_kwargs) yield open_dataset(store, **open_kwargs) + @pytest.mark.skip('cannot roundtrip coordinates yet for ' + 'CFEncodedInMemoryStore') def test_roundtrip_coordinates(self): - raise unittest.SkipTest('cannot roundtrip coordinates yet for ' - 'CFEncodedInMemoryStore') + pass def test_invalid_dataarray_names_raise(self): # only relevant for on-disk file formats diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f6c47cce8d8..62ce7d074fa 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import pickle +from distutils.version import LooseVersion from textwrap import dedent import numpy as np @@ -14,18 +15,22 @@ from xarray.tests import mock from . import ( - TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_frame_equal, assert_identical, raises_regex) + assert_allclose, assert_array_equal, assert_equal, assert_frame_equal, + assert_identical, raises_regex) dask = pytest.importorskip('dask') da = pytest.importorskip('dask.array') dd = pytest.importorskip('dask.dataframe') -class DaskTestCase(TestCase): +class DaskTestCase(object): def assertLazyAnd(self, expected, actual, test): - with dask.set_options(get=dask.get): + + with (dask.config.set(scheduler='single-threaded') + if LooseVersion(dask.__version__) >= LooseVersion('0.18.0') + else dask.set_options(get=dask.get)): test(actual, expected) + if isinstance(actual, Dataset): for k, v in actual.variables.items(): if k in actual.dims: @@ -52,6 +57,7 @@ def assertLazyAndIdentical(self, expected, actual): def assertLazyAndAllClose(self, expected, actual): self.assertLazyAnd(expected, actual, assert_allclose) + @pytest.fixture(autouse=True) def setUp(self): self.values = np.random.RandomState(0).randn(4, 6) self.data = da.from_array(self.values, chunks=(2, 2)) @@ -196,11 +202,13 @@ def test_missing_methods(self): except NotImplementedError as err: assert 'dask' in str(err) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_univariate_ufunc(self): u = self.eager_var v = self.lazy_var self.assertLazyAndAllClose(np.sin(u), xu.sin(v)) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_bivariate_ufunc(self): u = self.eager_var v = self.lazy_var @@ -242,6 +250,7 @@ def assertLazyAndAllClose(self, expected, actual): def assertLazyAndEqual(self, expected, actual): self.assertLazyAnd(expected, actual, assert_equal) + @pytest.fixture(autouse=True) def setUp(self): self.values = np.random.randn(4, 6) self.data = da.from_array(self.values, chunks=(2, 2)) @@ -378,8 +387,8 @@ def test_groupby(self): u = self.eager_array v = self.lazy_array - expected = u.groupby('x').mean() - actual = v.groupby('x').mean() + expected = u.groupby('x').mean(xr.ALL_DIMS) + actual = v.groupby('x').mean(xr.ALL_DIMS) self.assertLazyAndAllClose(expected, actual) def test_groupby_first(self): @@ -421,6 +430,7 @@ def duplicate_and_merge(array): actual = duplicate_and_merge(self.lazy_array) self.assertLazyAndEqual(expected, actual) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_ufuncs(self): u = self.eager_array v = self.lazy_array @@ -446,7 +456,11 @@ def counting_get(*args, **kwargs): count[0] += 1 return dask.get(*args, **kwargs) - ds.load(get=counting_get) + if dask.__version__ < '0.19.4': + ds.load(get=counting_get) + else: + ds.load(scheduler=counting_get) + assert count[0] == 1 def test_stack(self): @@ -573,7 +587,7 @@ def test_from_dask_variable(self): self.assertLazyAndIdentical(self.lazy_array, a) -class TestToDaskDataFrame(TestCase): +class TestToDaskDataFrame(object): def test_to_dask_dataframe(self): # Test conversion of Datasets to dask DataFrames @@ -821,7 +835,11 @@ def test_basic_compute(): dask.multiprocessing.get, dask.local.get_sync, None]: - with dask.set_options(get=get): + with (dask.config.set(scheduler=get) + if LooseVersion(dask.__version__) >= LooseVersion('0.19.4') + else dask.config.set(scheduler=get) + if LooseVersion(dask.__version__) >= LooseVersion('0.18.0') + else dask.set_options(get=get)): ds.compute() ds.foo.compute() ds.foo.variable.compute() diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e0b1496c7bf..87ee60715a1 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1,9 +1,9 @@ from __future__ import absolute_import, division, print_function import pickle +import warnings from copy import deepcopy from textwrap import dedent -import warnings import numpy as np import pandas as pd @@ -12,18 +12,20 @@ import xarray as xr from xarray import ( DataArray, Dataset, IndexVariable, Variable, align, broadcast, set_options) -from xarray.convert import from_cdms2 from xarray.coding.times import CFDatetimeCoder, _import_cftime -from xarray.core.common import full_like +from xarray.convert import from_cdms2 +from xarray.core.common import ALL_DIMS, full_like from xarray.core.pycompat import OrderedDict, iteritems from xarray.tests import ( - ReturnItem, TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_identical, raises_regex, requires_bottleneck, requires_cftime, - requires_dask, requires_np113, requires_scipy, source_ndarray, unittest) + LooseVersion, ReturnItem, assert_allclose, assert_array_equal, + assert_equal, assert_identical, raises_regex, requires_bottleneck, + requires_cftime, requires_dask, requires_iris, requires_np113, + requires_scipy, source_ndarray) -class TestDataArray(TestCase): - def setUp(self): +class TestDataArray(object): + @pytest.fixture(autouse=True) + def setup(self): self.attrs = {'attr1': 'value1', 'attr2': 2929} self.x = np.random.random((10, 20)) self.v = Variable(['x', 'y'], self.x) @@ -439,7 +441,7 @@ def test_getitem(self): assert_identical(self.ds['x'], x) assert_identical(self.ds['y'], y) - I = ReturnItem() # noqa: E741 # allow ambiguous name + I = ReturnItem() # noqa for i in [I[:], I[...], I[x.values], I[x.variable], I[x], I[x, y], I[x.values > -1], I[x.variable > -1], I[x > -1], I[x > -1, y > -1]]: @@ -616,9 +618,9 @@ def get_data(): da[dict(x=ind)] = value # should not raise def test_contains(self): - data_array = DataArray(1, coords={'x': 2}) - with pytest.warns(FutureWarning): - assert 'x' in data_array + data_array = DataArray([1, 2]) + assert 1 in data_array + assert 3 not in data_array def test_attr_sources_multiindex(self): # make sure attr-style access for multi-index levels @@ -671,6 +673,7 @@ def test_isel_types(self): assert_identical(da.isel(x=np.array([0], dtype="int64")), da.isel(x=np.array([0]))) + @pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_isel_fancy(self): shape = (10, 7, 6) np_array = np.random.random(shape) @@ -844,6 +847,7 @@ def test_isel_drop(self): selected = data.isel(x=0, drop=False) assert_identical(expected, selected) + @pytest.mark.filterwarnings("ignore:Dataset.isel_points") def test_isel_points(self): shape = (10, 5, 6) np_array = np.random.random(shape) @@ -998,7 +1002,7 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, assert da.dims[0] == renamed_dim da = da.rename({renamed_dim: 'x'}) assert_identical(da.variable, expected_da.variable) - self.assertVariableNotEqual(da['x'], expected_da['x']) + assert not da['x'].equals(expected_da['x']) test_sel(('a', 1, -1), 0) test_sel(('b', 2, -2), -1) @@ -1151,7 +1155,7 @@ def test_reset_coords(self): assert_identical(actual, expected) actual = data.copy() - actual.reset_coords(drop=True, inplace=True) + actual = actual.reset_coords(drop=True) assert_identical(actual, expected) actual = data.reset_coords('bar', drop=True) @@ -1160,8 +1164,9 @@ def test_reset_coords(self): dims=['x', 'y'], name='foo') assert_identical(actual, expected) - with raises_regex(ValueError, 'cannot reset coord'): - data.reset_coords(inplace=True) + with pytest.warns(FutureWarning, message='The inplace argument'): + with raises_regex(ValueError, 'cannot reset coord'): + data = data.reset_coords(inplace=True) with raises_regex(ValueError, 'cannot be found'): data.reset_coords('foo', drop=True) with raises_regex(ValueError, 'cannot be found'): @@ -1236,6 +1241,7 @@ def test_reindex_like_no_index(self): ValueError, 'different size for unlabeled'): foo.reindex_like(bar) + @pytest.mark.filterwarnings('ignore:Indexer has dimensions') def test_reindex_regressions(self): # regression test for #279 expected = DataArray(np.random.randn(5), coords=[("time", range(5))]) @@ -1285,7 +1291,7 @@ def test_swap_dims(self): def test_expand_dims_error(self): array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3.0)}, + coords={'x': np.linspace(0.0, 1.0, 3)}, attrs={'key': 'entry'}) with raises_regex(ValueError, 'dim should be str or'): @@ -1393,7 +1399,7 @@ def test_set_index(self): expected = array.set_index(x=['level_1', 'level_2', 'level_3']) assert_identical(obj, expected) - array.set_index(x=['level_1', 'level_2', 'level_3'], inplace=True) + array = array.set_index(x=['level_1', 'level_2', 'level_3']) assert_identical(array, expected) array2d = DataArray(np.random.rand(2, 2), @@ -1426,7 +1432,7 @@ def test_reset_index(self): assert_identical(obj, expected) array = self.mda.copy() - array.reset_index(['x'], drop=True, inplace=True) + array = array.reset_index(['x'], drop=True) assert_identical(array, expected) # single index @@ -1442,9 +1448,10 @@ def test_reorder_levels(self): obj = self.mda.reorder_levels(x=['level_2', 'level_1']) assert_identical(obj, expected) - array = self.mda.copy() - array.reorder_levels(x=['level_2', 'level_1'], inplace=True) - assert_identical(array, expected) + with pytest.warns(FutureWarning, message='The inplace argument'): + array = self.mda.copy() + array.reorder_levels(x=['level_2', 'level_1'], inplace=True) + assert_identical(array, expected) array = DataArray([1, 2], dims='x') with pytest.raises(KeyError): @@ -1659,9 +1666,23 @@ def test_dataset_math(self): def test_stack_unstack(self): orig = DataArray([[0, 1], [2, 3]], dims=['x', 'y'], attrs={'foo': 2}) + assert_identical(orig, orig.unstack()) + actual = orig.stack(z=['x', 'y']).unstack('z').drop(['x', 'y']) assert_identical(orig, actual) + dims = ['a', 'b', 'c', 'd', 'e'] + orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), dims=dims) + stacked = orig.stack(ab=['a', 'b'], cd=['c', 'd']) + + unstacked = stacked.unstack(['ab', 'cd']) + roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims) + assert_identical(orig, roundtripped) + + unstacked = stacked.unstack() + roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims) + assert_identical(orig, roundtripped) + def test_stack_unstack_decreasing_coordinate(self): # regression test for GH980 orig = DataArray(np.random.rand(3, 4), dims=('y', 'x'), @@ -1982,15 +2003,15 @@ def test_groupby_sum(self): self.x[:, 10:].sum(), self.x[:, 9:10].sum()]).T), 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] - assert_allclose(expected_sum_all, grouped.reduce(np.sum)) - assert_allclose(expected_sum_all, grouped.sum()) + assert_allclose(expected_sum_all, grouped.reduce(np.sum, dim=ALL_DIMS)) + assert_allclose(expected_sum_all, grouped.sum(ALL_DIMS)) expected = DataArray([array['y'].values[idx].sum() for idx in [slice(9), slice(10, None), slice(9, 10)]], [['a', 'b', 'c']], ['abc']) actual = array['y'].groupby('abc').apply(np.sum) assert_allclose(expected, actual) - actual = array['y'].groupby('abc').sum() + actual = array['y'].groupby('abc').sum(ALL_DIMS) assert_allclose(expected, actual) expected_sum_axis1 = Dataset( @@ -2001,6 +2022,27 @@ def test_groupby_sum(self): assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, 'y')) assert_allclose(expected_sum_axis1, grouped.sum('y')) + def test_groupby_warning(self): + array = self.make_groupby_example_array() + grouped = array.groupby('y') + with pytest.warns(FutureWarning): + grouped.sum() + + @pytest.mark.skipif(LooseVersion(xr.__version__) < LooseVersion('0.12'), + reason="not to forget the behavior change") + def test_groupby_sum_default(self): + array = self.make_groupby_example_array() + grouped = array.groupby('abc') + + expected_sum_all = Dataset( + {'foo': Variable(['x', 'abc'], + np.array([self.x[:, :9].sum(axis=-1), + self.x[:, 10:].sum(axis=-1), + self.x[:, 9:10].sum(axis=-1)]).T), + 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] + + assert_allclose(expected_sum_all, grouped.sum()) + def test_groupby_count(self): array = DataArray( [0, 0, np.nan, np.nan, 0, 0], @@ -2010,7 +2052,7 @@ def test_groupby_count(self): expected = DataArray([1, 1, 2], coords=[('cat', ['a', 'b', 'c'])]) assert_identical(actual, expected) - @unittest.skip('needs to be fixed for shortcut=False, keep_attrs=False') + @pytest.mark.skip('needs to be fixed for shortcut=False, keep_attrs=False') def test_groupby_reduce_attrs(self): array = self.make_groupby_example_array() array.attrs['foo'] = 'bar' @@ -2081,9 +2123,9 @@ def test_groupby_math(self): assert_identical(expected, actual) grouped = array.groupby('abc') - expected_agg = (grouped.mean() - np.arange(3)).rename(None) + expected_agg = (grouped.mean(ALL_DIMS) - np.arange(3)).rename(None) actual = grouped - DataArray(range(3), [('abc', ['a', 'b', 'c'])]) - actual_agg = actual.groupby('abc').mean() + actual_agg = actual.groupby('abc').mean(ALL_DIMS) assert_allclose(expected_agg, actual_agg) with raises_regex(TypeError, 'only support binary ops'): @@ -2157,7 +2199,7 @@ def test_groupby_multidim(self): ('lon', DataArray([5, 28, 23], coords=[('lon', [30., 40., 50.])])), ('lat', DataArray([16, 40], coords=[('lat', [10., 20.])]))]: - actual_sum = array.groupby(dim).sum() + actual_sum = array.groupby(dim).sum(ALL_DIMS) assert_identical(expected_sum, actual_sum) def test_groupby_multidim_apply(self): @@ -2239,12 +2281,9 @@ def test_resample_cftimeindex(self): cftime = _import_cftime() times = cftime.num2date(np.arange(12), units='hours since 0001-01-01', calendar='noleap') - with set_options(enable_cftimeindex=True): - array = DataArray(np.arange(12), [('time', times)]) + array = DataArray(np.arange(12), [('time', times)]) - with raises_regex(TypeError, - 'Only valid with DatetimeIndex, ' - 'TimedeltaIndex or PeriodIndex'): + with raises_regex(NotImplementedError, 'to_datetimeindex'): array.resample(time='6H').mean() def test_resample_first(self): @@ -2319,53 +2358,24 @@ def test_resample_drop_nondim_coords(self): actual = array.resample(time="1H").interpolate('linear') assert 'tc' not in actual.coords - def test_resample_old_vs_new_api(self): + def test_resample_keep_attrs(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) array = DataArray(np.ones(10), [('time', times)]) + array.attrs['meta'] = 'data' - # Simple mean - with pytest.warns(FutureWarning): - old_mean = array.resample('1D', 'time', how='mean') - new_mean = array.resample(time='1D').mean() - assert_identical(old_mean, new_mean) - - # Mean, while keeping attributes - attr_array = array.copy() - attr_array.attrs['meta'] = 'data' + result = array.resample(time='1D').mean(keep_attrs=True) + expected = DataArray([1, 1, 1], [('time', times[::4])], + attrs=array.attrs) + assert_identical(result, expected) - with pytest.warns(FutureWarning): - old_mean = attr_array.resample('1D', dim='time', how='mean', - keep_attrs=True) - new_mean = attr_array.resample(time='1D').mean(keep_attrs=True) - assert old_mean.attrs == new_mean.attrs - assert_identical(old_mean, new_mean) - - # Mean, with NaN to skip - nan_array = array.copy() - nan_array[1] = np.nan + def test_resample_skipna(self): + times = pd.date_range('2000-01-01', freq='6H', periods=10) + array = DataArray(np.ones(10), [('time', times)]) + array[1] = np.nan - with pytest.warns(FutureWarning): - old_mean = nan_array.resample('1D', 'time', how='mean', - skipna=False) - new_mean = nan_array.resample(time='1D').mean(skipna=False) + result = array.resample(time='1D').mean(skipna=False) expected = DataArray([np.nan, 1, 1], [('time', times[::4])]) - assert_identical(old_mean, expected) - assert_identical(new_mean, expected) - - # Try other common resampling methods - resampler = array.resample(time='1D') - for method in ['mean', 'median', 'sum', 'first', 'last', 'count']: - # Discard attributes on the call using the new api to match - # convention from old api - new_api = getattr(resampler, method)(keep_attrs=False) - with pytest.warns(FutureWarning): - old_api = array.resample('1D', dim='time', how=method) - assert_identical(new_api, old_api) - for method in [np.mean, np.sum, np.max, np.min]: - new_api = resampler.reduce(method) - with pytest.warns(FutureWarning): - old_api = array.resample('1D', dim='time', how=method) - assert_identical(new_api, old_api) + assert_identical(result, expected) def test_upsample(self): times = pd.date_range('2000-01-01', freq='6H', periods=5) @@ -2493,6 +2503,7 @@ def test_upsample_interpolate_regression_1605(self): assert_allclose(actual, expected, rtol=1e-16) @requires_dask + @requires_scipy def test_upsample_interpolate_dask(self): import dask.array as da @@ -2786,7 +2797,7 @@ def test_to_and_from_series(self): def test_series_categorical_index(self): # regression test for GH700 if not hasattr(pd, 'CategoricalIndex'): - raise unittest.SkipTest('requires pandas with CategoricalIndex') + pytest.skip('requires pandas with CategoricalIndex') s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list('aabbc'))) arr = DataArray(s) @@ -2926,9 +2937,9 @@ def test_to_and_from_cdms2_classic(self): expected_coords = [IndexVariable('distance', [-2, 2]), IndexVariable('time', [0, 1, 2])] actual = original.to_cdms2() - assert_array_equal(actual, original) + assert_array_equal(actual.asma(), original) assert actual.id == original.name - self.assertItemsEqual(actual.getAxisIds(), original.dims) + assert tuple(actual.getAxisIds()) == original.dims for axis, coord in zip(actual.getAxisList(), expected_coords): assert axis.id == coord.name assert_array_equal(axis, coord.values) @@ -2942,8 +2953,8 @@ def test_to_and_from_cdms2_classic(self): assert_identical(original, roundtripped) back = from_cdms2(actual) - self.assertItemsEqual(original.dims, back.dims) - self.assertItemsEqual(original.coords.keys(), back.coords.keys()) + assert original.dims == back.dims + assert original.coords.keys() == back.coords.keys() for coord_name in original.coords.keys(): assert_array_equal(original.coords[coord_name], back.coords[coord_name]) @@ -2964,13 +2975,15 @@ def test_to_and_from_cdms2_sgrid(self): coords=OrderedDict(x=x, y=y, lon=lon, lat=lat), name='sst') actual = original.to_cdms2() - self.assertItemsEqual(actual.getAxisIds(), original.dims) - assert_array_equal(original.coords['lon'], actual.getLongitude()) - assert_array_equal(original.coords['lat'], actual.getLatitude()) + assert tuple(actual.getAxisIds()) == original.dims + assert_array_equal(original.coords['lon'], + actual.getLongitude().asma()) + assert_array_equal(original.coords['lat'], + actual.getLatitude().asma()) back = from_cdms2(actual) - self.assertItemsEqual(original.dims, back.dims) - self.assertItemsEqual(original.coords.keys(), back.coords.keys()) + assert original.dims == back.dims + assert set(original.coords.keys()) == set(back.coords.keys()) assert_array_equal(original.coords['lat'], back.coords['lat']) assert_array_equal(original.coords['lon'], back.coords['lon']) @@ -2984,158 +2997,18 @@ def test_to_and_from_cdms2_ugrid(self): original = DataArray(np.arange(5), dims=['cell'], coords={'lon': lon, 'lat': lat, 'cell': cell}) actual = original.to_cdms2() - self.assertItemsEqual(actual.getAxisIds(), original.dims) - assert_array_equal(original.coords['lon'], actual.getLongitude()) - assert_array_equal(original.coords['lat'], actual.getLatitude()) + assert tuple(actual.getAxisIds()) == original.dims + assert_array_equal(original.coords['lon'], + actual.getLongitude().getValue()) + assert_array_equal(original.coords['lat'], + actual.getLatitude().getValue()) back = from_cdms2(actual) - self.assertItemsEqual(original.dims, back.dims) - self.assertItemsEqual(original.coords.keys(), back.coords.keys()) + assert set(original.dims) == set(back.dims) + assert set(original.coords.keys()) == set(back.coords.keys()) assert_array_equal(original.coords['lat'], back.coords['lat']) assert_array_equal(original.coords['lon'], back.coords['lon']) - def test_to_and_from_iris(self): - try: - import iris - import cf_units - except ImportError: - raise unittest.SkipTest('iris not installed') - - coord_dict = OrderedDict() - coord_dict['distance'] = ('distance', [-2, 2], {'units': 'meters'}) - coord_dict['time'] = ('time', pd.date_range('2000-01-01', periods=3)) - coord_dict['height'] = 10 - coord_dict['distance2'] = ('distance', [0, 1], {'foo': 'bar'}) - coord_dict['time2'] = (('distance', 'time'), [[0, 1, 2], [2, 3, 4]]) - - original = DataArray(np.arange(6, dtype='float').reshape(2, 3), - coord_dict, name='Temperature', - attrs={'baz': 123, 'units': 'Kelvin', - 'standard_name': 'fire_temperature', - 'long_name': 'Fire Temperature'}, - dims=('distance', 'time')) - - # Set a bad value to test the masking logic - original.data[0, 2] = np.NaN - - original.attrs['cell_methods'] = \ - 'height: mean (comment: A cell method)' - actual = original.to_iris() - assert_array_equal(actual.data, original.data) - assert actual.var_name == original.name - self.assertItemsEqual([d.var_name for d in actual.dim_coords], - original.dims) - assert (actual.cell_methods == (iris.coords.CellMethod( - method='mean', - coords=('height', ), - intervals=(), - comments=('A cell method', )), )) - - for coord, orginal_key in zip((actual.coords()), original.coords): - original_coord = original.coords[orginal_key] - assert coord.var_name == original_coord.name - assert_array_equal( - coord.points, CFDatetimeCoder().encode(original_coord).values) - assert (actual.coord_dims(coord) == - original.get_axis_num( - original.coords[coord.var_name].dims)) - - assert (actual.coord('distance2').attributes['foo'] == - original.coords['distance2'].attrs['foo']) - assert (actual.coord('distance').units == - cf_units.Unit(original.coords['distance'].units)) - assert actual.attributes['baz'] == original.attrs['baz'] - assert actual.standard_name == original.attrs['standard_name'] - - roundtripped = DataArray.from_iris(actual) - assert_identical(original, roundtripped) - - actual.remove_coord('time') - auto_time_dimension = DataArray.from_iris(actual) - assert auto_time_dimension.dims == ('distance', 'dim_1') - - actual.coord('distance').var_name = None - with raises_regex(ValueError, 'no var_name attribute'): - DataArray.from_iris(actual) - - @requires_dask - def test_to_and_from_iris_dask(self): - import dask.array as da - try: - import iris - import cf_units - except ImportError: - raise unittest.SkipTest('iris not installed') - - coord_dict = OrderedDict() - coord_dict['distance'] = ('distance', [-2, 2], {'units': 'meters'}) - coord_dict['time'] = ('time', pd.date_range('2000-01-01', periods=3)) - coord_dict['height'] = 10 - coord_dict['distance2'] = ('distance', [0, 1], {'foo': 'bar'}) - coord_dict['time2'] = (('distance', 'time'), [[0, 1, 2], [2, 3, 4]]) - - original = DataArray( - da.from_array(np.arange(-1, 5, dtype='float').reshape(2, 3), 3), - coord_dict, - name='Temperature', - attrs=dict(baz=123, units='Kelvin', - standard_name='fire_temperature', - long_name='Fire Temperature'), - dims=('distance', 'time')) - - # Set a bad value to test the masking logic - original.data = da.ma.masked_less(original.data, 0) - - original.attrs['cell_methods'] = \ - 'height: mean (comment: A cell method)' - actual = original.to_iris() - - # Be careful not to trigger the loading of the iris data - actual_data = actual.core_data() if \ - hasattr(actual, 'core_data') else actual.data - assert_array_equal(actual_data, original.data) - assert actual.var_name == original.name - self.assertItemsEqual([d.var_name for d in actual.dim_coords], - original.dims) - assert (actual.cell_methods == (iris.coords.CellMethod( - method='mean', - coords=('height', ), - intervals=(), - comments=('A cell method', )), )) - - for coord, orginal_key in zip((actual.coords()), original.coords): - original_coord = original.coords[orginal_key] - assert coord.var_name == original_coord.name - assert_array_equal( - coord.points, CFDatetimeCoder().encode(original_coord).values) - assert (actual.coord_dims(coord) == - original.get_axis_num( - original.coords[coord.var_name].dims)) - - assert (actual.coord('distance2').attributes['foo'] == original.coords[ - 'distance2'].attrs['foo']) - assert (actual.coord('distance').units == - cf_units.Unit(original.coords['distance'].units)) - assert actual.attributes['baz'] == original.attrs['baz'] - assert actual.standard_name == original.attrs['standard_name'] - - roundtripped = DataArray.from_iris(actual) - assert_identical(original, roundtripped) - - # If the Iris version supports it then we should have a dask array - # at each stage of the conversion - if hasattr(actual, 'core_data'): - self.assertEqual(type(original.data), type(actual.core_data())) - self.assertEqual(type(original.data), type(roundtripped.data)) - - actual.remove_coord('time') - auto_time_dimension = DataArray.from_iris(actual) - assert auto_time_dimension.dims == ('distance', 'dim_1') - - actual.coord('distance').var_name = None - with raises_regex(ValueError, 'no var_name attribute'): - DataArray.from_iris(actual) - def test_to_dataset_whole(self): unnamed = DataArray([1, 2], dims='x') with raises_regex(ValueError, 'unable to convert unnamed'): @@ -3225,24 +3098,51 @@ def test_coordinate_diff(self): actual = lon.diff('lon') assert_equal(expected, actual) - def test_shift(self): + @pytest.mark.parametrize('offset', [-5, -2, -1, 0, 1, 2, 5]) + def test_shift(self, offset): arr = DataArray([1, 2, 3], dims='x') actual = arr.shift(x=1) expected = DataArray([np.nan, 1, 2], dims='x') assert_identical(expected, actual) arr = DataArray([1, 2, 3], [('x', ['a', 'b', 'c'])]) - for offset in [-5, -2, -1, 0, 1, 2, 5]: - expected = DataArray(arr.to_pandas().shift(offset)) - actual = arr.shift(x=offset) - assert_identical(expected, actual) + expected = DataArray(arr.to_pandas().shift(offset)) + actual = arr.shift(x=offset) + assert_identical(expected, actual) + + def test_roll_coords(self): + arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + actual = arr.roll(x=1, roll_coords=True) + expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) + assert_identical(expected, actual) + + def test_roll_no_coords(self): + arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + actual = arr.roll(x=1, roll_coords=False) + expected = DataArray([3, 1, 2], coords=[('x', [0, 1, 2])]) + assert_identical(expected, actual) - def test_roll(self): + def test_roll_coords_none(self): arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') - actual = arr.roll(x=1) + + with pytest.warns(FutureWarning): + actual = arr.roll(x=1, roll_coords=None) + expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) assert_identical(expected, actual) + def test_copy_with_data(self): + orig = DataArray(np.random.random(size=(2, 2)), + dims=('x', 'y'), + attrs={'attr1': 'value1'}, + coords={'x': [4, 3]}, + name='helloworld') + new_data = np.arange(4).reshape(2, 2) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + def test_real_and_imag(self): array = DataArray(1 + 2j) assert_identical(array.real, DataArray(1)) @@ -3467,7 +3367,9 @@ def test_isin(da): def test_rolling_iter(da): rolling_obj = da.rolling(time=7) - rolling_obj_mean = rolling_obj.mean() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') + rolling_obj_mean = rolling_obj.mean() assert len(rolling_obj.window_labels) == len(da['time']) assert_identical(rolling_obj.window_labels, da['time']) @@ -3475,8 +3377,10 @@ def test_rolling_iter(da): for i, (label, window_da) in enumerate(rolling_obj): assert label == da['time'].isel(time=i) - actual = rolling_obj_mean.isel(time=i) - expected = window_da.mean('time') + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') + actual = rolling_obj_mean.isel(time=i) + expected = window_da.mean('time') # TODO add assert_allclose_with_nan, which compares nan position # as well as the closeness of the values. @@ -3693,3 +3597,214 @@ def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: xr.DataArray([1, 2, np.NaN]) > 0 assert len(record) == 0 + + +class TestIrisConversion(object): + @requires_iris + def test_to_and_from_iris(self): + import iris + import cf_units # iris requirement + + # to iris + coord_dict = OrderedDict() + coord_dict['distance'] = ('distance', [-2, 2], {'units': 'meters'}) + coord_dict['time'] = ('time', pd.date_range('2000-01-01', periods=3)) + coord_dict['height'] = 10 + coord_dict['distance2'] = ('distance', [0, 1], {'foo': 'bar'}) + coord_dict['time2'] = (('distance', 'time'), [[0, 1, 2], [2, 3, 4]]) + + original = DataArray(np.arange(6, dtype='float').reshape(2, 3), + coord_dict, name='Temperature', + attrs={'baz': 123, 'units': 'Kelvin', + 'standard_name': 'fire_temperature', + 'long_name': 'Fire Temperature'}, + dims=('distance', 'time')) + + # Set a bad value to test the masking logic + original.data[0, 2] = np.NaN + + original.attrs['cell_methods'] = \ + 'height: mean (comment: A cell method)' + actual = original.to_iris() + assert_array_equal(actual.data, original.data) + assert actual.var_name == original.name + assert tuple(d.var_name for d in actual.dim_coords) == original.dims + assert (actual.cell_methods == (iris.coords.CellMethod( + method='mean', + coords=('height', ), + intervals=(), + comments=('A cell method', )), )) + + for coord, orginal_key in zip((actual.coords()), original.coords): + original_coord = original.coords[orginal_key] + assert coord.var_name == original_coord.name + assert_array_equal( + coord.points, CFDatetimeCoder().encode(original_coord).values) + assert (actual.coord_dims(coord) == + original.get_axis_num( + original.coords[coord.var_name].dims)) + + assert (actual.coord('distance2').attributes['foo'] == + original.coords['distance2'].attrs['foo']) + assert (actual.coord('distance').units == + cf_units.Unit(original.coords['distance'].units)) + assert actual.attributes['baz'] == original.attrs['baz'] + assert actual.standard_name == original.attrs['standard_name'] + + roundtripped = DataArray.from_iris(actual) + assert_identical(original, roundtripped) + + actual.remove_coord('time') + auto_time_dimension = DataArray.from_iris(actual) + assert auto_time_dimension.dims == ('distance', 'dim_1') + + @requires_iris + @requires_dask + def test_to_and_from_iris_dask(self): + import dask.array as da + import iris + import cf_units # iris requirement + + coord_dict = OrderedDict() + coord_dict['distance'] = ('distance', [-2, 2], {'units': 'meters'}) + coord_dict['time'] = ('time', pd.date_range('2000-01-01', periods=3)) + coord_dict['height'] = 10 + coord_dict['distance2'] = ('distance', [0, 1], {'foo': 'bar'}) + coord_dict['time2'] = (('distance', 'time'), [[0, 1, 2], [2, 3, 4]]) + + original = DataArray( + da.from_array(np.arange(-1, 5, dtype='float').reshape(2, 3), 3), + coord_dict, + name='Temperature', + attrs=dict(baz=123, units='Kelvin', + standard_name='fire_temperature', + long_name='Fire Temperature'), + dims=('distance', 'time')) + + # Set a bad value to test the masking logic + original.data = da.ma.masked_less(original.data, 0) + + original.attrs['cell_methods'] = \ + 'height: mean (comment: A cell method)' + actual = original.to_iris() + + # Be careful not to trigger the loading of the iris data + actual_data = actual.core_data() if \ + hasattr(actual, 'core_data') else actual.data + assert_array_equal(actual_data, original.data) + assert actual.var_name == original.name + assert tuple(d.var_name for d in actual.dim_coords) == original.dims + assert (actual.cell_methods == (iris.coords.CellMethod( + method='mean', + coords=('height', ), + intervals=(), + comments=('A cell method', )), )) + + for coord, orginal_key in zip((actual.coords()), original.coords): + original_coord = original.coords[orginal_key] + assert coord.var_name == original_coord.name + assert_array_equal( + coord.points, CFDatetimeCoder().encode(original_coord).values) + assert (actual.coord_dims(coord) == + original.get_axis_num( + original.coords[coord.var_name].dims)) + + assert (actual.coord('distance2').attributes['foo'] == original.coords[ + 'distance2'].attrs['foo']) + assert (actual.coord('distance').units == + cf_units.Unit(original.coords['distance'].units)) + assert actual.attributes['baz'] == original.attrs['baz'] + assert actual.standard_name == original.attrs['standard_name'] + + roundtripped = DataArray.from_iris(actual) + assert_identical(original, roundtripped) + + # If the Iris version supports it then we should have a dask array + # at each stage of the conversion + if hasattr(actual, 'core_data'): + assert isinstance(original.data, type(actual.core_data())) + assert isinstance(original.data, type(roundtripped.data)) + + actual.remove_coord('time') + auto_time_dimension = DataArray.from_iris(actual) + assert auto_time_dimension.dims == ('distance', 'dim_1') + + @requires_iris + @pytest.mark.parametrize('var_name, std_name, long_name, name, attrs', [ + ('var_name', 'height', 'Height', + 'var_name', {'standard_name': 'height', 'long_name': 'Height'}), + (None, 'height', 'Height', + 'height', {'standard_name': 'height', 'long_name': 'Height'}), + (None, None, 'Height', + 'Height', {'long_name': 'Height'}), + (None, None, None, + None, {}), + ]) + def test_da_name_from_cube(self, std_name, long_name, var_name, name, + attrs): + from iris.cube import Cube + + data = [] + cube = Cube(data, var_name=var_name, standard_name=std_name, + long_name=long_name) + result = xr.DataArray.from_iris(cube) + expected = xr.DataArray(data, name=name, attrs=attrs) + xr.testing.assert_identical(result, expected) + + @requires_iris + @pytest.mark.parametrize('var_name, std_name, long_name, name, attrs', [ + ('var_name', 'height', 'Height', + 'var_name', {'standard_name': 'height', 'long_name': 'Height'}), + (None, 'height', 'Height', + 'height', {'standard_name': 'height', 'long_name': 'Height'}), + (None, None, 'Height', + 'Height', {'long_name': 'Height'}), + (None, None, None, + 'unknown', {}), + ]) + def test_da_coord_name_from_cube(self, std_name, long_name, var_name, + name, attrs): + from iris.cube import Cube + from iris.coords import DimCoord + + latitude = DimCoord([-90, 0, 90], standard_name=std_name, + var_name=var_name, long_name=long_name) + data = [0, 0, 0] + cube = Cube(data, dim_coords_and_dims=[(latitude, 0)]) + result = xr.DataArray.from_iris(cube) + expected = xr.DataArray(data, coords=[(name, [-90, 0, 90], attrs)]) + xr.testing.assert_identical(result, expected) + + @requires_iris + def test_prevent_duplicate_coord_names(self): + from iris.cube import Cube + from iris.coords import DimCoord + + # Iris enforces unique coordinate names. Because we use a different + # name resolution order a valid iris Cube with coords that have the + # same var_name would lead to duplicate dimension names in the + # DataArray + longitude = DimCoord([0, 360], standard_name='longitude', + var_name='duplicate') + latitude = DimCoord([-90, 0, 90], standard_name='latitude', + var_name='duplicate') + data = [[0, 0, 0], [0, 0, 0]] + cube = Cube(data, dim_coords_and_dims=[(longitude, 0), (latitude, 1)]) + with pytest.raises(ValueError): + xr.DataArray.from_iris(cube) + + @requires_iris + @pytest.mark.parametrize('coord_values', [ + ['IA', 'IL', 'IN'], # non-numeric values + [0, 2, 1], # non-monotonic values + ]) + def test_fallback_to_iris_AuxCoord(self, coord_values): + from iris.cube import Cube + from iris.coords import AuxCoord + + data = [0, 0, 0] + da = xr.DataArray(data, coords=[coord_values], dims=['space']) + result = xr.DataArray.to_iris(da) + expected = Cube(data, aux_coords_and_dims=[ + (AuxCoord(coord_values, var_name='space'), 0)]) + assert result == expected diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 4aa99b8ee5a..89ea3ba78a0 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function +import sys +import warnings from copy import copy, deepcopy from io import StringIO from textwrap import dedent -import warnings import numpy as np import pandas as pd @@ -12,17 +13,18 @@ import xarray as xr from xarray import ( - DataArray, Dataset, IndexVariable, MergeError, Variable, align, backends, - broadcast, open_dataset, set_options) -from xarray.core import indexing, utils + ALL_DIMS, DataArray, Dataset, IndexVariable, MergeError, Variable, align, + backends, broadcast, open_dataset, set_options) +from xarray.core import indexing, npcompat, utils from xarray.core.common import full_like from xarray.core.pycompat import ( OrderedDict, integer_types, iteritems, unicode_type) from . import ( - InaccessibleArray, TestCase, UnexpectedDataAccess, assert_allclose, - assert_array_equal, assert_equal, assert_identical, has_dask, raises_regex, - requires_bottleneck, requires_dask, requires_scipy, source_ndarray) + InaccessibleArray, UnexpectedDataAccess, assert_allclose, + assert_array_equal, assert_equal, assert_identical, has_cftime, has_dask, + raises_regex, requires_bottleneck, requires_dask, requires_scipy, + source_ndarray) try: import cPickle as pickle @@ -62,8 +64,8 @@ def create_test_multiindex(): class InaccessibleVariableDataStore(backends.InMemoryDataStore): - def __init__(self, writer=None): - super(InaccessibleVariableDataStore, self).__init__(writer) + def __init__(self): + super(InaccessibleVariableDataStore, self).__init__() self._indexvars = set() def store(self, variables, *args, **kwargs): @@ -84,7 +86,7 @@ def lazy_inaccessible(k, v): k, v in iteritems(self._variables)) -class TestDataset(TestCase): +class TestDataset(object): def test_repr(self): data = create_test_data(seed=123) data.attrs['foo'] = 'bar' @@ -93,15 +95,15 @@ def test_repr(self): Dimensions: (dim1: 8, dim2: 9, dim3: 10, time: 20) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20 * dim2 (dim2) float64 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 * dim3 (dim3) %s 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' numbers (dim3) int64 0 1 2 0 0 1 1 2 2 3 Dimensions without coordinates: dim1 Data variables: - var1 (dim1, dim2) float64 -1.086 0.9973 0.283 -1.506 -0.5786 1.651 ... - var2 (dim1, dim2) float64 1.162 -1.097 -2.123 1.04 -0.4034 -0.126 ... - var3 (dim3, dim1) float64 0.5565 -0.2121 0.4563 1.545 -0.2397 0.1433 ... + var1 (dim1, dim2) float64 -1.086 0.9973 0.283 ... 0.1995 0.4684 -0.8312 + var2 (dim1, dim2) float64 1.162 -1.097 -2.123 ... 0.1302 1.267 0.3328 + var3 (dim3, dim1) float64 0.5565 -0.2121 0.4563 ... -0.2452 -0.3616 Attributes: foo: bar""") % data['dim3'].dtype # noqa: E501 actual = '\n'.join(x.rstrip() for x in repr(data).split('\n')) @@ -182,15 +184,16 @@ def test_unicode_data(self): data = Dataset({u'foø': [u'ba®']}, attrs={u'å': u'∑'}) repr(data) # should not raise + byteorder = '<' if sys.byteorder == 'little' else '>' expected = dedent(u"""\ Dimensions: (foø: 1) Coordinates: - * foø (foø) 2.0, np.nan, ds['foo']) + + actual = ds.resample(time='1D').sum(min_count=1) + expected = xr.concat([ + ds.isel(time=slice(i * 4, (i + 1) * 4)).sum('time', min_count=1) + for i in range(3)], dim=actual['time']) + assert_equal(expected, actual) + def test_resample_by_mean_with_keep_attrs(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), @@ -2783,22 +2858,21 @@ def test_resample_drop_nondim_coords(self): actual = ds.resample(time="1H").interpolate('linear') assert 'tc' not in actual.coords - def test_resample_old_vs_new_api(self): + def test_resample_old_api(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), 'bar': ('time', np.random.randn(10), {'meta': 'data'}), 'time': times}) - ds.attrs['dsmeta'] = 'dsdata' - for method in ['mean', 'sum', 'count', 'first', 'last']: - resampler = ds.resample(time='1D') - # Discard attributes on the call using the new api to match - # convention from old api - new_api = getattr(resampler, method)(keep_attrs=False) - with pytest.warns(FutureWarning): - old_api = ds.resample('1D', dim='time', how=method) - assert_identical(new_api, old_api) + with raises_regex(TypeError, r'resample\(\) no longer supports'): + ds.resample('1D', 'time') + + with raises_regex(TypeError, r'resample\(\) no longer supports'): + ds.resample('1D', dim='time', how='mean') + + with raises_regex(TypeError, r'resample\(\) no longer supports'): + ds.resample('1D', dim='time') def test_to_array(self): ds = Dataset(OrderedDict([('a', 1), ('b', ('x', [1, 2, 3]))]), @@ -3350,9 +3424,8 @@ def test_reduce(self): (['dim2', 'time'], ['dim1', 'dim3']), (('dim2', 'time'), ['dim1', 'dim3']), ((), ['dim1', 'dim2', 'dim3', 'time'])]: - actual = data.min(dim=reduct).dims - print(reduct, actual, expected) - self.assertItemsEqual(actual, expected) + actual = list(data.min(dim=reduct).dims) + assert actual == expected assert_equal(data.mean(dim=[]), data) @@ -3406,8 +3479,7 @@ def test_reduce_cumsum_test_dims(self): ('time', ['dim1', 'dim2', 'dim3']) ]: actual = getattr(data, cumfunc)(dim=reduct).dims - print(reduct, actual, expected) - self.assertItemsEqual(actual, expected) + assert list(actual) == expected def test_reduce_non_numeric(self): data1 = create_test_data(seed=44) @@ -3545,14 +3617,14 @@ def test_rank(self): ds = create_test_data(seed=1234) # only ds.var3 depends on dim3 z = ds.rank('dim3') - self.assertItemsEqual(['var3'], list(z.data_vars)) + assert ['var3'] == list(z.data_vars) # same as dataarray version x = z.var3 y = ds.var3.rank('dim3') assert_equal(x, y) # coordinates stick - self.assertItemsEqual(list(z.coords), list(ds.coords)) - self.assertItemsEqual(list(x.coords), list(y.coords)) + assert list(z.coords) == list(ds.coords) + assert list(x.coords) == list(y.coords) # invalid dim with raises_regex(ValueError, 'does not contain'): x.rank('invalid_dim') @@ -3733,10 +3805,6 @@ def test_dataset_transpose(self): expected = ds.apply(lambda x: x.transpose()) assert_identical(expected, actual) - with pytest.warns(FutureWarning): - actual = ds.T - assert_identical(expected, actual) - actual = ds.transpose('x', 'y') expected = ds.apply(lambda x: x.transpose('x', 'y')) assert_identical(expected, actual) @@ -3835,18 +3903,52 @@ def test_shift(self): with raises_regex(ValueError, 'dimensions'): ds.shift(foo=123) - def test_roll(self): + def test_roll_coords(self): coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} attrs = {'meta': 'data'} ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) - actual = ds.roll(x=1) + actual = ds.roll(x=1, roll_coords=True) ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]} expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) assert_identical(expected, actual) with raises_regex(ValueError, 'dimensions'): - ds.roll(foo=123) + ds.roll(foo=123, roll_coords=True) + + def test_roll_no_coords(self): + coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} + attrs = {'meta': 'data'} + ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + actual = ds.roll(x=1, roll_coords=False) + + expected = Dataset({'foo': ('x', [3, 1, 2])}, coords, attrs) + assert_identical(expected, actual) + + with raises_regex(ValueError, 'dimensions'): + ds.roll(abc=321, roll_coords=False) + + def test_roll_coords_none(self): + coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} + attrs = {'meta': 'data'} + ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + + with pytest.warns(FutureWarning): + actual = ds.roll(x=1, roll_coords=None) + + ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]} + expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) + assert_identical(expected, actual) + + def test_roll_multidim(self): + # regression test for 2445 + arr = xr.DataArray( + [[1, 2, 3], [4, 5, 6]], coords={'x': range(3), 'y': range(2)}, + dims=('y', 'x')) + actual = arr.roll(x=1, roll_coords=True) + expected = xr.DataArray([[3, 1, 2], [6, 4, 5]], + coords=[('y', [0, 1]), ('x', [2, 0, 1])]) + assert_identical(expected, actual) def test_real_and_imag(self): attrs = {'foo': 'bar'} @@ -3906,6 +4008,26 @@ def test_filter_by_attrs(self): for var in new_ds.data_vars: assert new_ds[var].height == '10 m' + # Test return empty Dataset due to conflicting filters + new_ds = ds.filter_by_attrs( + standard_name='convective_precipitation_flux', + height='0 m') + assert not bool(new_ds.data_vars) + + # Test return one DataArray with two filter conditions + new_ds = ds.filter_by_attrs( + standard_name='air_potential_temperature', + height='0 m') + for var in new_ds.data_vars: + assert new_ds[var].standard_name == 'air_potential_temperature' + assert new_ds[var].height == '0 m' + assert new_ds[var].height != '10 m' + + # Test return empty Dataset due to conflicting callables + new_ds = ds.filter_by_attrs(standard_name=lambda v: False, + height=lambda v: True) + assert not bool(new_ds.data_vars) + def test_binary_op_join_setting(self): # arithmetic_join applies to data array coordinates missing_2 = xr.Dataset({'x': [0, 1]}) @@ -4218,6 +4340,11 @@ def test_dataset_constructor_aligns_to_explicit_coords( assert_equal(expected, result) +def test_error_message_on_set_supplied(): + with pytest.raises(TypeError, message='has invalid type set'): + xr.Dataset(dict(date=[1, 2, 3], sec={4})) + + @pytest.mark.parametrize('unaligned_coords', ( {'y': ('b', np.asarray([2, 1, 0]))}, )) @@ -4399,3 +4526,107 @@ def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: Dataset(data_vars={'x': ('y', [1, 2, np.NaN])}) > 0 assert len(record) == 0 + + +@pytest.mark.parametrize('dask', [True, False]) +@pytest.mark.parametrize('edge_order', [1, 2]) +def test_differentiate(dask, edge_order): + rs = np.random.RandomState(42) + coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] + + da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], + coords={'x': coord, + 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + if dask and has_dask: + da = da.chunk({'x': 4}) + + ds = xr.Dataset({'var': da}) + + # along x + actual = da.differentiate('x', edge_order) + expected_x = xr.DataArray( + npcompat.gradient(da, da['x'], axis=0, edge_order=edge_order), + dims=da.dims, coords=da.coords) + assert_equal(expected_x, actual) + assert_equal(ds['var'].differentiate('x', edge_order=edge_order), + ds.differentiate('x', edge_order=edge_order)['var']) + # coordinate should not change + assert_equal(da['x'], actual['x']) + + # along y + actual = da.differentiate('y', edge_order) + expected_y = xr.DataArray( + npcompat.gradient(da, da['y'], axis=1, edge_order=edge_order), + dims=da.dims, coords=da.coords) + assert_equal(expected_y, actual) + assert_equal(actual, ds.differentiate('y', edge_order=edge_order)['var']) + assert_equal(ds['var'].differentiate('y', edge_order=edge_order), + ds.differentiate('y', edge_order=edge_order)['var']) + + with pytest.raises(ValueError): + da.differentiate('x2d') + + +@pytest.mark.parametrize('dask', [True, False]) +def test_differentiate_datetime(dask): + rs = np.random.RandomState(42) + coord = np.array( + ['2004-07-13', '2006-01-13', '2010-08-13', '2010-09-13', + '2010-10-11', '2010-12-13', '2011-02-13', '2012-08-13'], + dtype='datetime64') + + da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], + coords={'x': coord, + 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + if dask and has_dask: + da = da.chunk({'x': 4}) + + # along x + actual = da.differentiate('x', edge_order=1, datetime_unit='D') + expected_x = xr.DataArray( + npcompat.gradient( + da, utils.datetime_to_numeric(da['x'], datetime_unit='D'), + axis=0, edge_order=1), dims=da.dims, coords=da.coords) + assert_equal(expected_x, actual) + + actual2 = da.differentiate('x', edge_order=1, datetime_unit='h') + assert np.allclose(actual, actual2 * 24) + + # for datetime variable + actual = da['x'].differentiate('x', edge_order=1, datetime_unit='D') + assert np.allclose(actual, 1.0) + + # with different date unit + da = xr.DataArray(coord.astype('datetime64[ms]'), dims=['x'], + coords={'x': coord}) + actual = da.differentiate('x', edge_order=1) + assert np.allclose(actual, 1.0) + + +@pytest.mark.skipif(not has_cftime, reason='Test requires cftime.') +@pytest.mark.parametrize('dask', [True, False]) +def test_differentiate_cftime(dask): + rs = np.random.RandomState(42) + coord = xr.cftime_range('2000', periods=8, freq='2M') + + da = xr.DataArray( + rs.randn(8, 6), + coords={'time': coord, 'z': 3, 't2d': (('time', 'y'), rs.randn(8, 6))}, + dims=['time', 'y']) + + if dask and has_dask: + da = da.chunk({'time': 4}) + + actual = da.differentiate('time', edge_order=1, datetime_unit='D') + expected_data = npcompat.gradient( + da, utils.datetime_to_numeric(da['time'], datetime_unit='D'), + axis=0, edge_order=1) + expected = xr.DataArray(expected_data, coords=da.coords, dims=da.dims) + assert_equal(expected, actual) + + actual2 = da.differentiate('time', edge_order=1, datetime_unit='h') + assert_allclose(actual, actual2 * 24) + + # Test the differentiation of datetimes themselves + actual = da['time'].differentiate('time', edge_order=1, datetime_unit='D') + assert_allclose(actual, xr.ones_like(da['time']).astype(float)) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 32035afdc57..1837a0fe4ef 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -15,16 +15,18 @@ from distributed.utils_test import cluster, gen_cluster from distributed.utils_test import loop # flake8: noqa from distributed.client import futures_of +import numpy as np import xarray as xr +from xarray.backends.locks import HDF5_LOCK, CombinedLock from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file, - create_tmp_geotiff) + create_tmp_geotiff, + open_example_dataset) from xarray.tests.test_dataset import create_test_data -from xarray.backends.common import HDF5_LOCK, CombinedLock from . import ( assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy, - requires_zarr, raises_regex) + requires_zarr, requires_cfgrib, raises_regex) # this is to stop isort throwing errors. May have been easier to just use # `isort:skip` in retrospect @@ -33,6 +35,11 @@ da = pytest.importorskip('dask.array') +@pytest.fixture +def tmp_netcdf_filename(tmpdir): + return str(tmpdir.join('testfile.nc')) + + ENGINES = [] if has_scipy: ENGINES.append('scipy') @@ -45,81 +52,69 @@ 'NETCDF3_64BIT_DATA', 'NETCDF4_CLASSIC', 'NETCDF4'], 'scipy': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT'], 'h5netcdf': ['NETCDF4']} -TEST_FORMATS = ['NETCDF3_CLASSIC', 'NETCDF4_CLASSIC', 'NETCDF4'] - - -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ['netcdf4']) -@pytest.mark.parametrize('autoclose', [True, False]) -@pytest.mark.parametrize('nc_format', TEST_FORMATS) -def test_dask_distributed_netcdf_roundtrip(monkeypatch, loop, - engine, autoclose, nc_format): - monkeypatch.setenv('HDF5_USE_FILE_LOCKING', 'FALSE') +ENGINES_AND_FORMATS = [ + ('netcdf4', 'NETCDF3_CLASSIC'), + ('netcdf4', 'NETCDF4_CLASSIC'), + ('netcdf4', 'NETCDF4'), + ('h5netcdf', 'NETCDF4'), + ('scipy', 'NETCDF3_64BIT'), +] - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_netcdf_roundtrip( + loop, tmp_netcdf_filename, engine, nc_format): - original = create_test_data().chunk(chunks) - original.to_netcdf(filename, engine=engine, format=nc_format) - - with xr.open_dataset(filename, - chunks=chunks, - engine=engine, - autoclose=autoclose) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) + if engine not in ENGINES: + pytest.skip('engine not available') + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ENGINES) -@pytest.mark.parametrize('autoclose', [True, False]) -@pytest.mark.parametrize('nc_format', TEST_FORMATS) -def test_dask_distributed_read_netcdf_integration_test(loop, engine, autoclose, - nc_format): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: - if engine == 'h5netcdf' and autoclose: - pytest.skip('h5netcdf does not support autoclose') + original = create_test_data().chunk(chunks) - if nc_format not in NC_FORMATS[engine]: - pytest.skip('invalid format for engine') + if engine == 'scipy': + with pytest.raises(NotImplementedError): + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) + return - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) - original = create_test_data() - original.to_netcdf(filename, engine=engine, format=nc_format) - with xr.open_dataset(filename, - chunks=chunks, - engine=engine, - autoclose=autoclose) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_read_netcdf_integration_test( + loop, tmp_netcdf_filename, engine, nc_format): + if engine not in ENGINES: + pytest.skip('engine not available') -@pytest.mark.parametrize('engine', ['h5netcdf', 'scipy']) -def test_dask_distributed_netcdf_integration_test_not_implemented(loop, engine): chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: - original = create_test_data().chunk(chunks) + original = create_test_data() + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) + + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, + engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) - with raises_regex(NotImplementedError, 'distributed'): - original.to_netcdf(filename, engine=engine) @requires_zarr @@ -148,6 +143,20 @@ def test_dask_distributed_rasterio_integration_test(loop): assert_allclose(actual, expected) +@requires_cfgrib +def test_dask_distributed_cfgrib_integration_test(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + with open_example_dataset('example.grib', + engine='cfgrib', + chunks={'time': 1}) as ds: + with open_example_dataset('example.grib', + engine='cfgrib') as expected: + assert isinstance(ds['t'].data, da.Array) + actual = ds.compute() + assert_allclose(actual, expected) + + @pytest.mark.skipif(distributed.__version__ <= '1.19.3', reason='Need recent distributed version to clean up get') @gen_cluster(client=True, timeout=None) diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 833df85f8af..292c60b4d05 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -50,3 +50,39 @@ def error(): def test_inf(obj): assert dtypes.INF > obj assert dtypes.NINF < obj + + +@pytest.mark.parametrize("kind, expected", [ + ('a', (np.dtype('O'), 'nan')), # dtype('S') + ('b', (np.float32, 'nan')), # dtype('int8') + ('B', (np.float32, 'nan')), # dtype('uint8') + ('c', (np.dtype('O'), 'nan')), # dtype('S1') + ('D', (np.complex128, '(nan+nanj)')), # dtype('complex128') + ('d', (np.float64, 'nan')), # dtype('float64') + ('e', (np.float16, 'nan')), # dtype('float16') + ('F', (np.complex64, '(nan+nanj)')), # dtype('complex64') + ('f', (np.float32, 'nan')), # dtype('float32') + ('h', (np.float32, 'nan')), # dtype('int16') + ('H', (np.float32, 'nan')), # dtype('uint16') + ('i', (np.float64, 'nan')), # dtype('int32') + ('I', (np.float64, 'nan')), # dtype('uint32') + ('l', (np.float64, 'nan')), # dtype('int64') + ('L', (np.float64, 'nan')), # dtype('uint64') + ('m', (np.timedelta64, 'NaT')), # dtype(' 0: + assert isinstance(da.data, dask_array_type) + + @pytest.mark.parametrize('dim_num', [1, 2]) @pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) @pytest.mark.parametrize('dask', [False, True]) @pytest.mark.parametrize('func', ['sum', 'min', 'max', 'mean', 'var']) +# TODO test cumsum, cumprod @pytest.mark.parametrize('skipna', [False, True]) @pytest.mark.parametrize('aggdim', [None, 'x']) def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): @@ -251,6 +269,9 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): if dask and not has_dask: pytest.skip('requires dask') + if dask and skipna is False and dtype in [np.bool_]: + pytest.skip('dask does not compute object-typed array') + rtol = 1e-04 if dtype == np.float32 else 1e-05 da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) @@ -259,6 +280,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): # TODO: remove these after resolving # https://github.com/dask/dask/issues/3245 with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') warnings.filterwarnings('ignore', 'All-NaN slice') warnings.filterwarnings('ignore', 'invalid value encountered in') @@ -272,6 +294,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): expected = getattr(np, func)(da.values, axis=axis) actual = getattr(da, func)(skipna=skipna, dim=aggdim) + assert_dask_array(actual, dask) assert np.allclose(actual.values, np.array(expected), rtol=1.0e-4, equal_nan=True) except (TypeError, AttributeError, ZeroDivisionError): @@ -279,14 +302,21 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): # nanmean for object dtype pass - # make sure the compatiblility with pandas' results. actual = getattr(da, func)(skipna=skipna, dim=aggdim) - if func == 'var': + + # for dask case, make sure the result is the same for numpy backend + expected = getattr(da.compute(), func)(skipna=skipna, dim=aggdim) + assert_allclose(actual, expected, rtol=rtol) + + # make sure the compatiblility with pandas' results. + if func in ['var', 'std']: expected = series_reduce(da, func, skipna=skipna, dim=aggdim, ddof=0) assert_allclose(actual, expected, rtol=rtol) # also check ddof!=0 case actual = getattr(da, func)(skipna=skipna, dim=aggdim, ddof=5) + if dask: + assert isinstance(da.data, dask_array_type) expected = series_reduce(da, func, skipna=skipna, dim=aggdim, ddof=5) assert_allclose(actual, expected, rtol=rtol) @@ -297,11 +327,14 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): # make sure the dtype argument if func not in ['max', 'min']: actual = getattr(da, func)(skipna=skipna, dim=aggdim, dtype=float) + assert_dask_array(actual, dask) assert actual.dtype == float # without nan da = construct_dataarray(dim_num, dtype, contains_nan=False, dask=dask) actual = getattr(da, func)(skipna=skipna) + if dask: + assert isinstance(da.data, dask_array_type) expected = getattr(np, 'nan{}'.format(func))(da.values) if actual.dtype == object: assert actual.values == np.array(expected) @@ -338,13 +371,6 @@ def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim): with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'All-NaN slice') - if aggdim == 'y' and contains_nan and skipna: - with pytest.raises(ValueError): - actual = da.isel(**{ - aggdim: getattr(da, 'arg' + func)( - dim=aggdim, skipna=skipna).compute()}) - return - actual = da.isel(**{aggdim: getattr(da, 'arg' + func) (dim=aggdim, skipna=skipna).compute()}) expected = getattr(da, func)(dim=aggdim, skipna=skipna) @@ -354,6 +380,7 @@ def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim): def test_argmin_max_error(): da = construct_dataarray(2, np.bool_, contains_nan=True, dask=False) + da[0] = np.nan with pytest.raises(ValueError): da.argmin(dim='y') @@ -388,3 +415,139 @@ def test_dask_rolling(axis, window, center): with pytest.raises(ValueError): rolling_window(dx, axis=axis, window=100, center=center, fill_value=np.nan) + + +@pytest.mark.skipif(not has_dask, reason='This is for dask.') +@pytest.mark.parametrize('axis', [0, -1, 1]) +@pytest.mark.parametrize('edge_order', [1, 2]) +def test_dask_gradient(axis, edge_order): + import dask.array as da + + array = np.array(np.random.randn(100, 5, 40)) + x = np.exp(np.linspace(0, 1, array.shape[axis])) + + darray = da.from_array(array, chunks=[(6, 30, 30, 20, 14), 5, 8]) + expected = gradient(array, x, axis=axis, edge_order=edge_order) + actual = gradient(darray, x, axis=axis, edge_order=edge_order) + + assert isinstance(actual, da.Array) + assert_array_equal(actual, expected) + + +@pytest.mark.parametrize('dim_num', [1, 2]) +@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('func', ['sum', 'prod']) +@pytest.mark.parametrize('aggdim', [None, 'x']) +def test_min_count(dim_num, dtype, dask, func, aggdim): + if dask and not has_dask: + pytest.skip('requires dask') + + da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) + min_count = 3 + + actual = getattr(da, func)(dim=aggdim, skipna=True, min_count=min_count) + + if LooseVersion(pd.__version__) >= LooseVersion('0.22.0'): + # min_count is only implenented in pandas > 0.22 + expected = series_reduce(da, func, skipna=True, dim=aggdim, + min_count=min_count) + assert_allclose(actual, expected) + + assert_dask_array(actual, dask) + + +@pytest.mark.parametrize('func', ['sum', 'prod']) +def test_min_count_dataset(func): + da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False) + ds = Dataset({'var1': da}, coords={'scalar': 0}) + actual = getattr(ds, func)(dim='x', skipna=True, min_count=3)['var1'] + expected = getattr(ds['var1'], func)(dim='x', skipna=True, min_count=3) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('func', ['sum', 'prod']) +def test_multiple_dims(dtype, dask, func): + if dask and not has_dask: + pytest.skip('requires dask') + da = construct_dataarray(3, dtype, contains_nan=True, dask=dask) + + actual = getattr(da, func)(('x', 'y')) + expected = getattr(getattr(da, func)('x'), func)('y') + assert_allclose(actual, expected) + + +def test_docs(): + # with min_count + actual = DataArray.sum.__doc__ + expected = dedent("""\ + Reduce this DataArray's data by applying `sum` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply `sum`. + axis : int or sequence of int, optional + Axis(es) over which to apply `sum`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + `sum` is calculated over axes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + min_count : int, default None + The required number of valid values to perform the operation. + If fewer than min_count non-NA values are present the result will + be NA. New in version 0.10.8: Added with the default being None. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `sum` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with `sum` applied to its data and the + indicated dimension(s) removed. + """) + assert actual == expected + + # without min_count + actual = DataArray.std.__doc__ + expected = dedent("""\ + Reduce this DataArray's data by applying `std` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply `std`. + axis : int or sequence of int, optional + Axis(es) over which to apply `std`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + `std` is calculated over axes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `std` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with `std` applied to its data and the + indicated dimension(s) removed. + """) + assert actual == expected diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index 24b710ae223..ffefa78aa34 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -4,7 +4,7 @@ import xarray as xr -from . import TestCase, raises_regex +from . import raises_regex try: import cPickle as pickle @@ -21,7 +21,7 @@ def __init__(self, xarray_obj): self.obj = xarray_obj -class TestAccessor(TestCase): +class TestAccessor(object): def test_register(self): @xr.register_dataset_accessor('demo') diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 34552891778..024c669bed9 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -7,34 +7,55 @@ from xarray.core import formatting from xarray.core.pycompat import PY3 -from . import TestCase, raises_regex +from . import raises_regex -class TestFormatting(TestCase): +class TestFormatting(object): def test_get_indexer_at_least_n_items(self): cases = [ - ((20,), (slice(10),)), - ((3, 20,), (0, slice(10))), - ((2, 10,), (0, slice(10))), - ((2, 5,), (slice(2), slice(None))), - ((1, 2, 5,), (0, slice(2), slice(None))), - ((2, 3, 5,), (0, slice(2), slice(None))), - ((1, 10, 1,), (0, slice(10), slice(None))), - ((2, 5, 1,), (slice(2), slice(None), slice(None))), - ((2, 5, 3,), (0, slice(4), slice(None))), - ((2, 3, 3,), (slice(2), slice(None), slice(None))), + ((20,), (slice(10),), (slice(-10, None),)), + ((3, 20,), (0, slice(10)), (-1, slice(-10, None))), + ((2, 10,), (0, slice(10)), (-1, slice(-10, None))), + ((2, 5,), (slice(2), slice(None)), + (slice(-2, None), slice(None))), + ((1, 2, 5,), (0, slice(2), slice(None)), + (-1, slice(-2, None), slice(None))), + ((2, 3, 5,), (0, slice(2), slice(None)), + (-1, slice(-2, None), slice(None))), + ((1, 10, 1,), (0, slice(10), slice(None)), + (-1, slice(-10, None), slice(None))), + ((2, 5, 1,), (slice(2), slice(None), slice(None)), + (slice(-2, None), slice(None), slice(None))), + ((2, 5, 3,), (0, slice(4), slice(None)), + (-1, slice(-4, None), slice(None))), + ((2, 3, 3,), (slice(2), slice(None), slice(None)), + (slice(-2, None), slice(None), slice(None))), ] - for shape, expected in cases: - actual = formatting._get_indexer_at_least_n_items(shape, 10) - assert expected == actual + for shape, start_expected, end_expected in cases: + actual = formatting._get_indexer_at_least_n_items(shape, 10, + from_end=False) + assert start_expected == actual + actual = formatting._get_indexer_at_least_n_items(shape, 10, + from_end=True) + assert end_expected == actual def test_first_n_items(self): array = np.arange(100).reshape(10, 5, 2) for n in [3, 10, 13, 100, 200]: actual = formatting.first_n_items(array, n) expected = array.flat[:n] - self.assertItemsEqual(expected, actual) + assert (expected == actual).all() + + with raises_regex(ValueError, 'at least one item'): + formatting.first_n_items(array, 0) + + def test_last_n_items(self): + array = np.arange(100).reshape(10, 5, 2) + for n in [3, 10, 13, 100, 200]: + actual = formatting.last_n_items(array, n) + expected = array.flat[-n:] + assert (expected == actual).all() with raises_regex(ValueError, 'at least one item'): formatting.first_n_items(array, 0) @@ -87,16 +108,32 @@ def test_format_items(self): assert expected == actual def test_format_array_flat(self): + actual = formatting.format_array_flat(np.arange(100), 2) + expected = '0 ... 99' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100), 9) + expected = '0 ... 99' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100), 10) + expected = '0 1 ... 99' + assert expected == actual + actual = formatting.format_array_flat(np.arange(100), 13) - expected = '0 1 2 3 4 ...' + expected = '0 1 ... 98 99' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100), 15) + expected = '0 1 2 ... 98 99' assert expected == actual actual = formatting.format_array_flat(np.arange(100.0), 11) - expected = '0.0 1.0 ...' + expected = '0.0 ... 99.0' assert expected == actual actual = formatting.format_array_flat(np.arange(100.0), 1) - expected = '0.0 ...' + expected = '0.0 ... 99.0' assert expected == actual actual = formatting.format_array_flat(np.arange(3), 5) @@ -104,11 +141,23 @@ def test_format_array_flat(self): assert expected == actual actual = formatting.format_array_flat(np.arange(4.0), 11) - expected = '0.0 1.0 ...' + expected = '0.0 ... 3.0' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(0), 0) + expected = '' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(1), 0) + expected = '0' + assert expected == actual + + actual = formatting.format_array_flat(np.arange(2), 0) + expected = '0 1' assert expected == actual actual = formatting.format_array_flat(np.arange(4), 0) - expected = '0 ...' + expected = '0 ... 3' assert expected == actual def test_pretty_print(self): diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 6dd14f5d6ad..8ace55be66b 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -5,9 +5,10 @@ import pytest import xarray as xr -from . import assert_identical from xarray.core.groupby import _consolidate_slices +from . import assert_identical + def test_consolidate_slices(): diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 0d1045d35c0..701eefcb462 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -10,13 +10,12 @@ from xarray.core import indexing, nputils from xarray.core.pycompat import native_int_types -from . import ( - IndexerMaker, ReturnItem, TestCase, assert_array_equal, raises_regex) +from . import IndexerMaker, ReturnItem, assert_array_equal, raises_regex B = IndexerMaker(indexing.BasicIndexer) -class TestIndexers(TestCase): +class TestIndexers(object): def set_to_zero(self, x, i): x = x.copy() x[i] = 0 @@ -25,7 +24,7 @@ def set_to_zero(self, x, i): def test_expanded_indexer(self): x = np.random.randn(10, 11, 12, 13, 14) y = np.arange(5) - I = ReturnItem() # noqa: E741 # allow ambiguous name + I = ReturnItem() # noqa for i in [I[:], I[...], I[0, :, 10], I[..., 10], I[:5, ..., 0], I[..., 0, :], I[y], I[y, y], I[..., y, y], I[..., 0, 1, 2, 3, 4]]: @@ -133,7 +132,7 @@ def test_indexer(data, x, expected_pos, expected_idx=None): pd.MultiIndex.from_product([[1, 2], [-1, -2]])) -class TestLazyArray(TestCase): +class TestLazyArray(object): def test_slice_slice(self): I = ReturnItem() # noqa: E741 # allow ambiguous name for size in [100, 99]: @@ -248,7 +247,7 @@ def check_indexing(v_eager, v_lazy, indexers): check_indexing(v_eager, v_lazy, indexers) -class TestCopyOnWriteArray(TestCase): +class TestCopyOnWriteArray(object): def test_setitem(self): original = np.arange(10) wrapped = indexing.CopyOnWriteArray(original) @@ -272,7 +271,7 @@ def test_index_scalar(self): assert np.array(x[B[0]][B[()]]) == 'foo' -class TestMemoryCachedArray(TestCase): +class TestMemoryCachedArray(object): def test_wrapper(self): original = indexing.LazilyOuterIndexedArray(np.arange(10)) wrapped = indexing.MemoryCachedArray(original) @@ -385,8 +384,9 @@ def test_vectorized_indexer(): np.arange(5, dtype=np.int64))) -class Test_vectorized_indexer(TestCase): - def setUp(self): +class Test_vectorized_indexer(object): + @pytest.fixture(autouse=True) + def setup(self): self.data = indexing.NumpyIndexingAdapter(np.random.randn(10, 12, 13)) self.indexers = [np.array([[0, 3, 2], ]), np.array([[0, 3, 3], [4, 6, 7]]), diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 69a4644bc97..624879cce1f 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -5,8 +5,11 @@ import pytest import xarray as xr -from xarray.tests import assert_allclose, assert_equal, requires_scipy +from xarray.tests import ( + assert_allclose, assert_equal, requires_cftime, requires_scipy) + from . import has_dask, has_scipy +from ..coding.cftimeindex import _parse_array_of_cftime_strings from .test_dataset import create_test_data try: @@ -462,19 +465,111 @@ def test_interp_like(): @requires_scipy -def test_datetime(): - da = xr.DataArray(np.random.randn(24), dims='time', +@pytest.mark.parametrize('x_new, expected', [ + (pd.date_range('2000-01-02', periods=3), [1, 2, 3]), + (np.array([np.datetime64('2000-01-01T12:00'), + np.datetime64('2000-01-02T12:00')]), [0.5, 1.5]), + (['2000-01-01T12:00', '2000-01-02T12:00'], [0.5, 1.5]), + (['2000-01-01T12:00'], 0.5), + pytest.param('2000-01-01T12:00', 0.5, marks=pytest.mark.xfail) +]) +def test_datetime(x_new, expected): + da = xr.DataArray(np.arange(24), dims='time', coords={'time': pd.date_range('2000-01-01', periods=24)}) - x_new = pd.date_range('2000-01-02', periods=3) actual = da.interp(time=x_new) - expected = da.isel(time=[1, 2, 3]) + expected_da = xr.DataArray(np.atleast_1d(expected), dims=['time'], + coords={'time': (np.atleast_1d(x_new) + .astype('datetime64[ns]'))}) + + assert_allclose(actual, expected_da) + + +@requires_scipy +def test_datetime_single_string(): + da = xr.DataArray(np.arange(24), dims='time', + coords={'time': pd.date_range('2000-01-01', periods=24)}) + actual = da.interp(time='2000-01-01T12:00') + expected = xr.DataArray(0.5) + + assert_allclose(actual.drop('time'), expected) + + +@requires_cftime +@requires_scipy +def test_cftime(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D') + actual = da.interp(time=times_new) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new], dims=['time']) + assert_allclose(actual, expected) - x_new = np.array([np.datetime64('2000-01-01T12:00'), - np.datetime64('2000-01-02T12:00')]) - actual = da.interp(time=x_new) - assert_allclose(actual.isel(time=0).drop('time'), - 0.5 * (da.isel(time=0) + da.isel(time=1))) - assert_allclose(actual.isel(time=1).drop('time'), - 0.5 * (da.isel(time=1) + da.isel(time=2))) + +@requires_cftime +@requires_scipy +def test_cftime_type_error(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D', + calendar='noleap') + with pytest.raises(TypeError): + da.interp(time=times_new) + + +@requires_cftime +@requires_scipy +def test_cftime_list_of_strings(): + from cftime import DatetimeProlepticGregorian + + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = ['2000-01-01T12:00', '2000-01-02T12:00', '2000-01-03T12:00'] + actual = da.interp(time=times_new) + + times_new_array = _parse_array_of_cftime_strings( + np.array(times_new), DatetimeProlepticGregorian) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new_array], + dims=['time']) + + assert_allclose(actual, expected) + + +@requires_cftime +@requires_scipy +def test_cftime_single_string(): + from cftime import DatetimeProlepticGregorian + + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = '2000-01-01T12:00' + actual = da.interp(time=times_new) + + times_new_array = _parse_array_of_cftime_strings( + np.array(times_new), DatetimeProlepticGregorian) + expected = xr.DataArray(0.5, coords={'time': times_new_array}) + + assert_allclose(actual, expected) + + +@requires_scipy +def test_datetime_to_non_datetime_error(): + da = xr.DataArray(np.arange(24), dims='time', + coords={'time': pd.date_range('2000-01-01', periods=24)}) + with pytest.raises(TypeError): + da.interp(time=0.5) + + +@requires_cftime +@requires_scipy +def test_cftime_to_non_cftime_error(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + with pytest.raises(TypeError): + da.interp(time=0.5) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 4d89be8ce55..300c490cff6 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -6,11 +6,11 @@ import xarray as xr from xarray.core import merge -from . import TestCase, raises_regex +from . import raises_regex from .test_dataset import create_test_data -class TestMergeInternals(TestCase): +class TestMergeInternals(object): def test_broadcast_dimension_size(self): actual = merge.broadcast_dimension_size( [xr.Variable('x', [1]), xr.Variable('y', [2, 1])]) @@ -25,7 +25,7 @@ def test_broadcast_dimension_size(self): [xr.Variable(('x', 'y'), [[1, 2]]), xr.Variable('y', [2])]) -class TestMergeFunction(TestCase): +class TestMergeFunction(object): def test_merge_arrays(self): data = create_test_data() actual = xr.merge([data.var1, data.var2]) @@ -130,7 +130,7 @@ def test_merge_no_conflicts_broadcast(self): assert expected.identical(actual) -class TestMergeMethod(TestCase): +class TestMergeMethod(object): def test_merge(self): data = create_test_data() @@ -195,7 +195,7 @@ def test_merge_compat(self): with pytest.raises(xr.MergeError): ds1.merge(ds2, compat='identical') - with raises_regex(ValueError, 'compat=\S+ invalid'): + with raises_regex(ValueError, 'compat=.* invalid'): ds1.merge(ds2, compat='foobar') def test_merge_auto_align(self): diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 5c7e384c789..47224e55473 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -93,14 +93,14 @@ def test_interpolate_pd_compat(): @requires_scipy -def test_scipy_methods_function(): - for method in ['barycentric', 'krog', 'pchip', 'spline', 'akima']: - kwargs = {} - # Note: Pandas does some wacky things with these methods and the full - # integration tests wont work. - da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) - actual = da.interpolate_na(method=method, dim='time', **kwargs) - assert (da.count('time') <= actual.count('time')).all() +@pytest.mark.parametrize('method', ['barycentric', 'krog', + 'pchip', 'spline', 'akima']) +def test_scipy_methods_function(method): + # Note: Pandas does some wacky things with these methods and the full + # integration tests wont work. + da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) + actual = da.interpolate_na(method=method, dim='time') + assert (da.count('time') <= actual.count('time')).all() @requires_scipy diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index aed96f1acb6..d594e1dcd18 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -3,7 +3,10 @@ import pytest import xarray -from xarray.core.options import OPTIONS +from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.backends.file_manager import FILE_CACHE +from xarray.tests.test_dataset import create_test_data +from xarray import concat, merge def test_invalid_option_raises(): @@ -11,6 +14,51 @@ def test_invalid_option_raises(): xarray.set_options(not_a_valid_options=True) +def test_display_width(): + with pytest.raises(ValueError): + xarray.set_options(display_width=0) + with pytest.raises(ValueError): + xarray.set_options(display_width=-10) + with pytest.raises(ValueError): + xarray.set_options(display_width=3.5) + + +def test_arithmetic_join(): + with pytest.raises(ValueError): + xarray.set_options(arithmetic_join='invalid') + with xarray.set_options(arithmetic_join='exact'): + assert OPTIONS['arithmetic_join'] == 'exact' + + +def test_enable_cftimeindex(): + with pytest.raises(ValueError): + xarray.set_options(enable_cftimeindex=None) + with pytest.warns(FutureWarning, match='no-op'): + with xarray.set_options(enable_cftimeindex=True): + assert OPTIONS['enable_cftimeindex'] + + +def test_file_cache_maxsize(): + with pytest.raises(ValueError): + xarray.set_options(file_cache_maxsize=0) + original_size = FILE_CACHE.maxsize + with xarray.set_options(file_cache_maxsize=123): + assert FILE_CACHE.maxsize == 123 + assert FILE_CACHE.maxsize == original_size + + +def test_keep_attrs(): + with pytest.raises(ValueError): + xarray.set_options(keep_attrs='invalid_str') + with xarray.set_options(keep_attrs=True): + assert OPTIONS['keep_attrs'] + with xarray.set_options(keep_attrs=False): + assert not OPTIONS['keep_attrs'] + with xarray.set_options(keep_attrs='default'): + assert _get_keep_attrs(default=True) + assert not _get_keep_attrs(default=False) + + def test_nested_options(): original = OPTIONS['display_width'] with xarray.set_options(display_width=1): @@ -19,3 +67,105 @@ def test_nested_options(): assert OPTIONS['display_width'] == 2 assert OPTIONS['display_width'] == 1 assert OPTIONS['display_width'] == original + + +def create_test_dataset_attrs(seed=0): + ds = create_test_data(seed) + ds.attrs = {'attr1': 5, 'attr2': 'history', + 'attr3': {'nested': 'more_info'}} + return ds + + +def create_test_dataarray_attrs(seed=0, var='var1'): + da = create_test_data(seed)[var] + da.attrs = {'attr1': 5, 'attr2': 'history', + 'attr3': {'nested': 'more_info'}} + return da + + +class TestAttrRetention(object): + def test_dataset_attr_retention(self): + # Use .mean() for all tests: a typical reduction operation + ds = create_test_dataset_attrs() + original_attrs = ds.attrs + + # Test default behaviour + result = ds.mean() + assert result.attrs == {} + with xarray.set_options(keep_attrs='default'): + result = ds.mean() + assert result.attrs == {} + + with xarray.set_options(keep_attrs=True): + result = ds.mean() + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=False): + result = ds.mean() + assert result.attrs == {} + + def test_dataarray_attr_retention(self): + # Use .mean() for all tests: a typical reduction operation + da = create_test_dataarray_attrs() + original_attrs = da.attrs + + # Test default behaviour + result = da.mean() + assert result.attrs == {} + with xarray.set_options(keep_attrs='default'): + result = da.mean() + assert result.attrs == {} + + with xarray.set_options(keep_attrs=True): + result = da.mean() + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=False): + result = da.mean() + assert result.attrs == {} + + def test_groupby_attr_retention(self): + da = xarray.DataArray([1, 2, 3], [('x', [1, 1, 2])]) + da.attrs = {'attr1': 5, 'attr2': 'history', + 'attr3': {'nested': 'more_info'}} + original_attrs = da.attrs + + # Test default behaviour + result = da.groupby('x').sum(keep_attrs=True) + assert result.attrs == original_attrs + with xarray.set_options(keep_attrs='default'): + result = da.groupby('x').sum(keep_attrs=True) + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=True): + result1 = da.groupby('x') + result = result1.sum() + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=False): + result = da.groupby('x').sum() + assert result.attrs == {} + + def test_concat_attr_retention(self): + ds1 = create_test_dataset_attrs() + ds2 = create_test_dataset_attrs() + ds2.attrs = {'wrong': 'attributes'} + original_attrs = ds1.attrs + + # Test default behaviour of keeping the attrs of the first + # dataset in the supplied list + # global keep_attrs option current doesn't affect concat + result = concat([ds1, ds2], dim='dim1') + assert result.attrs == original_attrs + + @pytest.mark.xfail + def test_merge_attr_retention(self): + da1 = create_test_dataarray_attrs(var='var1') + da2 = create_test_dataarray_attrs(var='var2') + da2.attrs = {'wrong': 'attributes'} + original_attrs = da1.attrs + + # merge currently discards attrs, and the global keep_attrs + # option doesn't affect this + result = merge([da1, da2]) + assert result.attrs == original_attrs diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index f56d1f460f3..39fd55fece6 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -7,6 +7,7 @@ import pandas as pd import pytest +import xarray as xr import xarray.plot as xplt from xarray import DataArray, Dataset from xarray.coding.times import _import_cftime @@ -16,9 +17,8 @@ import_seaborn, label_from_attrs) from . import ( - TestCase, assert_array_equal, assert_equal, raises_regex, - requires_matplotlib, requires_matplotlib2, requires_seaborn, - requires_cftime) + assert_array_equal, assert_equal, raises_regex, requires_cftime, + requires_matplotlib, requires_matplotlib2, requires_seaborn) # import mpl and change the backend before other mpl imports try: @@ -64,8 +64,10 @@ def easy_array(shape, start=0, stop=1): @requires_matplotlib -class PlotTestCase(TestCase): - def tearDown(self): +class PlotTestCase(object): + @pytest.fixture(autouse=True) + def setup(self): + yield # Remove all matplotlib figures plt.close('all') @@ -87,7 +89,8 @@ def contourf_called(self, plotmethod): class TestPlot(PlotTestCase): - def setUp(self): + @pytest.fixture(autouse=True) + def setup_array(self): self.darray = DataArray(easy_array((2, 3, 4))) def test_label_from_attrs(self): @@ -159,8 +162,8 @@ def test_2d_line_accepts_legend_kw(self): self.darray[:, :, 0].plot.line(x='dim_0', add_legend=True) assert plt.gca().get_legend() # check whether legend title is set - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_1' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_1') def test_2d_line_accepts_x_kw(self): self.darray[:, :, 0].plot.line(x='dim_0') @@ -171,12 +174,31 @@ def test_2d_line_accepts_x_kw(self): def test_2d_line_accepts_hue_kw(self): self.darray[:, :, 0].plot.line(hue='dim_0') - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_0' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_0') plt.cla() self.darray[:, :, 0].plot.line(hue='dim_1') - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_1' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_1') + + def test_2d_coords_line_plot(self): + lon, lat = np.meshgrid(np.linspace(-20, 20, 5), + np.linspace(0, 30, 4)) + lon += lat / 10 + lat += lon / 10 + da = xr.DataArray(np.arange(20).reshape(4, 5), dims=['y', 'x'], + coords={'lat': (('y', 'x'), lat), + 'lon': (('y', 'x'), lon)}) + + hdl = da.plot.line(x='lon', hue='x') + assert len(hdl) == 5 + + plt.clf() + hdl = da.plot.line(x='lon', hue='y') + assert len(hdl) == 4 + + with pytest.raises(ValueError, message='If x or y are 2D '): + da.plot.line(x='lon', hue='lat') def test_2d_before_squeeze(self): a = DataArray(easy_array((1, 5))) @@ -267,6 +289,7 @@ def test_datetime_dimension(self): assert ax.has_data() @pytest.mark.slow + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid(self): a = easy_array((10, 15, 4)) d = DataArray(a, dims=['y', 'x', 'z']) @@ -328,6 +351,7 @@ def test_plot_size(self): self.darray.plot(aspect=1) @pytest.mark.slow + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) @@ -340,8 +364,13 @@ def test_convenient_facetgrid_4d(self): with raises_regex(ValueError, '[Ff]acet'): d.plot(x='x', y='y', col='columns', ax=plt.gca()) + def test_coord_with_interval(self): + bins = [-1, 0, 1, 2] + self.darray.groupby_bins('dim_0', bins).mean(xr.ALL_DIMS).plot() + class TestPlot1D(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): d = [0, 1.1, 0, 2] self.darray = DataArray( @@ -354,7 +383,7 @@ def test_xlabel_is_index_name(self): def test_no_label_name_on_x_axis(self): self.darray.plot(y='period') - self.assertEqual('', plt.gca().get_xlabel()) + assert '' == plt.gca().get_xlabel() def test_no_label_name_on_y_axis(self): self.darray.plot() @@ -413,7 +442,22 @@ def test_slice_in_title(self): assert 'd = 10' == title +class TestPlotStep(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + self.darray = DataArray(easy_array((2, 3, 4))) + + def test_step(self): + self.darray[0, 0].plot.step() + + def test_coord_with_interval_step(self): + bins = [-1, 0, 1, 2] + self.darray.groupby_bins('dim_0', bins).mean(xr.ALL_DIMS).plot.step() + assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + + class TestPlotHistogram(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): self.darray = DataArray(easy_array((2, 3, 4))) @@ -426,10 +470,6 @@ def test_xlabel_uses_name(self): self.darray.plot.hist() assert 'testpoints [testunits]' == plt.gca().get_xlabel() - def test_ylabel_is_count(self): - self.darray.plot.hist() - assert 'Count' == plt.gca().get_ylabel() - def test_title_is_histogram(self): self.darray.plot.hist() assert 'Histogram' == plt.gca().get_title() @@ -451,9 +491,14 @@ def test_plot_nans(self): self.darray[0, 0, 0] = np.nan self.darray.plot.hist() + def test_hist_coord_with_interval(self): + (self.darray.groupby_bins('dim_0', [-1, 0, 1, 2]).mean(xr.ALL_DIMS) + .plot.hist(range=(-1, 2))) + @requires_matplotlib -class TestDetermineCmapParams(TestCase): +class TestDetermineCmapParams(object): + @pytest.fixture(autouse=True) def setUp(self): self.data = np.linspace(0, 1, num=100) @@ -474,6 +519,21 @@ def test_center(self): assert cmap_params['levels'] is None assert cmap_params['norm'] is None + def test_cmap_sequential_option(self): + with xr.set_options(cmap_sequential='magma'): + cmap_params = _determine_cmap_params(self.data) + assert cmap_params['cmap'] == 'magma' + + def test_cmap_sequential_explicit_option(self): + with xr.set_options(cmap_sequential=mpl.cm.magma): + cmap_params = _determine_cmap_params(self.data) + assert cmap_params['cmap'] == mpl.cm.magma + + def test_cmap_divergent_option(self): + with xr.set_options(cmap_divergent='magma'): + cmap_params = _determine_cmap_params(self.data, center=0.5) + assert cmap_params['cmap'] == 'magma' + def test_nan_inf_are_ignored(self): cmap_params1 = _determine_cmap_params(self.data) data = self.data @@ -609,9 +669,30 @@ def test_divergentcontrol(self): assert cmap_params['vmax'] == 0.6 assert cmap_params['cmap'] == "viridis" + def test_norm_sets_vmin_vmax(self): + vmin = self.data.min() + vmax = self.data.max() + + for norm, extend in zip([mpl.colors.LogNorm(), + mpl.colors.LogNorm(vmin + 1, vmax - 1), + mpl.colors.LogNorm(None, vmax - 1), + mpl.colors.LogNorm(vmin + 1, None)], + ['neither', 'both', 'max', 'min']): + + test_min = vmin if norm.vmin is None else norm.vmin + test_max = vmax if norm.vmax is None else norm.vmax + + cmap_params = _determine_cmap_params(self.data, norm=norm) + + assert cmap_params['vmin'] == test_min + assert cmap_params['vmax'] == test_max + assert cmap_params['extend'] == extend + assert cmap_params['norm'] == norm + @requires_matplotlib -class TestDiscreteColorMap(TestCase): +class TestDiscreteColorMap(object): + @pytest.fixture(autouse=True) def setUp(self): x = np.arange(start=0, stop=10, step=2) y = np.arange(start=9, stop=-7, step=-3) @@ -645,10 +726,10 @@ def test_build_discrete_cmap(self): @pytest.mark.slow def test_discrete_colormap_list_of_levels(self): - for extend, levels in [('max', [-1, 2, 4, 8, 10]), ('both', - [2, 5, 10, 11]), - ('neither', [0, 5, 10, 15]), ('min', - [2, 5, 10, 15])]: + for extend, levels in [('max', [-1, 2, 4, 8, 10]), + ('both', [2, 5, 10, 11]), + ('neither', [0, 5, 10, 15]), + ('min', [2, 5, 10, 15])]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)(levels=levels) assert_array_equal(levels, primitive.norm.boundaries) @@ -662,10 +743,10 @@ def test_discrete_colormap_list_of_levels(self): @pytest.mark.slow def test_discrete_colormap_int_levels(self): - for extend, levels, vmin, vmax in [('neither', 7, None, - None), ('neither', 7, None, 20), - ('both', 7, 4, 8), ('min', 10, 4, - 15)]: + for extend, levels, vmin, vmax in [('neither', 7, None, None), + ('neither', 7, None, 20), + ('both', 7, 4, 8), + ('min', 10, 4, 15)]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)( levels=levels, vmin=vmin, vmax=vmax) @@ -691,8 +772,13 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self): assert primitive.norm.vmax == max(levels) assert primitive.norm.vmin == min(levels) + def test_discrete_colormap_provided_boundary_norm(self): + norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4) + primitive = self.darray.plot.contourf(norm=norm) + np.testing.assert_allclose(primitive.levels, norm.boundaries) + -class Common2dMixin: +class Common2dMixin(object): """ Common tests for 2d plotting go here. @@ -700,6 +786,7 @@ class Common2dMixin: Should have the same name as the method. """ + @pytest.fixture(autouse=True) def setUp(self): da = DataArray(easy_array((10, 15), start=-1), dims=['y', 'x'], @@ -710,7 +797,7 @@ def setUp(self): x, y = np.meshgrid(da.x.values, da.y.values) ds['x2d'] = DataArray(x, dims=['y', 'x']) ds['y2d'] = DataArray(y, dims=['y', 'x']) - ds.set_coords(['x2d', 'y2d'], inplace=True) + ds = ds.set_coords(['x2d', 'y2d']) # set darray and plot method self.darray = ds.testvar @@ -748,6 +835,24 @@ def test_nonnumeric_index_raises_typeerror(self): def test_can_pass_in_axis(self): self.pass_in_axis(self.plotmethod) + def test_xyincrease_defaults(self): + + # With default settings the axis must be ordered regardless + # of the coords order. + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[1, 2, 3], + [1, 2]])) + bounds = plt.gca().get_ylim() + assert bounds[0] < bounds[1] + bounds = plt.gca().get_xlim() + assert bounds[0] < bounds[1] + # Inverted coords + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[3, 2, 1], + [2, 1]])) + bounds = plt.gca().get_ylim() + assert bounds[0] < bounds[1] + bounds = plt.gca().get_xlim() + assert bounds[0] < bounds[1] + def test_xyincrease_false_changes_axes(self): self.plotmethod(xincrease=False, yincrease=False) xlim = plt.gca().get_xlim() @@ -779,10 +884,13 @@ def test_plot_nans(self): clim2 = self.plotfunc(x2).get_clim() assert clim1 == clim2 + @pytest.mark.filterwarnings('ignore::UserWarning') + @pytest.mark.filterwarnings('ignore:invalid value encountered') def test_can_plot_all_nans(self): # regression test for issue #1780 self.plotfunc(DataArray(np.full((2, 2), np.nan))) + @pytest.mark.filterwarnings('ignore: Attempting to set') def test_can_plot_axis_size_one(self): if self.plotfunc.__name__ not in ('contour', 'contourf'): self.plotfunc(DataArray(np.ones((1, 1)))) @@ -974,6 +1082,7 @@ def test_2d_function_and_method_signature_same(self): del func_sig['darray'] assert func_sig == method_sig + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid(self): a = easy_array((10, 15, 4)) d = DataArray(a, dims=['y', 'x', 'z']) @@ -1005,6 +1114,7 @@ def test_convenient_facetgrid(self): else: assert '' == ax.get_xlabel() + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) @@ -1014,6 +1124,19 @@ def test_convenient_facetgrid_4d(self): for ax in g.axes.flat: assert ax.has_data() + @pytest.mark.filterwarnings('ignore:This figure includes') + def test_facetgrid_map_only_appends_mappables(self): + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) + g = self.plotfunc(d, x='x', y='y', col='columns', row='rows') + + expected = g._mappables + + g.map(lambda: plt.plot(1, 1)) + actual = g._mappables + + assert expected == actual + def test_facetgrid_cmap(self): # Regression test for GH592 data = (np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12)) @@ -1024,10 +1147,42 @@ def test_facetgrid_cmap(self): # check that all colormaps are the same assert len(set(m.get_cmap().name for m in fg._mappables)) == 1 + def test_facetgrid_cbar_kwargs(self): + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) + g = self.plotfunc(d, x='x', y='y', col='columns', row='rows', + cbar_kwargs={'label': 'test_label'}) + + # catch contour case + if hasattr(g, 'cbar'): + assert g.cbar._label == 'test_label' + + def test_facetgrid_no_cbar_ax(self): + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) + with pytest.raises(ValueError): + g = self.plotfunc(d, x='x', y='y', col='columns', row='rows', + cbar_ax=1) + def test_cmap_and_color_both(self): with pytest.raises(ValueError): self.plotmethod(colors='k', cmap='RdBu') + def test_2d_coord_with_interval(self): + for dim in self.darray.dims: + gp = self.darray.groupby_bins(dim, range(15)).mean(dim) + for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: + getattr(gp.plot, kind)() + + def test_colormap_error_norm_and_vmin_vmax(self): + norm = mpl.colors.LogNorm(0.1, 1e1) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmin=2) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmax=2) + @pytest.mark.slow class TestContourf(Common2dMixin, PlotTestCase): @@ -1090,23 +1245,23 @@ def test_colors(self): def _color_as_tuple(c): return tuple(c[:3]) + # with single color, we don't want rgb array artist = self.plotmethod(colors='k') - assert _color_as_tuple(artist.cmap.colors[0]) == \ - (0.0, 0.0, 0.0) + assert artist.cmap.colors[0] == 'k' artist = self.plotmethod(colors=['k', 'b']) - assert _color_as_tuple(artist.cmap.colors[1]) == \ - (0.0, 0.0, 1.0) + assert (_color_as_tuple(artist.cmap.colors[1]) == + (0.0, 0.0, 1.0)) artist = self.darray.plot.contour( levels=[-0.5, 0., 0.5, 1.], colors=['k', 'r', 'w', 'b']) - assert _color_as_tuple(artist.cmap.colors[1]) == \ - (1.0, 0.0, 0.0) - assert _color_as_tuple(artist.cmap.colors[2]) == \ - (1.0, 1.0, 1.0) + assert (_color_as_tuple(artist.cmap.colors[1]) == + (1.0, 0.0, 0.0)) + assert (_color_as_tuple(artist.cmap.colors[2]) == + (1.0, 1.0, 1.0)) # the last color is now under "over" - assert _color_as_tuple(artist.cmap._rgba_over) == \ - (0.0, 0.0, 1.0) + assert (_color_as_tuple(artist.cmap._rgba_over) == + (0.0, 0.0, 1.0)) def test_cmap_and_color_both(self): with pytest.raises(ValueError): @@ -1283,13 +1438,26 @@ def test_imshow_rgb_values_in_valid_range(self): assert out.dtype == np.uint8 assert (out[..., :3] == da.values).all() # Compare without added alpha + @pytest.mark.filterwarnings('ignore:Several dimensions of this array') def test_regression_rgb_imshow_dim_size_one(self): # Regression: https://github.com/pydata/xarray/issues/1966 da = DataArray(easy_array((1, 3, 3), start=0.0, stop=1.0)) da.plot.imshow() + def test_origin_overrides_xyincrease(self): + da = DataArray(easy_array((3, 2)), coords=[[-2, 0, 2], [-1, 1]]) + da.plot.imshow(origin='upper') + assert plt.xlim()[0] < 0 + assert plt.ylim()[1] < 0 + + plt.clf() + da.plot.imshow(origin='lower') + assert plt.xlim()[0] < 0 + assert plt.ylim()[0] < 0 + class TestFacetGrid(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): d = easy_array((10, 15, 3)) self.darray = DataArray( @@ -1461,7 +1629,9 @@ def test_num_ticks(self): @pytest.mark.slow def test_map(self): + assert self.g._finalized is False self.g.map(plt.contourf, 'x', 'y', Ellipsis) + assert self.g._finalized is True self.g.map(lambda: None) @pytest.mark.slow @@ -1515,7 +1685,9 @@ def test_facetgrid_polar(self): sharey=False) +@pytest.mark.filterwarnings('ignore:tight_layout cannot') class TestFacetGrid4d(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): a = easy_array((10, 15, 3, 2)) darray = DataArray(a, dims=['y', 'x', 'col', 'row']) @@ -1542,7 +1714,9 @@ def test_default_labels(self): assert substring_in_axes(label, ax) +@pytest.mark.filterwarnings('ignore:tight_layout cannot') class TestFacetedLinePlots(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): self.darray = DataArray(np.random.randn(10, 6, 3, 4), dims=['hue', 'x', 'col', 'row'], @@ -1690,6 +1864,7 @@ def test_not_same_dimensions(self): class TestDatetimePlot(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): ''' Create a DataArray with a time-axis that contains datetime objects. @@ -1742,3 +1917,68 @@ def test_plot_cftime_data_error(): data = DataArray(data, coords=[np.arange(5)], dims=['x']) with raises_regex(NotImplementedError, 'cftime.datetime'): data.plot() + + +test_da_list = [DataArray(easy_array((10, ))), + DataArray(easy_array((10, 3))), + DataArray(easy_array((10, 3, 2)))] + + +@requires_matplotlib +class TestAxesKwargs(object): + @pytest.mark.parametrize('da', test_da_list) + @pytest.mark.parametrize('xincrease', [True, False]) + def test_xincrease_kwarg(self, da, xincrease): + plt.clf() + da.plot(xincrease=xincrease) + assert plt.gca().xaxis_inverted() == (not xincrease) + + @pytest.mark.parametrize('da', test_da_list) + @pytest.mark.parametrize('yincrease', [True, False]) + def test_yincrease_kwarg(self, da, yincrease): + plt.clf() + da.plot(yincrease=yincrease) + assert plt.gca().yaxis_inverted() == (not yincrease) + + @pytest.mark.parametrize('da', test_da_list) + @pytest.mark.parametrize('xscale', ['linear', 'log', 'logit', 'symlog']) + def test_xscale_kwarg(self, da, xscale): + plt.clf() + da.plot(xscale=xscale) + assert plt.gca().get_xscale() == xscale + + @pytest.mark.parametrize('da', [DataArray(easy_array((10, ))), + DataArray(easy_array((10, 3)))]) + @pytest.mark.parametrize('yscale', ['linear', 'log', 'logit', 'symlog']) + def test_yscale_kwarg(self, da, yscale): + plt.clf() + da.plot(yscale=yscale) + assert plt.gca().get_yscale() == yscale + + @pytest.mark.parametrize('da', test_da_list) + def test_xlim_kwarg(self, da): + plt.clf() + expected = (0.0, 1000.0) + da.plot(xlim=[0, 1000]) + assert plt.gca().get_xlim() == expected + + @pytest.mark.parametrize('da', test_da_list) + def test_ylim_kwarg(self, da): + plt.clf() + da.plot(ylim=[0, 1000]) + expected = (0.0, 1000.0) + assert plt.gca().get_ylim() == expected + + @pytest.mark.parametrize('da', test_da_list) + def test_xticks_kwarg(self, da): + plt.clf() + da.plot(xticks=np.arange(5)) + expected = np.arange(5).tolist() + assert np.all(plt.gca().get_xticks() == expected) + + @pytest.mark.parametrize('da', test_da_list) + def test_yticks_kwarg(self, da): + plt.clf() + da.plot(yticks=np.arange(5)) + expected = np.arange(5) + assert np.all(plt.gca().get_yticks() == expected) diff --git a/xarray/tests/test_tutorial.py b/xarray/tests/test_tutorial.py index d550a85e8ce..6547311aa2f 100644 --- a/xarray/tests/test_tutorial.py +++ b/xarray/tests/test_tutorial.py @@ -2,15 +2,17 @@ import os +import pytest + from xarray import DataArray, tutorial from xarray.core.pycompat import suppress -from . import TestCase, assert_identical, network +from . import assert_identical, network @network -class TestLoadDataset(TestCase): - +class TestLoadDataset(object): + @pytest.fixture(autouse=True) def setUp(self): self.testfile = 'tiny' self.testfilepath = os.path.expanduser(os.sep.join( @@ -21,6 +23,11 @@ def setUp(self): os.remove('{}.md5'.format(self.testfilepath)) def test_download_from_github(self): - ds = tutorial.load_dataset(self.testfile) + ds = tutorial.open_dataset(self.testfile).load() tiny = DataArray(range(5), name='tiny').to_dataset() assert_identical(ds, tiny) + + def test_download_from_github_load_without_cache(self): + ds_nocache = tutorial.open_dataset(self.testfile, cache=False).load() + ds_cache = tutorial.open_dataset(self.testfile).load() + assert_identical(ds_cache, ds_nocache) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 195bb36e36e..6941efb1c6e 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -8,9 +8,9 @@ import xarray as xr import xarray.ufuncs as xu -from . import ( - assert_array_equal, assert_identical as assert_identical_, mock, - raises_regex, requires_np113) +from . import assert_array_equal +from . import assert_identical as assert_identical_ +from . import mock, raises_regex, requires_np113 def assert_identical(a, b): diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index ed8045b78e4..ed07af0d7bb 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -6,19 +6,21 @@ import pandas as pd import pytest +import xarray as xr from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops, utils from xarray.core.options import set_options from xarray.core.pycompat import OrderedDict from xarray.core.utils import either_dict_or_kwargs +from xarray.testing import assert_identical from . import ( - TestCase, assert_array_equal, has_cftime, has_cftime_or_netCDF4, + assert_array_equal, has_cftime, has_cftime_or_netCDF4, requires_cftime, requires_dask) from .test_coding_times import _all_cftime_date_types -class TestAlias(TestCase): +class TestAlias(object): def test(self): def new_method(): pass @@ -44,19 +46,17 @@ def test_safe_cast_to_index(): @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('enable_cftimeindex', [False, True]) -def test_safe_cast_to_index_cftimeindex(enable_cftimeindex): +def test_safe_cast_to_index_cftimeindex(): date_types = _all_cftime_date_types() for date_type in date_types.values(): dates = [date_type(1, 1, day) for day in range(1, 20)] - if enable_cftimeindex and has_cftime: + if has_cftime: expected = CFTimeIndex(dates) else: expected = pd.Index(dates) - with set_options(enable_cftimeindex=enable_cftimeindex): - actual = utils.safe_cast_to_index(np.array(dates)) + actual = utils.safe_cast_to_index(np.array(dates)) assert_array_equal(expected, actual) assert expected.dtype == actual.dtype assert isinstance(actual, type(expected)) @@ -64,13 +64,11 @@ def test_safe_cast_to_index_cftimeindex(enable_cftimeindex): # Test that datetime.datetime objects are never used in a CFTimeIndex @pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') -@pytest.mark.parametrize('enable_cftimeindex', [False, True]) -def test_safe_cast_to_index_datetime_datetime(enable_cftimeindex): +def test_safe_cast_to_index_datetime_datetime(): dates = [datetime(1, 1, day) for day in range(1, 20)] expected = pd.Index(dates) - with set_options(enable_cftimeindex=enable_cftimeindex): - actual = utils.safe_cast_to_index(np.array(dates)) + actual = utils.safe_cast_to_index(np.array(dates)) assert_array_equal(expected, actual) assert isinstance(actual, pd.Index) @@ -96,7 +94,7 @@ def test_multiindex_from_product_levels_non_unique(): np.testing.assert_array_equal(result.levels[1], [1, 2]) -class TestArrayEquiv(TestCase): +class TestArrayEquiv(object): def test_0d(self): # verify our work around for pd.isnull not working for 0-dimensional # object arrays @@ -106,8 +104,9 @@ def test_0d(self): assert not duck_array_ops.array_equiv(0, np.array(1, dtype=object)) -class TestDictionaries(TestCase): - def setUp(self): +class TestDictionaries(object): + @pytest.fixture(autouse=True) + def setup(self): self.x = {'a': 'A', 'b': 'B'} self.y = {'c': 'C', 'b': 'B'} self.z = {'a': 'Z'} @@ -174,7 +173,7 @@ def test_frozen(self): def test_sorted_keys_dict(self): x = {'a': 1, 'b': 2, 'c': 3} y = utils.SortedKeysDict(x) - self.assertItemsEqual(y, ['a', 'b', 'c']) + assert list(y) == ['a', 'b', 'c'] assert repr(utils.SortedKeysDict()) == \ "SortedKeysDict({})" @@ -189,7 +188,7 @@ def test_chain_map(self): m['x'] = 100 assert m['x'] == 100 assert m.maps[0]['x'] == 100 - self.assertItemsEqual(['x', 'y', 'z'], m) + assert set(m) == {'x', 'y', 'z'} def test_repr_object(): @@ -197,7 +196,23 @@ def test_repr_object(): assert repr(obj) == 'foo' -class Test_is_uniform_and_sorted(TestCase): +def test_is_remote_uri(): + assert utils.is_remote_uri('http://example.com') + assert utils.is_remote_uri('https://example.com') + assert not utils.is_remote_uri(' http://example.com') + assert not utils.is_remote_uri('example.nc') + + +def test_is_grib_path(): + assert not utils.is_grib_path('example.nc') + assert not utils.is_grib_path('example.grib ') + assert utils.is_grib_path('example.grib') + assert utils.is_grib_path('example.grib2') + assert utils.is_grib_path('example.grb') + assert utils.is_grib_path('example.grb2') + + +class Test_is_uniform_and_sorted(object): def test_sorted_uniform(self): assert utils.is_uniform_spaced(np.arange(5)) @@ -218,7 +233,7 @@ def test_relative_tolerance(self): assert utils.is_uniform_spaced([0, 0.97, 2], rtol=0.1) -class Test_hashable(TestCase): +class Test_hashable(object): def test_hashable(self): for v in [False, 1, (2, ), (3, 4), 'four']: @@ -263,3 +278,42 @@ def test_either_dict_or_kwargs(): with pytest.raises(ValueError, match=r'foo'): result = either_dict_or_kwargs(dict(a=1), dict(a=1), 'foo') + + +def test_datetime_to_numeric_datetime64(): + times = pd.date_range('2000', periods=5, freq='7D') + da = xr.DataArray(times, coords=[times], dims=['time']) + result = utils.datetime_to_numeric(da, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(0, 35, 7), coords=da.coords) + assert_identical(result, expected) + + offset = da.isel(time=1) + result = utils.datetime_to_numeric(da, offset=offset, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(-7, 28, 7), coords=da.coords) + assert_identical(result, expected) + + dtype = np.float32 + result = utils.datetime_to_numeric(da, datetime_unit='h', dtype=dtype) + expected = 24 * xr.DataArray( + np.arange(0, 35, 7), coords=da.coords).astype(dtype) + assert_identical(result, expected) + + +@requires_cftime +def test_datetime_to_numeric_cftime(): + times = xr.cftime_range('2000', periods=5, freq='7D') + da = xr.DataArray(times, coords=[times], dims=['time']) + result = utils.datetime_to_numeric(da, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(0, 35, 7), coords=da.coords) + assert_identical(result, expected) + + offset = da.isel(time=1) + result = utils.datetime_to_numeric(da, offset=offset, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(-7, 28, 7), coords=da.coords) + assert_identical(result, expected) + + dtype = np.float32 + result = utils.datetime_to_numeric(da, datetime_unit='h', dtype=dtype) + expected = 24 * xr.DataArray( + np.arange(0, 35, 7), coords=da.coords).astype(dtype) + assert_identical(result, expected) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 290c7a6e308..0bd440781ac 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1,12 +1,11 @@ from __future__ import absolute_import, division, print_function -from collections import namedtuple +import warnings from copy import copy, deepcopy from datetime import datetime, timedelta from distutils.version import LooseVersion from textwrap import dedent -import warnings import numpy as np import pandas as pd @@ -26,11 +25,11 @@ from xarray.tests import requires_bottleneck from . import ( - TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_identical, raises_regex, requires_dask, source_ndarray) + assert_allclose, assert_array_equal, assert_equal, assert_identical, + raises_regex, requires_dask, source_ndarray) -class VariableSubclassTestCases(object): +class VariableSubclassobjects(object): def test_properties(self): data = 0.5 * np.arange(10) v = self.cls(['time'], data, {'foo': 'bar'}) @@ -480,20 +479,20 @@ def test_concat_mixed_dtypes(self): assert_identical(expected, actual) assert actual.dtype == object - def test_copy(self): + @pytest.mark.parametrize('deep', [True, False]) + def test_copy(self, deep): v = self.cls('x', 0.5 * np.arange(10), {'foo': 'bar'}) - for deep in [True, False]: - w = v.copy(deep=deep) - assert type(v) is type(w) - assert_identical(v, w) - assert v.dtype == w.dtype - if self.cls is Variable: - if deep: - assert source_ndarray(v.values) is not \ - source_ndarray(w.values) - else: - assert source_ndarray(v.values) is \ - source_ndarray(w.values) + w = v.copy(deep=deep) + assert type(v) is type(w) + assert_identical(v, w) + assert v.dtype == w.dtype + if self.cls is Variable: + if deep: + assert (source_ndarray(v.values) is not + source_ndarray(w.values)) + else: + assert (source_ndarray(v.values) is + source_ndarray(w.values)) assert_identical(v, copy(v)) def test_copy_index(self): @@ -506,6 +505,34 @@ def test_copy_index(self): assert isinstance(w.to_index(), pd.MultiIndex) assert_array_equal(v._data.array, w._data.array) + def test_copy_with_data(self): + orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'}) + new_data = np.array([[2.5, 5.0], [7.1, 43]]) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + + def test_copy_with_data_errors(self): + orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'}) + new_data = [2.5, 5.0] + with raises_regex(ValueError, 'must match shape of object'): + orig.copy(data=new_data) + + def test_copy_index_with_data(self): + orig = IndexVariable('x', np.arange(5)) + new_data = np.arange(5, 10) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + + def test_copy_index_with_data_errors(self): + orig = IndexVariable('x', np.arange(5)) + new_data = np.arange(5, 20) + with raises_regex(ValueError, 'must match shape of object'): + orig.copy(data=new_data) + def test_real_and_imag(self): v = self.cls('x', np.arange(3) - 1j * np.arange(3), {'foo': 'bar'}) expected_re = self.cls('x', np.arange(3), {'foo': 'bar'}) @@ -787,10 +814,11 @@ def test_rolling_window(self): v_loaded[0] = 1.0 -class TestVariable(TestCase, VariableSubclassTestCases): +class TestVariable(VariableSubclassobjects): cls = staticmethod(Variable) - def setUp(self): + @pytest.fixture(autouse=True) + def setup(self): self.d = np.random.random((10, 3)).astype(np.float64) def test_data_and_values(self): @@ -938,27 +966,14 @@ def test_as_variable(self): assert not isinstance(ds['x'], Variable) assert isinstance(as_variable(ds['x']), Variable) - FakeVariable = namedtuple('FakeVariable', 'values dims') - fake_xarray = FakeVariable(expected.values, expected.dims) - assert_identical(expected, as_variable(fake_xarray)) - - FakeVariable = namedtuple('FakeVariable', 'data dims') - fake_xarray = FakeVariable(expected.data, expected.dims) - assert_identical(expected, as_variable(fake_xarray)) - - FakeVariable = namedtuple('FakeVariable', - 'data values dims attrs encoding') - fake_xarray = FakeVariable(expected_extra.data, expected_extra.values, - expected_extra.dims, expected_extra.attrs, - expected_extra.encoding) - assert_identical(expected_extra, as_variable(fake_xarray)) - xarray_tuple = (expected_extra.dims, expected_extra.values, expected_extra.attrs, expected_extra.encoding) assert_identical(expected_extra, as_variable(xarray_tuple)) - with raises_regex(TypeError, 'tuples to convert'): + with raises_regex(TypeError, 'tuple of form'): as_variable(tuple(data)) + with raises_regex(ValueError, 'tuple of form'): # GH1016 + as_variable(('five', 'six', 'seven')) with raises_regex( TypeError, 'without an explicit list of dimensions'): as_variable(data) @@ -979,6 +994,13 @@ def test_as_variable(self): ValueError, 'has more than 1-dimension'): as_variable(expected, name='x') + # test datetime, timedelta conversion + dt = np.array([datetime(1999, 1, 1) + timedelta(days=x) + for x in range(10)]) + assert as_variable(dt, 'time').dtype.kind == 'M' + td = np.array([timedelta(days=x) for x in range(10)]) + assert as_variable(td, 'time').dtype.kind == 'm' + def test_repr(self): v = Variable(['time', 'x'], [[1, 2, 3], [4, 5, 6]], {'foo': 'bar'}) expected = dedent(""" @@ -1503,8 +1525,8 @@ def test_reduce_funcs(self): assert_identical(v.all(dim='x'), Variable([], False)) v = Variable('t', pd.date_range('2000-01-01', periods=3)) - with pytest.raises(NotImplementedError): - v.argmax(skipna=True) + assert v.argmax(skipna=True) == 2 + assert_identical( v.max(), Variable([], pd.Timestamp('2000-01-03'))) @@ -1639,7 +1661,7 @@ def assert_assigned_2d(array, key_x, key_y, values): @requires_dask -class TestVariableWithDask(TestCase, VariableSubclassTestCases): +class TestVariableWithDask(VariableSubclassobjects): cls = staticmethod(lambda *args: Variable(*args).chunk()) @pytest.mark.xfail @@ -1665,6 +1687,12 @@ def test_getitem_fancy(self): def test_getitem_1d_fancy(self): super(TestVariableWithDask, self).test_getitem_1d_fancy() + def test_equals_all_dtypes(self): + import dask + if '0.18.2' <= LooseVersion(dask.__version__) < '0.19.1': + pytest.xfail('https://github.com/pydata/xarray/issues/2318') + super(TestVariableWithDask, self).test_equals_all_dtypes() + def test_getitem_with_mask_nd_indexer(self): import dask.array as da v = Variable(['x'], da.arange(3, chunks=3)) @@ -1673,7 +1701,7 @@ def test_getitem_with_mask_nd_indexer(self): self.cls(('x', 'y'), [[0, -1], [-1, 2]])) -class TestIndexVariable(TestCase, VariableSubclassTestCases): +class TestIndexVariable(VariableSubclassobjects): cls = staticmethod(IndexVariable) def test_init(self): @@ -1786,7 +1814,7 @@ def test_rolling_window(self): super(TestIndexVariable, self).test_rolling_window() -class TestAsCompatibleData(TestCase): +class TestAsCompatibleData(object): def test_unchanged_types(self): types = (np.asarray, PandasIndexAdapter, LazilyOuterIndexedArray) for t in types: @@ -1927,9 +1955,10 @@ def test_raise_no_warning_for_nan_in_binary_ops(): assert len(record) == 0 -class TestBackendIndexing(TestCase): +class TestBackendIndexing(object): """ Make sure all the array wrappers can be indexed. """ + @pytest.fixture(autouse=True) def setUp(self): self.d = np.random.random((10, 3)).astype(np.float64) diff --git a/xarray/tutorial.py b/xarray/tutorial.py index 83a8317f42b..064eed330cc 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -9,6 +9,7 @@ import hashlib import os as _os +import warnings from .backends.api import open_dataset as _open_dataset from .core.pycompat import urlretrieve as _urlretrieve @@ -24,7 +25,7 @@ def file_md5_checksum(fname): # idea borrowed from Seaborn -def load_dataset(name, cache=True, cache_dir=_default_cache_dir, +def open_dataset(name, cache=True, cache_dir=_default_cache_dir, github_url='https://github.com/pydata/xarray-data', branch='master', **kws): """ @@ -48,6 +49,10 @@ def load_dataset(name, cache=True, cache_dir=_default_cache_dir, kws : dict, optional Passed to xarray.open_dataset + See Also + -------- + xarray.open_dataset + """ longdir = _os.path.expanduser(cache_dir) fullname = name + '.nc' @@ -77,9 +82,27 @@ def load_dataset(name, cache=True, cache_dir=_default_cache_dir, """ raise IOError(msg) - ds = _open_dataset(localfile, **kws).load() + ds = _open_dataset(localfile, **kws) if not cache: + ds = ds.load() _os.remove(localfile) return ds + + +def load_dataset(*args, **kwargs): + """ + `load_dataset` will be removed in version 0.12. The current behavior of + this function can be achived by using `tutorial.open_dataset(...).load()`. + + See Also + -------- + open_dataset + """ + warnings.warn( + "load_dataset` will be removed in xarray version 0.12. The current " + "behavior of this function can be achived by using " + "`tutorial.open_dataset(...).load()`.", + DeprecationWarning, stacklevel=2) + return open_dataset(*args, **kwargs).load() diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index 478b867b0af..5459e67e603 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -44,7 +44,7 @@ def get_sys_info(): (sysname, nodename, release, version, machine, processor) = platform.uname() blob.extend([ - ("python", "%d.%d.%d.%s.%s" % sys.version_info[:]), + ("python", sys.version), ("python-bits", struct.calcsize("P") * 8), ("OS", "%s" % (sysname)), ("OS-release", "%s" % (release)), @@ -63,9 +63,27 @@ def get_sys_info(): return blob +def netcdf_and_hdf5_versions(): + libhdf5_version = None + libnetcdf_version = None + try: + import netCDF4 + libhdf5_version = netCDF4.__hdf5libversion__ + libnetcdf_version = netCDF4.__netcdf4libversion__ + except ImportError: + try: + import h5py + libhdf5_version = h5py.__hdf5libversion__ + except ImportError: + pass + return [('libhdf5', libhdf5_version), ('libnetcdf', libnetcdf_version)] + + def show_versions(as_json=False): sys_info = get_sys_info() + sys_info.extend(netcdf_and_hdf5_versions()) + deps = [ # (MODULE_NAME, f(mod) -> mod version) ("xarray", lambda mod: mod.__version__), @@ -74,11 +92,16 @@ def show_versions(as_json=False): ("scipy", lambda mod: mod.__version__), # xarray optionals ("netCDF4", lambda mod: mod.__version__), - # ("pydap", lambda mod: mod.version.version), + ("pydap", lambda mod: mod.__version__), ("h5netcdf", lambda mod: mod.__version__), ("h5py", lambda mod: mod.__version__), ("Nio", lambda mod: mod.__version__), ("zarr", lambda mod: mod.__version__), + ("cftime", lambda mod: mod.__version__), + ("PseudonetCDF", lambda mod: mod.__version__), + ("rasterio", lambda mod: mod.__version__), + ("cfgrib", lambda mod: mod.__version__), + ("iris", lambda mod: mod.__version__), ("bottleneck", lambda mod: mod.__version__), ("cyordereddict", lambda mod: mod.__version__), ("dask", lambda mod: mod.__version__), @@ -103,10 +126,14 @@ def show_versions(as_json=False): mod = sys.modules[modname] else: mod = importlib.import_module(modname) - ver = ver_f(mod) - deps_blob.append((modname, ver)) except Exception: deps_blob.append((modname, None)) + else: + try: + ver = ver_f(mod) + deps_blob.append((modname, ver)) + except Exception: + deps_blob.append((modname, 'installed')) if (as_json): try: